Skip to content

Commit

Permalink
Add serialiser version negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
Encephala committed Jun 18, 2024
1 parent 9983fed commit 9790c58
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 20 deletions.
28 changes: 22 additions & 6 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod serverless;

use std::io::Write;

use tokio::{io::BufReader, net::{
use tokio::{io::{AsyncReadExt, AsyncWriteExt, BufReader}, net::{
TcpStream,
ToSocketAddrs,
}};
Expand All @@ -12,6 +12,7 @@ use dbms::{
SqlError,
serialisation::{SerialisationManager, Serialiser},
server::Message,
utils::serialiser_version_to_serialiser,
};

const SERIALISATION_MANAGER: SerialisationManager = SerialisationManager::new(Serialiser::V2);
Expand All @@ -23,12 +24,27 @@ async fn session(address: impl ToSocketAddrs) -> Result<(), SqlError> {

let mut reader = BufReader::new(reader);

// let welcome_message = Message::read(&mut reader).await?;
// TODO: How do I do this?
let number_of_serialisers = stream.read_u8().await
.map_err(SqlError::CouldNotReadFromConnection)?;

// println!(
// "Got welcome message: {}",
// String::from_utf8(welcome_message.0).unwrap()
// );
let mut serialisers_buffer = vec![0_u8; number_of_serialisers as usize];

stream.read_exact(&mut serialisers_buffer).await
.map_err(SqlError::CouldNotReadFromConnection)?;

if serialisers_buffer.is_empty() {
return Err(SqlError::InputTooShort(0, 1));
}

let highest_serialiser = serialisers_buffer.iter().max().unwrap();

let serialiser = serialiser_version_to_serialiser(*highest_serialiser)?;

println!("Chose serialiser {serialiser:?}");

stream.write_u8(*highest_serialiser).await
.map_err(SqlError::CouldNotWriteToConnection)?;

todo!();

Expand Down
2 changes: 1 addition & 1 deletion dbms/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
mod database;
pub mod types;
pub mod evaluate;
mod utils;
pub mod utils;
pub mod persistence;
pub mod serialisation;
pub mod server;
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/persistence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::serialisation::SerialisationManager;

// Love me some premature abstractions
#[async_trait]
pub trait PersistenceManager: Send + Sync {
pub trait PersistenceManager: std::fmt::Debug + Send + Sync {
async fn save_database(&self, database: &Database) -> Result<()>;
async fn delete_database(&self, name: DatabaseName) -> Result<DatabaseName>;

Expand Down
40 changes: 30 additions & 10 deletions dbms/src/server/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@ mod tests;
use std::{net::SocketAddr, path::PathBuf};

use tokio::{
net::TcpStream, sync::broadcast::Receiver
io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, sync::broadcast::Receiver
};

use crate::{
evaluate::{
Execute,
ExecutionResult
},
persistence::{
}, persistence::{
FileSystem,
PersistenceManager
},
serialisation::{
}, serialisation::{
SerialisationManager,
Serialiser
},
types::DatabaseName, Database, Result, SqlError
}, types::DatabaseName, utils::serialiser_version_to_serialiser, Database, Result, SqlError
};

use sql_parse::{parse_statement, parser::{CreateType, Statement}};

use super::protocol::{Message, MessageBody};

#[derive(Debug)]
struct Runtime {
persistence_manager: Box<dyn PersistenceManager>,
database: Option<Database>,
Expand All @@ -38,6 +36,7 @@ pub struct Connection {
context: Context,
}

#[derive(Debug)]
pub struct Context {
peer_address: SocketAddr,
serialiser: Serialiser,
Expand All @@ -46,7 +45,6 @@ pub struct Context {

impl Connection {
pub async fn new(mut stream: TcpStream, shutdown_receiver: Receiver<()>) -> Result<Self> {
// TODO: Negotiate connection parameters
let context = Connection::setup_context(&mut stream).await?;

return Ok(Connection {
Expand All @@ -60,11 +58,12 @@ impl Connection {
///
/// Returns a [`Context`] object populated with these parameters
/// as well as other (default) parameters.
// I don't quite like this function name
async fn setup_context(stream: &mut TcpStream) -> Result<Context> {
let peer_address = stream.peer_addr()
.map_err(SqlError::CouldNotReadFromConnection)?;

let serialiser = todo!();
let serialiser = Connection::negotiate_serialiser_version(stream).await?;

let runtime = Runtime {
persistence_manager: Box::new(FileSystem::new(
Expand All @@ -81,8 +80,29 @@ impl Connection {
});
}

async fn negotiate_serialiser_version(stream: &mut TcpStream) -> Result<Serialiser> {
let available_serialiser_versions = [1, 2];

stream.write_all((available_serialiser_versions.len() as u8).to_le_bytes().as_slice()).await
.map_err(SqlError::CouldNotReadFromConnection)?;

stream.write_all(available_serialiser_versions.as_slice()).await
.map_err(SqlError::CouldNotReadFromConnection)?;


let mut serialiser_version_buffer = [0_u8];

stream.read_exact(&mut serialiser_version_buffer).await
.map_err(SqlError::CouldNotReadFromConnection)?;

let [decided_version] = serialiser_version_buffer;

let serialiser = serialiser_version_to_serialiser(decided_version)?;

return Ok(serialiser);
}

pub async fn handle(mut self) -> Result<()> {
println!("Handling connection in {:?}", std::thread::current());

loop {
tokio::select! {
Expand Down
2 changes: 0 additions & 2 deletions dbms/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ fn spawn_new_handler(
println!("New connection established from {address:?}");

return spawn(async move {
println!("Setting up new connection in {:?}", std::thread::current());

let connection = Connection::new(stream, shutdown_receiver).await?;

connection.handle().await
Expand Down
13 changes: 13 additions & 0 deletions dbms/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
use crate::{
SqlError, Result,
serialisation::Serialiser
};

pub fn serialiser_version_to_serialiser(version: u8) -> Result<Serialiser>{
return match version {
1 => Ok(Serialiser::V1),
2 => Ok(Serialiser::V2),
other => Err(SqlError::IncompatibleVersion(other)),
};
}

#[cfg(test)]
pub mod tests {
use sql_parse::parser::ColumnType;
Expand Down

0 comments on commit 9790c58

Please sign in to comment.