From 05472dfbd0a599b52b2af82bd58a202770993e65 Mon Sep 17 00:00:00 2001 From: Encephala Date: Mon, 10 Jun 2024 16:59:16 +0200 Subject: [PATCH] Playing with tokio/futures/async --- Cargo.toml | 11 ++-- cli/Cargo.toml | 14 +++++ cli/src/main.rs | 114 ++++++++++++++++++++++++++++++++++ dbms/Cargo.toml | 2 +- dbms/src/lib.rs | 5 ++ dbms/src/server/mod.rs | 22 +++++++ sql-parse/Cargo.toml | 2 +- src/main.rs | 138 +++++++++-------------------------------- 8 files changed, 193 insertions(+), 115 deletions(-) create mode 100644 cli/Cargo.toml create mode 100644 cli/src/main.rs create mode 100644 dbms/src/server/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 6704c7b..aae954a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,13 +2,16 @@ name = "rusty-db" version = "0.1.0" edition = "2021" -license = "MIT" description = "A non-production-ready-and-will-never-be database written for my own learning purposes" +[workspace.package] +license = "MIT" + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] tokio = { version = "*", features = ["full"]} +futures = "*" sql-parse.workspace = true dbms.workspace = true @@ -17,9 +20,5 @@ dbms.workspace = true sql-parse = { path = "./sql-parse", version = "0.1.0" } dbms = { path = "./dbms", version = "0.1.0" } - [workspace] -members = [ - "sql-parse", - "dbms" -] +members = ["cli"] diff --git a/cli/Cargo.toml b/cli/Cargo.toml new file mode 100644 index 0000000..f9a2beb --- /dev/null +++ b/cli/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "cli" +version = "0.1.0" +edition = "2021" +license.workspace = true +description = "CLI client for the database" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "*", features = ["full"]} + +sql-parse.workspace = true +dbms.workspace = true diff --git a/cli/src/main.rs b/cli/src/main.rs new file mode 100644 index 0000000..1abca12 --- /dev/null +++ b/cli/src/main.rs @@ -0,0 +1,114 @@ +#![allow(clippy::needless_return)] + +use std::io::Write; +use std::path::PathBuf; + +use sql_parse::{Lexer, parse_statement, Statement, CreateType}; +use dbms::{Execute, Database, DatabaseName, ExecutionResult, PersistenceManager, FileSystem, SerialisationManager, Serialiser}; + +async fn repl() { + let stdin = std::io::stdin(); + let mut stdout = std::io::stdout(); + + let mut database: Option = None; + + let persistence_manager: Box<_> = FileSystem::new( + SerialisationManager::new(Serialiser::V2), + PathBuf::from("/tmp/rusty-db"), + ).into(); + + loop { + print!(">> "); + + stdout.flush().unwrap(); + + let mut input = String::new(); + + stdin.read_line(&mut input).unwrap(); + + if input == "\\q\n" { + break; + } else if input.is_empty() { + println!(); + break; + } + + // TODO: Standardise handling these special commands + if input.starts_with("\\c ") { + let database_name = input.strip_prefix("\\c ").unwrap().strip_suffix('\n').unwrap(); + + database = match persistence_manager.load_database(DatabaseName(database_name.into())).await { + Ok(db) => { + println!("Connected to database {}", db.name.0); + + Some(db) + }, + Err(error) => { + println!("Got execution error: {error:?}"); + + None + }, + }; + + continue; + } + + if input.starts_with("\\l ") { + let tokens = Lexer::lex(input.strip_prefix("\\l ").unwrap()); + + println!("Lexed: {tokens:?}"); + + continue; + } + + let statement = parse_statement(&input); + + if input.starts_with("\\p ") { + let statement = parse_statement(input.strip_prefix("\\p ").unwrap()); + + println!("Parsed: {statement:?}"); + + continue; + } + + if let Some(statement) = statement { + let is_create_database = matches!(statement, Statement::Create { what: CreateType::Database, .. }); + let is_drop_database = matches!(statement, Statement::Drop { what: CreateType::Database, .. }); + + let result = statement.execute(database.as_mut(), persistence_manager.as_ref()).await; + + match result { + Ok(result) => { + match result { + ExecutionResult::None => (), + an_actual_result => println!("Executed:\n{an_actual_result:?}"), + } + }, + Err(error) => { + println!("Got execution error: {error:?}"); + + // Don't persist storage if statement failed + continue; + } + } + + if is_create_database || is_drop_database { + continue; + } + + // TODO: doing this properly, should only write changed things + // Also I can probably do better than the `is_drop_database` above + match persistence_manager.save_database(database.as_ref().unwrap()).await { + Ok(_) => (), + Err(error) => println!("Failed saving to disk: {error:?}"), + } + } else { + println!("Failed to parse: {input}"); + } + } +} + +#[tokio::main] +async fn main() { + repl().await; +} diff --git a/dbms/Cargo.toml b/dbms/Cargo.toml index e2ff99c..ee03869 100644 --- a/dbms/Cargo.toml +++ b/dbms/Cargo.toml @@ -2,7 +2,7 @@ name = "dbms" version = "0.1.0" edition = "2021" -license = "MIT" +license.workspace = true description = "Database management system" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/dbms/src/lib.rs b/dbms/src/lib.rs index f7e37b6..4ea962c 100644 --- a/dbms/src/lib.rs +++ b/dbms/src/lib.rs @@ -6,6 +6,7 @@ mod types; mod evaluate; mod utils; mod persistence; +mod server; use types::{ColumnName, ColumnValue, TableName}; use sql_parse::{ColumnType, Expression, InfixOperator}; @@ -15,6 +16,7 @@ pub use database::Database; pub use types::DatabaseName; pub use evaluate::{Execute, ExecutionResult}; pub use persistence::{PersistenceManager, FileSystem, SerialisationManager, Serialiser}; +pub use server::handle_connection; @@ -49,6 +51,9 @@ pub enum SqlError { NotABoolean(u8), IncompatibleVersion(u8), + + CouldNotWriteToConnection(std::io::Error), + CouldNotReadFromConnection(std::io::Error), } pub type Result = std::result::Result; diff --git a/dbms/src/server/mod.rs b/dbms/src/server/mod.rs new file mode 100644 index 0000000..8ec07fc --- /dev/null +++ b/dbms/src/server/mod.rs @@ -0,0 +1,22 @@ +use std::{io::{Read, Write}, net::TcpStream}; + +use crate::{Result, SqlError}; +// use sql_parse::parse_statement; + +pub async fn handle_connection(mut stream: TcpStream) -> Result<()> { + write_welcome(&mut stream)?; + + let buf = &mut vec![]; + stream.read_to_end(buf) + .map_err(SqlError::CouldNotReadFromConnection)?; + + // Handle message + println!("Got message {}", std::str::from_utf8(buf).unwrap()); + + return Ok(()); +} + +fn write_welcome(stream: &mut TcpStream) -> Result<()> { + return stream.write_all(&[0x48, 0x45, 0x4C, 0x4C, 0x4F]) + .map_err(SqlError::CouldNotWriteToConnection); +} diff --git a/sql-parse/Cargo.toml b/sql-parse/Cargo.toml index f08f2dd..eaeb5ba 100644 --- a/sql-parse/Cargo.toml +++ b/sql-parse/Cargo.toml @@ -2,7 +2,7 @@ name = "sql-parse" version = "0.1.0" edition = "2021" -license = "MIT" +license.workspace = true description = "SQL parser" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/main.rs b/src/main.rs index 3b45cbd..3890884 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,114 +1,38 @@ -#![allow(clippy::needless_return)] +use std::net::TcpListener; -use std::io::Write; -use std::path::PathBuf; +use tokio::{spawn, task::JoinHandle}; +use futures::future::join_all; -use sql_parse::{Lexer, parse_statement, Statement, CreateType}; -use dbms::{Execute, Database, DatabaseName, ExecutionResult, PersistenceManager, FileSystem, SerialisationManager, Serialiser}; +use dbms::handle_connection; -async fn repl() { - let stdin = std::io::stdin(); - let mut stdout = std::io::stdout(); - - let mut database: Option = None; - - let persistence_manager: Box<_> = FileSystem::new( - SerialisationManager::new(Serialiser::V2), - PathBuf::from("/tmp/rusty-db"), - ).into(); - - loop { - print!(">> "); - - stdout.flush().unwrap(); - - let mut input = String::new(); - - stdin.read_line(&mut input).unwrap(); - - if input == "\\q\n" { - break; - } else if input.is_empty() { - println!(); - break; - } - - // TODO: Standardise handling these special commands - if input.starts_with("\\l ") { - let tokens = Lexer::lex(input.strip_prefix("\\l ").unwrap()); - - println!("Lexed: {tokens:?}"); - - continue; - } - - let statement = parse_statement(&input); - - if input.starts_with("\\p ") { - let statement = parse_statement(input.strip_prefix("\\p ").unwrap()); - - println!("Parsed: {statement:?}"); - - continue; - } - - if input.starts_with("\\c ") { - let database_name = input.strip_prefix("\\c ").unwrap().strip_suffix('\n').unwrap(); - - database = match persistence_manager.load_database(DatabaseName(database_name.into())).await { - Ok(db) => { - println!("Connected to database {}", db.name.0); - - Some(db) - }, - Err(error) => { - println!("Got execution error: {error:?}"); - - None - }, - }; - - continue; - } - - if let Some(statement) = statement { - let is_create_database = matches!(statement, Statement::Create { what: CreateType::Database, .. }); - let is_drop_database = matches!(statement, Statement::Drop { what: CreateType::Database, .. }); - - let result = statement.execute(database.as_mut(), persistence_manager.as_ref()).await; - - match result { - Ok(result) => { - match result { - ExecutionResult::None => (), - an_actual_result => println!("Executed:\n{an_actual_result:?}"), - } - }, - Err(error) => { - println!("Got execution error: {error:?}"); - - // Don't persist storage if statement failed - continue; - } - } - - if is_create_database || is_drop_database { - continue; - } - - // TODO: doing this properly, should only write changed things - // Also I can probably do better than the `is_drop_database` above - match persistence_manager.save_database(database.as_ref().unwrap()).await { - Ok(_) => (), - Err(error) => println!("Failed saving to disk: {error:?}"), - } - } else { - println!("Failed to parse: {input}"); +#[tokio::main] +async fn main() { + let listener = TcpListener::bind("localhost:42069").unwrap(); + println!("Listening on localhost:42069 (of course)"); + + let mut join_handles = vec![]; + + for stream in listener.incoming() { + join_handles.retain(|handle: &JoinHandle<_>| { + !handle.is_finished() + }); + + match stream { + Ok(stream) => { + println!("New connection established from {}", stream.peer_addr().unwrap()); + println!("Now have {} connections", join_handles.len() + 1); + + join_handles.push(spawn(async move { + handle_connection(stream).await + })); + }, + Err(error) => panic!("{error}"), } } -} -#[tokio::main] -async fn main() { - repl().await; + join_all(join_handles).await + .into_iter() + .collect::, _>, _>>().unwrap().unwrap(); + + println!("Main thread exiting"); }