From 9790c5844f6cec3f2487879426d6e47ebe243117 Mon Sep 17 00:00:00 2001 From: Encephala Date: Tue, 18 Jun 2024 22:44:59 +0200 Subject: [PATCH] Add serialiser version negotiation --- cli/src/main.rs | 28 +++++++++++++++++----- dbms/src/lib.rs | 2 +- dbms/src/persistence.rs | 2 +- dbms/src/server/connection/mod.rs | 40 +++++++++++++++++++++++-------- dbms/src/server/mod.rs | 2 -- dbms/src/utils.rs | 13 ++++++++++ 6 files changed, 67 insertions(+), 20 deletions(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index e62cead..c1954be 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -3,7 +3,7 @@ mod serverless; use std::io::Write; -use tokio::{io::BufReader, net::{ +use tokio::{io::{AsyncReadExt, AsyncWriteExt, BufReader}, net::{ TcpStream, ToSocketAddrs, }}; @@ -12,6 +12,7 @@ use dbms::{ SqlError, serialisation::{SerialisationManager, Serialiser}, server::Message, + utils::serialiser_version_to_serialiser, }; const SERIALISATION_MANAGER: SerialisationManager = SerialisationManager::new(Serialiser::V2); @@ -23,12 +24,27 @@ async fn session(address: impl ToSocketAddrs) -> Result<(), SqlError> { let mut reader = BufReader::new(reader); - // let welcome_message = Message::read(&mut reader).await?; + // TODO: How do I do this? + let number_of_serialisers = stream.read_u8().await + .map_err(SqlError::CouldNotReadFromConnection)?; - // println!( - // "Got welcome message: {}", - // String::from_utf8(welcome_message.0).unwrap() - // ); + let mut serialisers_buffer = vec![0_u8; number_of_serialisers as usize]; + + stream.read_exact(&mut serialisers_buffer).await + .map_err(SqlError::CouldNotReadFromConnection)?; + + if serialisers_buffer.is_empty() { + return Err(SqlError::InputTooShort(0, 1)); + } + + let highest_serialiser = serialisers_buffer.iter().max().unwrap(); + + let serialiser = serialiser_version_to_serialiser(*highest_serialiser)?; + + println!("Chose serialiser {serialiser:?}"); + + stream.write_u8(*highest_serialiser).await + .map_err(SqlError::CouldNotWriteToConnection)?; todo!(); diff --git a/dbms/src/lib.rs b/dbms/src/lib.rs index 6877b81..3206444 100644 --- a/dbms/src/lib.rs +++ b/dbms/src/lib.rs @@ -4,7 +4,7 @@ mod database; pub mod types; pub mod evaluate; -mod utils; +pub mod utils; pub mod persistence; pub mod serialisation; pub mod server; diff --git a/dbms/src/persistence.rs b/dbms/src/persistence.rs index 00fb63b..528e2f6 100644 --- a/dbms/src/persistence.rs +++ b/dbms/src/persistence.rs @@ -13,7 +13,7 @@ use super::serialisation::SerialisationManager; // Love me some premature abstractions #[async_trait] -pub trait PersistenceManager: Send + Sync { +pub trait PersistenceManager: std::fmt::Debug + Send + Sync { async fn save_database(&self, database: &Database) -> Result<()>; async fn delete_database(&self, name: DatabaseName) -> Result; diff --git a/dbms/src/server/connection/mod.rs b/dbms/src/server/connection/mod.rs index 71ce91d..81bf459 100644 --- a/dbms/src/server/connection/mod.rs +++ b/dbms/src/server/connection/mod.rs @@ -4,29 +4,27 @@ mod tests; use std::{net::SocketAddr, path::PathBuf}; use tokio::{ - net::TcpStream, sync::broadcast::Receiver + io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, sync::broadcast::Receiver }; use crate::{ evaluate::{ Execute, ExecutionResult - }, - persistence::{ + }, persistence::{ FileSystem, PersistenceManager - }, - serialisation::{ + }, serialisation::{ SerialisationManager, Serialiser - }, - types::DatabaseName, Database, Result, SqlError + }, types::DatabaseName, utils::serialiser_version_to_serialiser, Database, Result, SqlError }; use sql_parse::{parse_statement, parser::{CreateType, Statement}}; use super::protocol::{Message, MessageBody}; +#[derive(Debug)] struct Runtime { persistence_manager: Box, database: Option, @@ -38,6 +36,7 @@ pub struct Connection { context: Context, } +#[derive(Debug)] pub struct Context { peer_address: SocketAddr, serialiser: Serialiser, @@ -46,7 +45,6 @@ pub struct Context { impl Connection { pub async fn new(mut stream: TcpStream, shutdown_receiver: Receiver<()>) -> Result { - // TODO: Negotiate connection parameters let context = Connection::setup_context(&mut stream).await?; return Ok(Connection { @@ -60,11 +58,12 @@ impl Connection { /// /// Returns a [`Context`] object populated with these parameters /// as well as other (default) parameters. + // I don't quite like this function name async fn setup_context(stream: &mut TcpStream) -> Result { let peer_address = stream.peer_addr() .map_err(SqlError::CouldNotReadFromConnection)?; - let serialiser = todo!(); + let serialiser = Connection::negotiate_serialiser_version(stream).await?; let runtime = Runtime { persistence_manager: Box::new(FileSystem::new( @@ -81,8 +80,29 @@ impl Connection { }); } + async fn negotiate_serialiser_version(stream: &mut TcpStream) -> Result { + let available_serialiser_versions = [1, 2]; + + stream.write_all((available_serialiser_versions.len() as u8).to_le_bytes().as_slice()).await + .map_err(SqlError::CouldNotReadFromConnection)?; + + stream.write_all(available_serialiser_versions.as_slice()).await + .map_err(SqlError::CouldNotReadFromConnection)?; + + + let mut serialiser_version_buffer = [0_u8]; + + stream.read_exact(&mut serialiser_version_buffer).await + .map_err(SqlError::CouldNotReadFromConnection)?; + + let [decided_version] = serialiser_version_buffer; + + let serialiser = serialiser_version_to_serialiser(decided_version)?; + + return Ok(serialiser); + } + pub async fn handle(mut self) -> Result<()> { - println!("Handling connection in {:?}", std::thread::current()); loop { tokio::select! { diff --git a/dbms/src/server/mod.rs b/dbms/src/server/mod.rs index f8e701d..0cec7a2 100644 --- a/dbms/src/server/mod.rs +++ b/dbms/src/server/mod.rs @@ -100,8 +100,6 @@ fn spawn_new_handler( println!("New connection established from {address:?}"); return spawn(async move { - println!("Setting up new connection in {:?}", std::thread::current()); - let connection = Connection::new(stream, shutdown_receiver).await?; connection.handle().await diff --git a/dbms/src/utils.rs b/dbms/src/utils.rs index 5911633..78f231c 100644 --- a/dbms/src/utils.rs +++ b/dbms/src/utils.rs @@ -1,3 +1,16 @@ +use crate::{ + SqlError, Result, + serialisation::Serialiser +}; + +pub fn serialiser_version_to_serialiser(version: u8) -> Result{ + return match version { + 1 => Ok(Serialiser::V1), + 2 => Ok(Serialiser::V2), + other => Err(SqlError::IncompatibleVersion(other)), + }; +} + #[cfg(test)] pub mod tests { use sql_parse::parser::ColumnType;