From 512f065f4557cf64a540e8e006ee1c329f8b413a Mon Sep 17 00:00:00 2001 From: Florian Amsallem Date: Thu, 5 Dec 2024 11:08:10 +0100 Subject: [PATCH] editoast: fix projection endpoint Co-authored-by: Youness CHRIFI ALAOUI Signed-off-by: Florian Amsallem Signed-off-by: Youness CHRIFI ALAOUI --- editoast/src/views/train_schedule.rs | 101 ++++++++++++++++-- .../src/views/train_schedule/projection.rs | 61 +++++------ 2 files changed, 126 insertions(+), 36 deletions(-) diff --git a/editoast/src/views/train_schedule.rs b/editoast/src/views/train_schedule.rs index b6aa27a3ab3..03d03eb6d97 100644 --- a/editoast/src/views/train_schedule.rs +++ b/editoast/src/views/train_schedule.rs @@ -796,6 +796,8 @@ async fn get_path( #[cfg(test)] mod tests { use axum::http::StatusCode; + use chrono::DateTime; + use chrono::Utc; use pretty_assertions::assert_eq; use rstest::rstest; use serde_json::json; @@ -920,8 +922,7 @@ mod tests { ) } - async fn app_infra_id_train_schedule_id_for_simulation_tests() -> (TestApp, i64, i64) { - let db_pool = DbConnectionPoolV2::for_tests(); + fn mocked_core_pathfinding_sim_and_proj(train_id: i64) -> MockingClient { let mut core = MockingClient::new(); core.stub("/v2/pathfinding/blocks") .method(reqwest::Method::POST) @@ -975,10 +976,18 @@ mod tests { } })) .finish(); - let app = TestAppBuilder::new() - .db_pool(db_pool.clone()) - .core_client(core.into()) - .build(); + core.stub("/v2/signal_projection") + .method(reqwest::Method::POST) + .response(StatusCode::OK) + .json(json!({ + "signal_updates": {train_id.to_string(): [] }, + })) + .finish(); + core + } + + async fn app_infra_id_train_schedule_id_for_simulation_tests() -> (TestApp, i64, i64) { + let db_pool = DbConnectionPoolV2::for_tests(); let small_infra = create_small_infra(&mut db_pool.get_ok()).await; let rolling_stock = create_fast_rolling_stock(&mut db_pool.get_ok(), "simulation_rolling_stock").await; @@ -997,6 +1006,11 @@ mod tests { .create(&mut db_pool.get_ok()) .await .expect("Failed to create train schedule"); + let core = mocked_core_pathfinding_sim_and_proj(train_schedule.id); + let app = TestAppBuilder::new() + .db_pool(db_pool.clone()) + .core_client(core.into()) + .build(); (app, small_infra.id, train_schedule.id) } @@ -1024,4 +1038,79 @@ mod tests { })); app.fetch(request).assert_status(StatusCode::OK); } + + #[derive(Deserialize)] + struct PartialProjectPathTrainResult { + departure_time: DateTime, + // Ignore the rest of the payload + } + + #[rstest] + async fn train_schedule_project_path() { + // SETUP + let db_pool = DbConnectionPoolV2::for_tests(); + + let small_infra = create_small_infra(&mut db_pool.get_ok()).await; + let rolling_stock = + create_fast_rolling_stock(&mut db_pool.get_ok(), "simulation_rolling_stock").await; + let timetable = create_timetable(&mut db_pool.get_ok()).await; + let train_schedule_base: TrainScheduleBase = TrainScheduleBase { + rolling_stock_name: rolling_stock.name.clone(), + ..serde_json::from_str(include_str!("../tests/train_schedules/simple.json")) + .expect("Unable to parse") + }; + let train_schedule: Changeset = TrainScheduleForm { + timetable_id: Some(timetable.id), + train_schedule: train_schedule_base.clone(), + } + .into(); + let train_schedule_valid = train_schedule + .create(&mut db_pool.get_ok()) + .await + .expect("Failed to create train schedule"); + + let train_schedule_fail: Changeset = TrainScheduleForm { + timetable_id: Some(timetable.id), + train_schedule: TrainScheduleBase { + rolling_stock_name: "fail".to_string(), + start_time: DateTime::from_timestamp(0, 0).unwrap(), + ..train_schedule_base.clone() + }, + } + .into(); + + let train_schedule_fail = train_schedule_fail + .create(&mut db_pool.get_ok()) + .await + .expect("Failed to create train schedule"); + + let core = mocked_core_pathfinding_sim_and_proj(train_schedule_valid.id); + let app = TestAppBuilder::new() + .db_pool(db_pool.clone()) + .core_client(core.into()) + .build(); + + // TEST + let request = app.post("/train_schedule/project_path").json(&json!({ + "infra_id": small_infra.id, + "electrical_profile_set_id": null, + "ids": vec![train_schedule_fail.id, train_schedule_valid.id], + "path": { + "track_section_ranges": [ + {"track_section": "TA1", "begin": 0, "end": 100, "direction": "START_TO_STOP"} + ], + "routes": [], + "blocks": [] + } + })); + let response: HashMap = + app.fetch(request).assert_status(StatusCode::OK).json_into(); + + // EXPECT + assert_eq!(response.len(), 1); + assert_eq!( + response[&train_schedule_valid.id].departure_time, + train_schedule_base.start_time + ); + } } diff --git a/editoast/src/views/train_schedule/projection.rs b/editoast/src/views/train_schedule/projection.rs index 8bfa58abdae..872195b86d6 100644 --- a/editoast/src/views/train_schedule/projection.rs +++ b/editoast/src/views/train_schedule/projection.rs @@ -198,7 +198,7 @@ async fn project_path( let mut trains_hash_values = vec![]; let mut trains_details = vec![]; - for (sim, pathfinding_result) in simulations { + for (train, (sim, pathfinding_result)) in izip!(&trains, simulations) { let track_ranges = match pathfinding_result { PathfindingResult::Success(PathfindingResultSuccess { track_section_ranges, @@ -221,6 +221,7 @@ async fn project_path( } = report_train; let train_details = TrainSimulationDetails { + train_id: train.id, positions, times, signal_critical_positions, @@ -242,17 +243,13 @@ async fn project_path( let cached_projections: Vec> = valkey_conn.json_get_bulk(&trains_hash_values).await?; - let mut hit_cache: HashMap = HashMap::new(); - let mut miss_cache = HashMap::new(); - for (train_details, projection, train_id) in izip!( - trains_details, - cached_projections, - trains.iter().map(|t| t.id) - ) { + let mut hit_cache = vec![]; + let mut miss_cache = vec![]; + for (train_details, projection) in izip!(&trains_details, cached_projections) { if let Some(cached) = projection { - hit_cache.insert(train_id, cached); + hit_cache.push((cached, train_details.train_id)); } else { - miss_cache.insert(train_id, train_details.clone()); + miss_cache.push(train_details.clone()); } } @@ -277,25 +274,25 @@ async fn project_path( let signal_updates = signal_updates?; // 3. Store the projection in the cache (using pipeline) - let trains_hash_values: HashMap<_, _> = trains + let trains_hash_values: HashMap<_, _> = trains_details .iter() - .map(|t| t.id) + .map(|t| t.train_id) .zip(trains_hash_values) .collect(); let mut new_items = vec![]; - for id in miss_cache.keys() { - let hash = &trains_hash_values[id]; + for train_id in miss_cache.iter().map(|t| t.train_id) { + let hash = &trains_hash_values[&train_id]; let cached_value = CachedProjectPathTrainResult { space_time_curves: space_time_curves - .get(id) + .get(&train_id) .expect("Space time curves not available for train") .clone(), signal_updates: signal_updates - .get(id) + .get(&train_id) .expect("Signal update not available for train") .clone(), }; - hit_cache.insert(*id, cached_value.clone()); + hit_cache.push((cached_value.clone(), train_id)); new_items.push((hash, cached_value)); } valkey_conn.json_set_bulk(&new_items).await?; @@ -303,21 +300,21 @@ async fn project_path( let train_map: HashMap = trains.into_iter().map(|ts| (ts.id, ts)).collect(); // 4.1 Fetch rolling stock length - let mut project_path_result = HashMap::new(); let rolling_stock_length: HashMap<_, _> = rolling_stocks .into_iter() .map(|rs| (rs.name, rs.length)) .collect(); // 4.2 Build the projection response - for (id, cached) in hit_cache { - let train = train_map.get(&id).expect("Train not found"); + let mut project_path_result = HashMap::new(); + for (cached, train_id) in hit_cache { + let train = train_map.get(&train_id).expect("Train not found"); let length = rolling_stock_length .get(&train.rolling_stock_name) .expect("Rolling stock length not found"); project_path_result.insert( - id, + train_id, ProjectPathTrainResult { departure_time: train.start_time, rolling_stock_length: (length * 1000.).round() as u64, @@ -332,6 +329,7 @@ async fn project_path( /// Input for the projection of a train schedule on a path #[derive(Debug, Clone, Hash)] struct TrainSimulationDetails { + train_id: i64, positions: Vec, times: Vec, train_path: Vec, @@ -346,7 +344,7 @@ async fn compute_batch_signal_updates<'a>( path_track_ranges: &'a Vec, path_routes: &'a Vec, path_blocks: &'a Vec, - trains_details: &'a HashMap, + trains_details: &'a [TrainSimulationDetails], ) -> Result>> { if trains_details.is_empty() { return Ok(HashMap::new()); @@ -359,13 +357,13 @@ async fn compute_batch_signal_updates<'a>( blocks: path_blocks, train_simulations: trains_details .iter() - .map(|(id, details)| { + .map(|detail| { ( - *id, + detail.train_id, TrainSimulation { - signal_critical_positions: &details.signal_critical_positions, - zone_updates: &details.zone_updates, - simulation_end_time: details.times[details.times.len() - 1], + signal_critical_positions: &detail.signal_critical_positions, + zone_updates: &detail.zone_updates, + simulation_end_time: detail.times[detail.times.len() - 1], }, ) }) @@ -377,14 +375,14 @@ async fn compute_batch_signal_updates<'a>( /// Compute space time curves of a list of train schedules async fn compute_batch_space_time_curves<'a>( - trains_details: &HashMap, + trains_details: &Vec, path_projection: &PathProjection<'a>, ) -> HashMap> { let mut space_time_curves = HashMap::new(); - for (train_id, train_detail) in trains_details { + for train_detail in trains_details { space_time_curves.insert( - *train_id, + train_detail.train_id, compute_space_time_curves(train_detail, path_projection), ); } @@ -584,6 +582,7 @@ mod tests { ]; let project_path_input = TrainSimulationDetails { + train_id: 0, positions, times, train_path, @@ -618,6 +617,7 @@ mod tests { ]; let project_path_input = TrainSimulationDetails { + train_id: 0, positions: positions.clone(), times: times.clone(), train_path, @@ -655,6 +655,7 @@ mod tests { let path_projection = PathProjection::new(&path); let project_path_input = TrainSimulationDetails { + train_id: 0, positions, times, train_path,