Skip to content

Commit

Permalink
editoast: add configurable authorization to TestAppBuilder
Browse files Browse the repository at this point in the history
Signed-off-by: hamz2a <atrari.hamza@gmail.com>
  • Loading branch information
hamz2a committed Dec 30, 2024
1 parent 6a105bf commit 1314060
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 10 deletions.
10 changes: 2 additions & 8 deletions editoast/src/views/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,13 @@ async fn authenticate(

async fn authentication_middleware(
State(AppState {
db_pool,
disable_authorization,
..
db_pool, config, ..
}): State<AppState>,
mut req: Request,
next: Next,
) -> Result<Response> {
let headers = req.headers();
let authorizer = authenticate(disable_authorization, headers, db_pool).await?;
let authorizer = authenticate(config.disable_authorization, headers, db_pool).await?;
req.extensions_mut().insert(authorizer);
Ok(next.run(req).await)
}
Expand Down Expand Up @@ -365,7 +363,6 @@ pub struct ServerConfig {
pub health_check_timeout: Duration,
pub map_layers_max_zoom: u8,
pub disable_authorization: bool,

pub postgres_config: PostgresConfig,
pub osrdyne_config: OsrdyneConfig,
pub valkey_config: ValkeyConfig,
Expand All @@ -382,13 +379,11 @@ pub struct Server {
#[derive(Clone)]
pub struct AppState {
pub config: Arc<ServerConfig>,

pub db_pool: Arc<DbConnectionPoolV2>,
pub valkey: Arc<ValkeyClient>,
pub infra_caches: Arc<DashMap<i64, InfraCache>>,
pub map_layers: Arc<MapLayers>,
pub speed_limit_tag_ids: Arc<SpeedLimitTagIds>,
pub disable_authorization: bool,
pub core_client: Arc<CoreClient>,
pub osrdyne_client: Arc<OsrdyneClient>,
pub health_check_timeout: Duration,
Expand Down Expand Up @@ -458,7 +453,6 @@ impl AppState {
osrdyne_client,
map_layers: Arc::new(MapLayers::default()),
speed_limit_tag_ids,
disable_authorization: config.disable_authorization,
health_check_timeout: config.health_check_timeout,
config: Arc::new(config),
})
Expand Down
32 changes: 32 additions & 0 deletions editoast/src/views/projects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,21 @@ async fn patch(
#[cfg(test)]
pub mod tests {

use std::collections::HashSet;

use axum::http::StatusCode;
use editoast_authz::authorizer::UserInfo;
use pretty_assertions::assert_eq;
use rstest::rstest;
use serde_json::json;

use super::*;
use crate::core::mocking::MockingClient;
use crate::core::CoreClient;
use crate::models::fixtures::create_project;
use crate::models::prelude::*;
use crate::views::test_app::TestAppBuilder;
use crate::views::test_app::TestRequestExt;

#[rstest]
async fn project_post() {
Expand All @@ -394,6 +400,32 @@ pub mod tests {
assert_eq!(project.name, project_name);
}

#[rstest]
async fn project_post_should_fail_when_authorization_is_enabled() {
let pool = DbConnectionPoolV2::for_tests();
let user = UserInfo {
identity: "user_identity".to_string(),
name: "user_name".to_string(),
};
let app = TestAppBuilder::new()
.db_pool(pool)
.core_client(CoreClient::Mocked(MockingClient::default()))
.enable_authorization(true)
.user(user.clone())
.roles(HashSet::from([BuiltinRole::OpsRead]))
.build();

let request = app.post("/projects").with_user(user).json(&json!({
"name": "test_project_failed",
"description": "",
"objectives": "",
"funders": "",
}));

// OpsWrite is required to complete this request successfully.
app.fetch(request).assert_status(StatusCode::FORBIDDEN);
}

#[rstest]
async fn project_list() {
let app = TestAppBuilder::default_app();
Expand Down
64 changes: 62 additions & 2 deletions editoast/src/views/test_app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
//! test actix server, database connection pool, and different mocking
//! components.
use std::collections::HashSet;
use std::sync::Arc;

use axum::Router;
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use dashmap::DashMap;
use editoast_authz::authorizer::StorageDriver;
use editoast_authz::authorizer::UserInfo;
use editoast_authz::BuiltinRole;
use editoast_models::DbConnectionPoolV2;
use editoast_osrdyne_client::OsrdyneClient;
use futures::executor::block_on;
use serde::de::DeserializeOwned;
use tower_http::trace::TraceLayer;
use url::Url;
Expand All @@ -18,6 +23,7 @@ use crate::{
generated_data::speed_limit_tags_config::SpeedLimitTagIds,
infra_cache::InfraCache,
map::MapLayers,
models::auth::PgAuthDriver,
valkey_utils::ValkeyConfig,
AppState, ValkeyClient,
};
Expand All @@ -39,6 +45,9 @@ pub(crate) struct TestAppBuilder {
db_pool: Option<DbConnectionPoolV2>,
core_client: Option<CoreClient>,
osrdyne_client: Option<OsrdyneClient>,
enable_authorization: bool,
user: Option<UserInfo>,
roles: HashSet<BuiltinRole>,
}

impl TestAppBuilder {
Expand All @@ -47,6 +56,9 @@ impl TestAppBuilder {
db_pool: None,
core_client: None,
osrdyne_client: None,
enable_authorization: false,
user: None,
roles: HashSet::new(),
}
}

Expand All @@ -68,6 +80,23 @@ impl TestAppBuilder {
self
}

pub fn enable_authorization(mut self, enable_authorization: bool) -> Self {
self.enable_authorization = enable_authorization;
self
}

pub fn user(mut self, user: UserInfo) -> Self {
assert!(self.user.is_none());
self.user = Some(user);
self
}

pub fn roles(mut self, roles: HashSet<BuiltinRole>) -> Self {
assert!(self.roles.is_empty());
self.roles = roles;
self
}

pub fn default_app() -> TestApp {
let pool = DbConnectionPoolV2::for_tests();
let core_client = CoreClient::Mocked(MockingClient::default());
Expand All @@ -83,7 +112,7 @@ impl TestAppBuilder {
port: 0,
address: String::default(),
health_check_timeout: chrono::Duration::milliseconds(500),
disable_authorization: true,
disable_authorization: !self.enable_authorization,
map_layers_max_zoom: 18,
postgres_config: PostgresConfig {
database_url: Url::parse("postgres://osrd:password@localhost:5432/osrd").unwrap(),
Expand Down Expand Up @@ -152,7 +181,6 @@ impl TestAppBuilder {
infra_caches,
map_layers: Arc::new(MapLayers::default()),
speed_limit_tag_ids,
disable_authorization: true,
health_check_timeout: config.health_check_timeout,
config: Arc::new(config),
};
Expand All @@ -171,6 +199,23 @@ impl TestAppBuilder {
// Run server
let server = TestServer::new(router).expect("test server should build properly");

// Setup user and roles
let driver = PgAuthDriver::<BuiltinRole>::new(db_pool_v2.clone());
if let Some(ref user) = self.user {
let uid = block_on(async {
driver
.ensure_user(user)
.await
.expect("User should be created successfully")
});
block_on(async {
driver
.ensure_subject_roles(uid, self.roles)
.await
.expect("Roles should be updated successfully")
});
};

TestApp {
server,
db_pool: db_pool_v2,
Expand Down Expand Up @@ -213,20 +258,35 @@ impl TestApp {
pub fn get(&self, path: &str) -> TestRequest {
self.server.get(&trim_path(path))
}

pub fn post(&self, path: &str) -> TestRequest {
self.server.post(&trim_path(path))
}

pub fn put(&self, path: &str) -> TestRequest {
self.server.put(&trim_path(path))
}

pub fn patch(&self, path: &str) -> TestRequest {
self.server.patch(&trim_path(path))
}

pub fn delete(&self, path: &str) -> TestRequest {
self.server.delete(&trim_path(path))
}
}

pub trait TestRequestExt {
fn with_user(self, user: UserInfo) -> Self;
}

impl TestRequestExt for TestRequest {
fn with_user(self, user: UserInfo) -> Self {
self.add_header("x-remote-user-identity", user.identity)
.add_header("x-remote-user-name", user.name)
}
}

// For technical reasons, we had a hard time trying to configure the normalizing layer
// in the test server. Since we have control over the paths configured in our unit tests,
// doing this manually is probably a good enough solution for now.
Expand Down

0 comments on commit 1314060

Please sign in to comment.