Skip to content

Commit

Permalink
feat: sql proxy endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Nov 20, 2024
1 parent d005475 commit 77ece4c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
1 change: 1 addition & 0 deletions bin/torii/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ async fn main() -> anyhow::Result<()> {
Some(grpc_addr),
None,
Some(artifacts_addr),
Arc::new(pool.clone()),
));

let graphql_server = spawn_rebuilding_graphql_server(
Expand Down
8 changes: 6 additions & 2 deletions crates/torii/core/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ impl Sql {
}

pub async fn model(&self, selector: Felt) -> Result<Model> {
self.model_cache.model(&selector).await.map_err(|e| e.into())
self.model_cache.model(&selector).await.map_err(|e| e.into())
}

pub async fn does_entity_exist(&self, model: String, key: Felt) -> Result<bool> {
Expand Down Expand Up @@ -828,7 +828,11 @@ impl Sql {
Ty::Enum(e) => {
if e.options.iter().all(
|o| {
if let Ty::Tuple(t) = &o.ty { t.is_empty() } else { false }
if let Ty::Tuple(t) = &o.ty {
t.is_empty()
} else {
false
}
},
) {
return Ok(());
Expand Down
51 changes: 48 additions & 3 deletions crates/torii/server/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tokio::sync::RwLock;
use tower::ServiceBuilder;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::error;
use sqlx::SqlitePool;

const DEFAULT_ALLOW_HEADERS: [&str; 13] = [
"accept",
Expand Down Expand Up @@ -60,6 +61,7 @@ pub struct Proxy {
grpc_addr: Option<SocketAddr>,
artifacts_addr: Option<SocketAddr>,
graphql_addr: Arc<RwLock<Option<SocketAddr>>>,
pool: Arc<SqlitePool>,
}

impl Proxy {
Expand All @@ -69,13 +71,15 @@ impl Proxy {
grpc_addr: Option<SocketAddr>,
graphql_addr: Option<SocketAddr>,
artifacts_addr: Option<SocketAddr>,
pool: Arc<SqlitePool>,
) -> Self {
Self {
addr,
allowed_origins,
grpc_addr,
graphql_addr: Arc::new(RwLock::new(graphql_addr)),
artifacts_addr,
pool,
}
}

Expand All @@ -93,6 +97,7 @@ impl Proxy {
let grpc_addr = self.grpc_addr;
let graphql_addr = self.graphql_addr.clone();
let artifacts_addr = self.artifacts_addr;
let pool = self.pool.clone();

let make_svc = make_service_fn(move |conn: &AddrStream| {
let remote_addr = conn.remote_addr().ip();
Expand Down Expand Up @@ -131,10 +136,10 @@ impl Proxy {

let graphql_addr_clone = graphql_addr.clone();
let service = ServiceBuilder::new().option_layer(cors).service_fn(move |req| {
let graphql_addr = graphql_addr_clone.clone();
let pool = pool.clone();
async move {
let graphql_addr = graphql_addr.read().await;
handle(remote_addr, grpc_addr, artifacts_addr, *graphql_addr, req).await
let graphql_addr = graphql_addr_clone.clone();
handle(remote_addr, grpc_addr, artifacts_addr, *graphql_addr, pool, req).await
}
});

Expand All @@ -156,6 +161,7 @@ async fn handle(
grpc_addr: Option<SocketAddr>,
artifacts_addr: Option<SocketAddr>,
graphql_addr: Option<SocketAddr>,
pool: Arc<SqlitePool>,
req: Request<Body>,
) -> Result<Response<Body>, Infallible> {
if req.uri().path().starts_with("/static") {
Expand Down Expand Up @@ -224,6 +230,45 @@ async fn handle(
}
}

if req.uri().path().starts_with("/sql") {
if req.method() != Method::POST {
return Ok(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::from("Only POST method is allowed"))
.unwrap());
}

let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap_or_default();
let query = String::from_utf8(body_bytes.to_vec()).unwrap_or_default();

if !query.trim().to_uppercase().starts_with("SELECT") {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Only SELECT queries are allowed"))
.unwrap());
}

return match sqlx::query(&query).fetch_all(&*pool).await {
Ok(rows) => {
let json = serde_json::to_string(
&rows.iter()
.map(|row| row.columns().iter().map(|col| col.name()).collect::<Vec<_>>())
.collect::<Vec<_>>()
).unwrap();

Ok(Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(json))
.unwrap())
}
Err(e) => Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Query error: {}", e)))
.unwrap()),
};
}

let json = json!({
"service": "torii",
"success": true
Expand Down

0 comments on commit 77ece4c

Please sign in to comment.