From a9eb075d1ed20c72326f89dd6db169a261a6b75b Mon Sep 17 00:00:00 2001 From: Clemens Winter Date: Sat, 8 Jun 2024 13:26:42 -0700 Subject: [PATCH] Add support for bearer token --- src/bin/db_bench.rs | 1 + src/bin/load_generator.rs | 1 + src/bin/log.rs | 1 + src/logging_client/mod.rs | 32 +++++++++++++++++++++++++------- src/python.rs | 4 +++- tests/ingestion_test.rs | 5 +++++ 6 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/bin/db_bench.rs b/src/bin/db_bench.rs index be5e95b1..99dd3abc 100644 --- a/src/bin/db_bench.rs +++ b/src/bin/db_bench.rs @@ -189,6 +189,7 @@ fn ingest(opts: &Opts, db: &LocustDB, small_tables: &[String]) -> u64 { addr, 64 * (1 << 20), BufferFullPolicy::Block, + None, ); let mut rng = rand::rngs::SmallRng::seed_from_u64(0); if !opts.large_only { diff --git a/src/bin/load_generator.rs b/src/bin/load_generator.rs index 05847d9e..d2e9fa5a 100644 --- a/src/bin/load_generator.rs +++ b/src/bin/load_generator.rs @@ -57,6 +57,7 @@ async fn main() { &addr, 1 << 28, BufferFullPolicy::Block, + None, ); let mut interval = time::interval(Duration::from_millis(interval)); diff --git a/src/bin/log.rs b/src/bin/log.rs index 65f4f458..ea695165 100644 --- a/src/bin/log.rs +++ b/src/bin/log.rs @@ -44,6 +44,7 @@ async fn main() { &addr, 1 << 50, BufferFullPolicy::Block, + None, ); let mut rng = rand::thread_rng(); let mut random_walks = (0..5) diff --git a/src/logging_client/mod.rs b/src/logging_client/mod.rs index 10970268..607952a3 100644 --- a/src/logging_client/mod.rs +++ b/src/logging_client/mod.rs @@ -3,14 +3,16 @@ use std::sync::atomic::AtomicU64; use std::sync::{Arc, Condvar, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use locustdb_serialization::api::{AnyVal, Column, ColumnNameRequest, ColumnNameResponse, EncodingOpts, MultiQueryRequest, MultiQueryResponse, QueryResponse}; +use locustdb_serialization::api::{ + AnyVal, Column, ColumnNameRequest, ColumnNameResponse, EncodingOpts, MultiQueryRequest, + MultiQueryResponse, QueryResponse, +}; use locustdb_serialization::event_buffer::EventBuffer; -use reqwest::header::CONTENT_TYPE; +use reqwest::header::{self, CONTENT_TYPE}; use tokio::select; use tokio::time::{self, MissedTickBehavior}; use tokio_util::sync::CancellationToken; - pub struct LoggingClient { // Table -> Rows events: Arc>, @@ -24,6 +26,7 @@ pub struct LoggingClient { query_url: String, columns_url: String, buffer_full_policy: BufferFullPolicy, + bearer_token: Option, } struct BackgroundWorker { @@ -55,6 +58,7 @@ impl LoggingClient { locustdb_url: &str, max_buffer_size_bytes: usize, buffer_full_policy: BufferFullPolicy, + bearer_token: Option, ) -> LoggingClient { let buffer: Arc> = Arc::default(); let shutdown = CancellationToken::new(); @@ -84,6 +88,7 @@ impl LoggingClient { query_url: format!("{locustdb_url}/multi_query_cols"), columns_url: format!("{locustdb_url}/columns"), buffer_full_policy, + bearer_token, } } @@ -99,7 +104,7 @@ impl LoggingClient { let response = self .query_client .post(&self.query_url) - .header(CONTENT_TYPE, "application/json") + .headers(self.headers()) .json(&request_body) .send() .await?; @@ -113,7 +118,9 @@ impl LoggingClient { rsps.iter_mut().for_each(|rsp| { rsp.columns.iter_mut().for_each(|(_, col)| { if let Column::Xor(data) = col { - *col = Column::Float(locustdb_compression_utils::xor_float::double::decode(&data[..]).unwrap()) + *col = Column::Float( + locustdb_compression_utils::xor_float::double::decode(&data[..]).unwrap(), + ) } }); }); @@ -186,7 +193,7 @@ impl LoggingClient { let response = self .query_client .post(&self.columns_url) - .header(CONTENT_TYPE, "application/json") + .headers(self.headers()) .json(&request_body) .send() .await?; @@ -205,6 +212,15 @@ impl LoggingClient { tokio::time::sleep(self.flush_interval).await; } } + + fn headers(&self) -> header::HeaderMap { + let mut headers = header::HeaderMap::new(); + headers.insert(CONTENT_TYPE, "application/json".parse().unwrap()); + if let Some(bearer_token) = self.bearer_token.as_ref() { + headers.insert("Authorization", bearer_token.parse().unwrap()); + } + headers + } } impl BackgroundWorker { @@ -219,7 +235,9 @@ impl BackgroundWorker { } loop { self.flush().await; - if self.request_data.lock().unwrap().is_none() && self.events.lock().unwrap().tables.is_empty() { + if self.request_data.lock().unwrap().is_none() + && self.events.lock().unwrap().tables.is_empty() + { break; } } diff --git a/src/python.rs b/src/python.rs index cdf80b22..0c754542 100644 --- a/src/python.rs +++ b/src/python.rs @@ -30,12 +30,13 @@ fn locustdb(m: &Bound<'_, PyModule>) -> PyResult<()> { #[pymethods] impl Client { #[new] - #[pyo3(signature = (url, max_buffer_size_bytes = 128 * (1 << 20), block_when_buffer_full = false, flush_interval_seconds = 1))] + #[pyo3(signature = (url, max_buffer_size_bytes = 128 * (1 << 20), block_when_buffer_full = false, flush_interval_seconds = 1, bearer_token = None))] fn new( url: &str, max_buffer_size_bytes: usize, block_when_buffer_full: bool, flush_interval_seconds: u64, + bearer_token: Option, ) -> Self { let _guard = RT.enter(); Self { @@ -48,6 +49,7 @@ impl Client { } else { BufferFullPolicy::Drop }, + bearer_token, ), } } diff --git a/tests/ingestion_test.rs b/tests/ingestion_test.rs index e5e3c016..24494a69 100644 --- a/tests/ingestion_test.rs +++ b/tests/ingestion_test.rs @@ -165,6 +165,7 @@ fn ingest(offset: usize, rows: usize, random_cols: usize, tables: &[String]) { addr, 64 * (1 << 20), BufferFullPolicy::Block, + None, ); let mut rng = rand::rngs::SmallRng::seed_from_u64(0); for row in 0..rows { @@ -222,6 +223,7 @@ async fn test_ingest_sparse_nullable() { // Set max buffer size to 0 to ensure we ingest one row at a time 0, BufferFullPolicy::Block, + None, ); let mut rng = rand::rngs::SmallRng::seed_from_u64(0); let mut vals = vec![]; @@ -284,6 +286,7 @@ async fn test_persist_meta_tables() { &addr, 0, BufferFullPolicy::Block, + None, ); log.log("qwerty", [("value".to_string(), vf64(1.0))]); log.log("asdf", [("value".to_string(), vf64(1.0))]); @@ -340,6 +343,7 @@ async fn test_many_concurrent_requests() { &addr, 1 << 20, BufferFullPolicy::Block, + None, ); for i in 0..value_count { log.log(&table, [("value".to_string(), vf64(i))]); @@ -365,6 +369,7 @@ async fn test_many_concurrent_requests() { &addr, 0, BufferFullPolicy::Block, + None, ); let query = format!("SELECT SUM(value) AS total FROM table_{:02}", tid); let mut last_log_time = Instant::now();