Skip to content

Commit

Permalink
Add support for bearer token
Browse files Browse the repository at this point in the history
  • Loading branch information
cswinter committed Jun 8, 2024
1 parent 1052869 commit a9eb075
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/bin/db_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ fn ingest(opts: &Opts, db: &LocustDB, small_tables: &[String]) -> u64 {
addr,
64 * (1 << 20),
BufferFullPolicy::Block,
None,
);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
if !opts.large_only {
Expand Down
1 change: 1 addition & 0 deletions src/bin/load_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ async fn main() {
&addr,
1 << 28,
BufferFullPolicy::Block,
None,
);
let mut interval = time::interval(Duration::from_millis(interval));

Expand Down
1 change: 1 addition & 0 deletions src/bin/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async fn main() {
&addr,
1 << 50,
BufferFullPolicy::Block,
None,
);
let mut rng = rand::thread_rng();
let mut random_walks = (0..5)
Expand Down
32 changes: 25 additions & 7 deletions src/logging_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use locustdb_serialization::api::{AnyVal, Column, ColumnNameRequest, ColumnNameResponse, EncodingOpts, MultiQueryRequest, MultiQueryResponse, QueryResponse};
use locustdb_serialization::api::{
AnyVal, Column, ColumnNameRequest, ColumnNameResponse, EncodingOpts, MultiQueryRequest,
MultiQueryResponse, QueryResponse,
};
use locustdb_serialization::event_buffer::EventBuffer;
use reqwest::header::CONTENT_TYPE;
use reqwest::header::{self, CONTENT_TYPE};
use tokio::select;
use tokio::time::{self, MissedTickBehavior};
use tokio_util::sync::CancellationToken;


pub struct LoggingClient {
// Table -> Rows
events: Arc<Mutex<EventBuffer>>,
Expand All @@ -24,6 +26,7 @@ pub struct LoggingClient {
query_url: String,
columns_url: String,
buffer_full_policy: BufferFullPolicy,
bearer_token: Option<String>,
}

struct BackgroundWorker {
Expand Down Expand Up @@ -55,6 +58,7 @@ impl LoggingClient {
locustdb_url: &str,
max_buffer_size_bytes: usize,
buffer_full_policy: BufferFullPolicy,
bearer_token: Option<String>,
) -> LoggingClient {
let buffer: Arc<Mutex<EventBuffer>> = Arc::default();
let shutdown = CancellationToken::new();
Expand Down Expand Up @@ -84,6 +88,7 @@ impl LoggingClient {
query_url: format!("{locustdb_url}/multi_query_cols"),
columns_url: format!("{locustdb_url}/columns"),
buffer_full_policy,
bearer_token,
}
}

Expand All @@ -99,7 +104,7 @@ impl LoggingClient {
let response = self
.query_client
.post(&self.query_url)
.header(CONTENT_TYPE, "application/json")
.headers(self.headers())
.json(&request_body)
.send()
.await?;
Expand All @@ -113,7 +118,9 @@ impl LoggingClient {
rsps.iter_mut().for_each(|rsp| {
rsp.columns.iter_mut().for_each(|(_, col)| {
if let Column::Xor(data) = col {
*col = Column::Float(locustdb_compression_utils::xor_float::double::decode(&data[..]).unwrap())
*col = Column::Float(
locustdb_compression_utils::xor_float::double::decode(&data[..]).unwrap(),
)
}
});
});
Expand Down Expand Up @@ -186,7 +193,7 @@ impl LoggingClient {
let response = self
.query_client
.post(&self.columns_url)
.header(CONTENT_TYPE, "application/json")
.headers(self.headers())
.json(&request_body)
.send()
.await?;
Expand All @@ -205,6 +212,15 @@ impl LoggingClient {
tokio::time::sleep(self.flush_interval).await;
}
}

fn headers(&self) -> header::HeaderMap {
let mut headers = header::HeaderMap::new();
headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
if let Some(bearer_token) = self.bearer_token.as_ref() {
headers.insert("Authorization", bearer_token.parse().unwrap());
}
headers
}
}

impl BackgroundWorker {
Expand All @@ -219,7 +235,9 @@ impl BackgroundWorker {
}
loop {
self.flush().await;
if self.request_data.lock().unwrap().is_none() && self.events.lock().unwrap().tables.is_empty() {
if self.request_data.lock().unwrap().is_none()
&& self.events.lock().unwrap().tables.is_empty()
{
break;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ fn locustdb(m: &Bound<'_, PyModule>) -> PyResult<()> {
#[pymethods]
impl Client {
#[new]
#[pyo3(signature = (url, max_buffer_size_bytes = 128 * (1 << 20), block_when_buffer_full = false, flush_interval_seconds = 1))]
#[pyo3(signature = (url, max_buffer_size_bytes = 128 * (1 << 20), block_when_buffer_full = false, flush_interval_seconds = 1, bearer_token = None))]
fn new(
url: &str,
max_buffer_size_bytes: usize,
block_when_buffer_full: bool,
flush_interval_seconds: u64,
bearer_token: Option<String>,
) -> Self {
let _guard = RT.enter();
Self {
Expand All @@ -48,6 +49,7 @@ impl Client {
} else {
BufferFullPolicy::Drop
},
bearer_token,
),
}
}
Expand Down
5 changes: 5 additions & 0 deletions tests/ingestion_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ fn ingest(offset: usize, rows: usize, random_cols: usize, tables: &[String]) {
addr,
64 * (1 << 20),
BufferFullPolicy::Block,
None,
);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
for row in 0..rows {
Expand Down Expand Up @@ -222,6 +223,7 @@ async fn test_ingest_sparse_nullable() {
// Set max buffer size to 0 to ensure we ingest one row at a time
0,
BufferFullPolicy::Block,
None,
);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
let mut vals = vec![];
Expand Down Expand Up @@ -284,6 +286,7 @@ async fn test_persist_meta_tables() {
&addr,
0,
BufferFullPolicy::Block,
None,
);
log.log("qwerty", [("value".to_string(), vf64(1.0))]);
log.log("asdf", [("value".to_string(), vf64(1.0))]);
Expand Down Expand Up @@ -340,6 +343,7 @@ async fn test_many_concurrent_requests() {
&addr,
1 << 20,
BufferFullPolicy::Block,
None,
);
for i in 0..value_count {
log.log(&table, [("value".to_string(), vf64(i))]);
Expand All @@ -365,6 +369,7 @@ async fn test_many_concurrent_requests() {
&addr,
0,
BufferFullPolicy::Block,
None,
);
let query = format!("SELECT SUM(value) AS total FROM table_{:02}", tid);
let mut last_log_time = Instant::now();
Expand Down

0 comments on commit a9eb075

Please sign in to comment.