-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(torii): sql proxy endpoint for querying #2706
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
use std::sync::Arc; | ||
use std::time::Duration; | ||
|
||
use base64::engine::general_purpose::STANDARD; | ||
use base64::Engine; | ||
use http::header::CONTENT_TYPE; | ||
use http::{HeaderName, Method}; | ||
use hyper::client::connect::dns::GaiResolver; | ||
|
@@ -12,11 +14,14 @@ | |
use hyper::{Body, Client, Request, Response, Server, StatusCode}; | ||
use hyper_reverse_proxy::ReverseProxy; | ||
use serde_json::json; | ||
use sqlx::{Column, Row, SqlitePool, TypeInfo}; | ||
use tokio::sync::RwLock; | ||
use tower::ServiceBuilder; | ||
use tower_http::cors::{AllowOrigin, CorsLayer}; | ||
use tracing::error; | ||
|
||
pub(crate) const LOG_TARGET: &str = "torii::server::proxy"; | ||
|
||
const DEFAULT_ALLOW_HEADERS: [&str; 13] = [ | ||
"accept", | ||
"origin", | ||
|
@@ -60,6 +65,7 @@ | |
grpc_addr: Option<SocketAddr>, | ||
artifacts_addr: Option<SocketAddr>, | ||
graphql_addr: Arc<RwLock<Option<SocketAddr>>>, | ||
pool: Arc<SqlitePool>, | ||
} | ||
|
||
impl Proxy { | ||
|
@@ -69,13 +75,15 @@ | |
grpc_addr: Option<SocketAddr>, | ||
graphql_addr: Option<SocketAddr>, | ||
artifacts_addr: Option<SocketAddr>, | ||
pool: Arc<SqlitePool>, | ||
) -> Self { | ||
Self { | ||
addr, | ||
allowed_origins, | ||
grpc_addr, | ||
graphql_addr: Arc::new(RwLock::new(graphql_addr)), | ||
artifacts_addr, | ||
pool, | ||
} | ||
} | ||
|
||
|
@@ -93,6 +101,7 @@ | |
let grpc_addr = self.grpc_addr; | ||
let graphql_addr = self.graphql_addr.clone(); | ||
let artifacts_addr = self.artifacts_addr; | ||
let pool = self.pool.clone(); | ||
|
||
let make_svc = make_service_fn(move |conn: &AddrStream| { | ||
let remote_addr = conn.remote_addr().ip(); | ||
|
@@ -129,12 +138,14 @@ | |
), | ||
}); | ||
|
||
let pool_clone = pool.clone(); | ||
let graphql_addr_clone = graphql_addr.clone(); | ||
let service = ServiceBuilder::new().option_layer(cors).service_fn(move |req| { | ||
let pool = pool_clone.clone(); | ||
let graphql_addr = graphql_addr_clone.clone(); | ||
async move { | ||
let graphql_addr = graphql_addr.read().await; | ||
handle(remote_addr, grpc_addr, artifacts_addr, *graphql_addr, req).await | ||
handle(remote_addr, grpc_addr, artifacts_addr, *graphql_addr, pool, req).await | ||
} | ||
}); | ||
|
||
|
@@ -156,6 +167,7 @@ | |
grpc_addr: Option<SocketAddr>, | ||
artifacts_addr: Option<SocketAddr>, | ||
graphql_addr: Option<SocketAddr>, | ||
pool: Arc<SqlitePool>, | ||
req: Request<Body>, | ||
) -> Result<Response<Body>, Infallible> { | ||
if req.uri().path().starts_with("/static") { | ||
|
@@ -165,7 +177,7 @@ | |
return match GRAPHQL_PROXY_CLIENT.call(client_ip, &artifacts_addr, req).await { | ||
Ok(response) => Ok(response), | ||
Err(_error) => { | ||
error!("{:?}", _error); | ||
error!(target: LOG_TARGET, "Artifacts proxy error: {:?}", _error); | ||
Ok(Response::builder() | ||
.status(StatusCode::INTERNAL_SERVER_ERROR) | ||
.body(Body::empty()) | ||
|
@@ -186,7 +198,7 @@ | |
return match GRAPHQL_PROXY_CLIENT.call(client_ip, &graphql_addr, req).await { | ||
Ok(response) => Ok(response), | ||
Err(_error) => { | ||
error!("{:?}", _error); | ||
error!(target: LOG_TARGET, "GraphQL proxy error: {:?}", _error); | ||
Ok(Response::builder() | ||
.status(StatusCode::INTERNAL_SERVER_ERROR) | ||
.body(Body::empty()) | ||
|
@@ -208,7 +220,7 @@ | |
return match GRPC_PROXY_CLIENT.call(client_ip, &grpc_addr, req).await { | ||
Ok(response) => Ok(response), | ||
Err(_error) => { | ||
error!("{:?}", _error); | ||
error!(target: LOG_TARGET, "GRPC proxy error: {:?}", _error); | ||
Ok(Response::builder() | ||
.status(StatusCode::INTERNAL_SERVER_ERROR) | ||
.body(Body::empty()) | ||
|
@@ -224,6 +236,83 @@ | |
} | ||
} | ||
|
||
if req.uri().path().starts_with("/sql") { | ||
let query = if req.method() == Method::GET { | ||
// Get the query from URL parameters | ||
let params = req.uri().query().unwrap_or_default(); | ||
form_urlencoded::parse(params.as_bytes()) | ||
.find(|(key, _)| key == "q") | ||
.map(|(_, value)| value.to_string()) | ||
.unwrap_or_default() | ||
} else if req.method() == Method::POST { | ||
// Get the query from request body | ||
let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap_or_default(); | ||
String::from_utf8(body_bytes.to_vec()).unwrap_or_default() | ||
} else { | ||
return Ok(Response::builder() | ||
.status(StatusCode::METHOD_NOT_ALLOWED) | ||
.body(Body::from("Only GET and POST methods are allowed")) | ||
.unwrap()); | ||
}; | ||
|
||
// Execute the query | ||
return match sqlx::query(&query).fetch_all(&*pool).await { | ||
Ok(rows) => { | ||
let result: Vec<_> = rows | ||
.iter() | ||
.map(|row| { | ||
let mut obj = serde_json::Map::new(); | ||
for (i, column) in row.columns().iter().enumerate() { | ||
let value: serde_json::Value = match column.type_info().name() { | ||
"TEXT" => row | ||
.get::<Option<String>, _>(i) | ||
.map_or(serde_json::Value::Null, serde_json::Value::String), | ||
// for operators like count(*) the type info is NULL | ||
// so we default to a number | ||
"INTEGER" | "NULL" => row | ||
.get::<Option<i64>, _>(i) | ||
.map_or(serde_json::Value::Null, |n| { | ||
serde_json::Value::Number(n.into()) | ||
}), | ||
"REAL" => row.get::<Option<f64>, _>(i).map_or( | ||
serde_json::Value::Null, | ||
|f| { | ||
serde_json::Number::from_f64(f).map_or( | ||
serde_json::Value::Null, | ||
serde_json::Value::Number, | ||
) | ||
}, | ||
), | ||
"BLOB" => row | ||
.get::<Option<Vec<u8>>, _>(i) | ||
.map_or(serde_json::Value::Null, |bytes| { | ||
serde_json::Value::String(STANDARD.encode(bytes)) | ||
}), | ||
_ => row | ||
.get::<Option<String>, _>(i) | ||
.map_or(serde_json::Value::Null, serde_json::Value::String), | ||
}; | ||
obj.insert(column.name().to_string(), value); | ||
} | ||
serde_json::Value::Object(obj) | ||
}) | ||
.collect(); | ||
|
||
let json = serde_json::to_string(&result).unwrap(); | ||
|
||
Ok(Response::builder() | ||
.status(StatusCode::OK) | ||
.header(CONTENT_TYPE, "application/json") | ||
.body(Body::from(json)) | ||
.unwrap()) | ||
} | ||
Err(e) => Ok(Response::builder() | ||
.status(StatusCode::BAD_REQUEST) | ||
.body(Body::from(format!("Query error: {:?}", e))) | ||
.unwrap()), | ||
Comment on lines
+310
to
+312
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid exposing detailed error messages to clients Including detailed error information in HTTP responses can reveal sensitive internal details and pose a security risk. It's better to log the error internally and return a generic message to the client. Apply this diff to return a generic error message: - Err(e) => Ok(Response::builder()
- .status(StatusCode::BAD_REQUEST)
- .body(Body::from(format!("Query error: {:?}", e)))
- .unwrap()),
+ Err(e) => {
+ error!(target: LOG_TARGET, "Query execution error: {:?}", e);
+ Ok(Response::builder()
+ .status(StatusCode::BAD_REQUEST)
+ .body(Body::from("Invalid query"))
+ .unwrap())
+ },
|
||
}; | ||
Comment on lines
+259
to
+313
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Implement safeguards against resource exhaustion Ohayo, sensei! Executing arbitrary SQL queries without limitations can lead to performance issues, such as high memory usage or slow response times if large result sets are returned. Consider implementing safeguards like limiting the number of rows returned or restricting the types of queries that can be executed. For example, you could limit the number of rows by appending |
||
} | ||
Comment on lines
+239
to
+314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider restricting or securing the Ohayo, sensei! Exposing an endpoint that allows execution of arbitrary SQL queries can be a significant security risk, even with a read-only database connection. It could lead to data leakage or other unintended consequences. Consider the following options:
Would you like assistance in refactoring this endpoint to enhance security? |
||
|
||
let json = json!({ | ||
"service": "torii", | ||
"success": true | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Ohayo, sensei! Avoid silently ignoring errors when reading request body
Using
unwrap_or_default()
when reading the request body may mask errors that should be handled appropriately. Consider properly handling errors when converting the request body to bytes and when converting bytes to a UTF-8 string.Apply this diff to handle errors explicitly: