diff --git a/crates/torii/server/src/cli.rs b/crates/torii/server/src/cli.rs index 2e7963e3c2..06e7c785c9 100644 --- a/crates/torii/server/src/cli.rs +++ b/crates/torii/server/src/cli.rs @@ -7,20 +7,28 @@ use std::sync::Arc; use clap::Parser; use dojo_world::contracts::world::WorldContractReader; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; +use sqlx::SqlitePool; use starknet::core::types::FieldElement; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::broadcast; +use tokio::sync::broadcast::Sender; +use tokio_stream::StreamExt; use torii_core::engine::{Engine, EngineConfig, Processors}; use torii_core::processors::metadata_update::MetadataUpdateProcessor; use torii_core::processors::register_model::RegisterModelProcessor; use torii_core::processors::store_set_record::StoreSetRecordProcessor; use torii_core::processors::store_transaction::StoreTransactionProcessor; +use torii_core::simple_broker::SimpleBroker; use torii_core::sql::Sql; +use torii_core::types::Model; +use tracing::info; use tracing_subscriber::fmt; use url::Url; +use crate::proxy::Proxy; + /// Dojo World Indexer #[derive(Parser, Debug)] #[command(name = "torii", author, version, about, long_about = None)] @@ -111,10 +119,6 @@ async fn main() -> anyhow::Result<()> { let addr: SocketAddr = format!("{}:{}", args.host, args.port).parse()?; - let shutdown_rx = shutdown_tx.subscribe(); - let (graphql_addr, graphql_server) = - torii_graphql::server::new(shutdown_rx, &pool, args.external_url).await; - let shutdown_rx = shutdown_tx.subscribe(); let (grpc_addr, grpc_server) = torii_grpc::server::new( shutdown_rx, @@ -125,8 +129,18 @@ async fn main() -> anyhow::Result<()> { ) .await?; - let shutdown_rx = shutdown_tx.subscribe(); - let proxy_server = proxy::new(shutdown_rx, addr, args.allowed_origins, grpc_addr, graphql_addr); + let proxy_server = Arc::new(Proxy::new(addr, args.allowed_origins, Some(grpc_addr), None)); + + let graphql_server = spawn_rebuilding_graphql_server( + shutdown_tx.clone(), + pool.into(), + args.external_url, + proxy_server.clone(), + ); + + info!("🚀 Torii listening at {}", format!("http://{}", addr)); + info!("Graphql playground: {}\n", format!("http://{}/graphql", addr)); + info!("GRPC playground: {}\n", format!("http://{}/grpc", addr)); tokio::select! { _ = sigterm.recv() => { @@ -137,10 +151,34 @@ async fn main() -> anyhow::Result<()> { } _ = engine.start() => {}, - _ = proxy_server.await => {}, + _ = proxy_server.start(shutdown_tx.subscribe()) => {}, _ = graphql_server => {}, _ = grpc_server => {}, }; Ok(()) } + +async fn spawn_rebuilding_graphql_server( + shutdown_tx: Sender<()>, + pool: Arc, + external_url: Option, + proxy_server: Arc, +) { + let mut broker = SimpleBroker::::subscribe(); + + loop { + let shutdown_rx = shutdown_tx.subscribe(); + let (new_addr, new_server) = + torii_graphql::server::new(shutdown_rx, &pool, external_url.clone()).await; + + tokio::spawn(new_server); + + proxy_server.set_graphql_addr(new_addr).await; + + // Break the loop if there are no more events + if broker.next().await.is_none() { + break; + } + } +} diff --git a/crates/torii/server/src/proxy.rs b/crates/torii/server/src/proxy.rs index 53a5932098..ea842129cd 100644 --- a/crates/torii/server/src/proxy.rs +++ b/crates/torii/server/src/proxy.rs @@ -1,12 +1,15 @@ use std::convert::Infallible; -use std::future::Future; use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::RwLock; +use http::header::CONTENT_TYPE; use http::{HeaderName, Method}; use hyper::server::conn::AddrStream; use hyper::service::make_service_fn; use hyper::{Body, Request, Response, Server, StatusCode}; +use serde_json::json; use tower::ServiceBuilder; use tower_http::cors::{AllowOrigin, CorsLayer}; @@ -23,91 +26,137 @@ const DEFAULT_EXPOSED_HEADERS: [&str; 3] = ["grpc-status", "grpc-message", "grpc-status-details-bin"]; const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60); -// async fn model_registered_listener(notify_restart: Arc) { -// while (SimpleBroker::::subscribe().next().await).is_some() { -// notify_restart.notify_one(); -// } -// } +pub struct Proxy { + addr: SocketAddr, + allowed_origins: Vec, + grpc_addr: Option, + graphql_addr: Arc>>, +} + +impl Proxy { + pub fn new( + addr: SocketAddr, + allowed_origins: Vec, + grpc_addr: Option, + graphql_addr: Option, + ) -> Self { + Self { addr, allowed_origins, grpc_addr, graphql_addr: Arc::new(RwLock::new(graphql_addr)) } + } + + pub async fn set_graphql_addr(&self, addr: SocketAddr) { + let mut graphql_addr = self.graphql_addr.write().await; + *graphql_addr = Some(addr); + } + + pub async fn start( + &self, + mut shutdown_rx: tokio::sync::broadcast::Receiver<()>, + ) -> Result<(), hyper::Error> { + let addr = self.addr; + let allowed_origins = self.allowed_origins.clone(); + let grpc_addr = self.grpc_addr; + let graphql_addr = self.graphql_addr.clone(); + + let make_svc = make_service_fn(move |conn: &AddrStream| { + let remote_addr = conn.remote_addr().ip(); + let cors = CorsLayer::new() + .max_age(DEFAULT_MAX_AGE) + .allow_methods([Method::GET, Method::POST]) + .allow_headers( + DEFAULT_ALLOW_HEADERS + .iter() + .cloned() + .map(HeaderName::from_static) + .collect::>(), + ) + .expose_headers( + DEFAULT_EXPOSED_HEADERS + .iter() + .cloned() + .map(HeaderName::from_static) + .collect::>(), + ); + + let cors = match allowed_origins.as_slice() { + [origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()), + origins => cors.allow_origin( + origins.iter().map(|o| o.parse().expect("valid origin")).collect::>(), + ), + }; + + let graphql_addr_clone = graphql_addr.clone(); + let service = ServiceBuilder::new().layer(cors).service_fn(move |req| { + let graphql_addr = graphql_addr_clone.clone(); + async move { + let graphql_addr = graphql_addr.read().await.clone(); + let graphql_addr = graphql_addr.map(|addr| format!("http://{}", addr)); + let grpc_addr = grpc_addr.map(|addr| format!("http://{}", addr)); + handle(remote_addr, grpc_addr, graphql_addr, req).await + } + }); + + async { Ok::<_, Infallible>(service) } + }); -fn debug_request(req: Request) -> Result, Infallible> { - let body_str = format!("{:?}", req); - Ok(Response::new(Body::from(body_str))) + let server = Server::bind(&addr).serve(make_svc); + server + .with_graceful_shutdown(async move { + // Wait for the shutdown signal + shutdown_rx.recv().await.ok(); + }) + .await + } } async fn handle( client_ip: IpAddr, - grpc_addr: String, - graphql_addr: String, + grpc_addr: Option, + graphql_addr: Option, req: Request, ) -> Result, Infallible> { if req.uri().path().starts_with("/grpc") { - match hyper_reverse_proxy::call(client_ip, &grpc_addr, req).await { - Ok(response) => Ok(response), - Err(_error) => Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) + if let Some(grpc_addr) = grpc_addr { + return match hyper_reverse_proxy::call(client_ip, &grpc_addr, req).await { + Ok(response) => Ok(response), + Err(_error) => Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap()), + }; + } else { + return Ok(Response::builder() + .status(StatusCode::NOT_FOUND) .body(Body::empty()) - .unwrap()), + .unwrap()); } - } else if req.uri().path().starts_with("/graphql") { - match hyper_reverse_proxy::call(client_ip, &graphql_addr, req).await { - Ok(response) => Ok(response), - Err(_error) => Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) + } + + if req.uri().path().starts_with("/graphql") { + if let Some(graphql_addr) = graphql_addr { + return match hyper_reverse_proxy::call(client_ip, &graphql_addr, req).await { + Ok(response) => Ok(response), + Err(_error) => Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap()), + }; + } else { + return Ok(Response::builder() + .status(StatusCode::NOT_FOUND) .body(Body::empty()) - .unwrap()), + .unwrap()); } - } else { - debug_request(req) } -} - -pub async fn new( - mut shutdown_rx: tokio::sync::broadcast::Receiver<()>, - addr: SocketAddr, - allowed_origins: Vec, - grpc_addr: SocketAddr, - graphql_addr: SocketAddr, -) -> impl Future> + 'static { - let make_svc = make_service_fn(move |conn: &AddrStream| { - let remote_addr = conn.remote_addr().ip(); - let grpc_addr = format!("http://{}", grpc_addr); - let graphql_addr = format!("http://{}", graphql_addr); - - let cors = CorsLayer::new() - .max_age(DEFAULT_MAX_AGE) - .allow_methods([Method::GET, Method::POST]) - .allow_headers( - DEFAULT_ALLOW_HEADERS - .iter() - .cloned() - .map(HeaderName::from_static) - .collect::>(), - ) - .expose_headers( - DEFAULT_EXPOSED_HEADERS - .iter() - .cloned() - .map(HeaderName::from_static) - .collect::>(), - ); - - let cors = match allowed_origins.as_slice() { - [origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()), - origins => cors.allow_origin( - origins.iter().map(|o| o.parse().expect("valid origin")).collect::>(), - ), - }; - let service = ServiceBuilder::new().layer(cors).service_fn(move |req| { - handle(remote_addr, grpc_addr.clone(), graphql_addr.clone(), req) - }); - - async { Ok::<_, Infallible>(service) } + let json = json!({ + "service": "torii", + "success": true }); - - let server = Server::bind(&addr).serve(make_svc); - server.with_graceful_shutdown(async move { - // Wait for the shutdown signal - shutdown_rx.recv().await.ok(); - }) + let body = Body::from(json.to_string()); + let response = Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .unwrap(); + Ok(response) }