Skip to content

Commit

Permalink
editoast: fix projection endpoint
Browse files Browse the repository at this point in the history
Co-authored-by: Youness CHRIFI ALAOUI <youness.chrifi@gmail.com>
Signed-off-by: Florian Amsallem <florian.amsallem@gmail.com>
  • Loading branch information
flomonster and younesschrifi committed Dec 9, 2024
1 parent 8f6131e commit 61dde6b
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 32 deletions.
157 changes: 155 additions & 2 deletions editoast/src/views/train_schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,8 @@ async fn get_path(
#[cfg(test)]
mod tests {
use axum::http::StatusCode;
use chrono::DateTime;
use chrono::Utc;
use pretty_assertions::assert_eq;
use rstest::rstest;
use serde_json::json;
Expand Down Expand Up @@ -920,8 +922,7 @@ mod tests {
)
}

async fn app_infra_id_train_schedule_id_for_simulation_tests() -> (TestApp, i64, i64) {
let db_pool = DbConnectionPoolV2::for_tests();
fn mocked_core_pathfinding_and_sim() -> MockingClient {
let mut core = MockingClient::new();
core.stub("/v2/pathfinding/blocks")
.method(reqwest::Method::POST)
Expand Down Expand Up @@ -975,6 +976,19 @@ mod tests {
}
}))
.finish();
core.stub("/v2/signal_projection")
.method(reqwest::Method::POST)
.response(StatusCode::OK)
.json(json!({
"signal_updates": {},
}))
.finish();
core
}

async fn app_infra_id_train_schedule_id_for_simulation_tests() -> (TestApp, i64, i64) {
let db_pool = DbConnectionPoolV2::for_tests();
let core = mocked_core_pathfinding_and_sim();
let app = TestAppBuilder::new()
.db_pool(db_pool.clone())
.core_client(core.into())
Expand Down Expand Up @@ -1024,4 +1038,143 @@ mod tests {
}));
app.fetch(request).assert_status(StatusCode::OK);
}

#[derive(Deserialize)]
struct PartialProjectPathTrainResult {
departure_time: DateTime<Utc>,
// Ignore the rest of the payload
}

fn mocked_core_for_project_path(train_id: i64) -> MockingClient {
let mut core = MockingClient::new();
core.stub("/v2/pathfinding/blocks")
.method(reqwest::Method::POST)
.response(StatusCode::OK)
.json(json!({
"blocks":[],
"routes": [],
"track_section_ranges": [],
"path_item_positions": [0,1,2,3],
"length": 0,
"status": "success"
}))
.finish();
core.stub("/v2/standalone_simulation")
.method(reqwest::Method::POST)
.response(StatusCode::OK)
.json(json!({
"status": "success",
"base": {
"positions": [],
"times": [],
"speeds": [],
"energy_consumption": 0.0,
"path_item_times": [0, 1000, 2000, 3000]
},
"provisional": {
"positions": [],
"times": [],
"speeds": [],
"energy_consumption": 0.0,
"path_item_times": [0, 1000, 2000, 3000]
},
"final_output": {
"positions": [0],
"times": [0],
"speeds": [],
"energy_consumption": 0.0,
"path_item_times": [0, 1000, 2000, 3000],
"signal_critical_positions": [],
"zone_updates": [],
"spacing_requirements": [],
"routing_requirements": []
},
"mrsp": {
"boundaries": [],
"values": []
},
"electrical_profiles": {
"boundaries": [],
"values": []
}
}))
.finish();
core.stub("/v2/signal_projection")
.method(reqwest::Method::POST)
.response(StatusCode::OK)
.json(json!({
"signal_updates": {train_id.to_string(): [] },
}))
.finish();
core
}

#[rstest]
async fn train_schedule_project_path() {
// SETUP
let db_pool = DbConnectionPoolV2::for_tests();

let small_infra = create_small_infra(&mut db_pool.get_ok()).await;
let rolling_stock =
create_fast_rolling_stock(&mut db_pool.get_ok(), "simulation_rolling_stock").await;
let timetable = create_timetable(&mut db_pool.get_ok()).await;
let train_schedule_base: TrainScheduleBase = TrainScheduleBase {
rolling_stock_name: rolling_stock.name.clone(),
..serde_json::from_str(include_str!("../tests/train_schedules/simple.json"))
.expect("Unable to parse")
};
let train_schedule: Changeset<TrainSchedule> = TrainScheduleForm {
timetable_id: Some(timetable.id),
train_schedule: train_schedule_base.clone(),
}
.into();
let train_schedule_valid = train_schedule
.create(&mut db_pool.get_ok())
.await
.expect("Failed to create train schedule");

let train_schedule_fail: Changeset<TrainSchedule> = TrainScheduleForm {
timetable_id: Some(timetable.id),
train_schedule: TrainScheduleBase {
rolling_stock_name: "fail".to_string(),
start_time: DateTime::from_timestamp(0, 0).unwrap(),
..train_schedule_base.clone()
},
}
.into();

let train_schedule_fail = train_schedule_fail
.create(&mut db_pool.get_ok())
.await
.expect("Failed to create train schedule");

let core = mocked_core_for_project_path(train_schedule_valid.id);
let app = TestAppBuilder::new()
.db_pool(db_pool.clone())
.core_client(core.into())
.build();

// TEST
let request = app.post("/train_schedule/project_path").json(&json!({
"infra_id": small_infra.id,
"electrical_profile_set_id": null,
"ids": vec![train_schedule_fail.id, train_schedule_valid.id],
"path": {
"track_section_ranges": [
{"track_section": "TA1", "begin": 0, "end": 100, "direction": "START_TO_STOP"}
],
"routes": [],
"blocks": []
}
}));
let response: HashMap<i64, PartialProjectPathTrainResult> =
app.fetch(request).assert_status(StatusCode::OK).json_into();

// EXPECT
assert_eq!(response.len(), 1);
assert_eq!(
response[&train_schedule_valid.id].departure_time,
train_schedule_base.start_time
);
}
}
61 changes: 31 additions & 30 deletions editoast/src/views/train_schedule/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ async fn project_path(
let mut trains_hash_values = vec![];
let mut trains_details = vec![];

for (sim, pathfinding_result) in simulations {
for (train, (sim, pathfinding_result)) in izip!(&trains, simulations) {
let track_ranges = match pathfinding_result {
PathfindingResult::Success(PathfindingResultSuccess {
track_section_ranges,
Expand All @@ -221,6 +221,7 @@ async fn project_path(
} = report_train;

let train_details = TrainSimulationDetails {
train_id: train.id,
positions,
times,
signal_critical_positions,
Expand All @@ -242,17 +243,13 @@ async fn project_path(
let cached_projections: Vec<Option<CachedProjectPathTrainResult>> =
valkey_conn.json_get_bulk(&trains_hash_values).await?;

let mut hit_cache: HashMap<i64, CachedProjectPathTrainResult> = HashMap::new();
let mut miss_cache = HashMap::new();
for (train_details, projection, train_id) in izip!(
trains_details,
cached_projections,
trains.iter().map(|t| t.id)
) {
let mut hit_cache = vec![];
let mut miss_cache = vec![];
for (train_details, projection) in izip!(&trains_details, cached_projections) {
if let Some(cached) = projection {
hit_cache.insert(train_id, cached);
hit_cache.push((cached, train_details.train_id));
} else {
miss_cache.insert(train_id, train_details.clone());
miss_cache.push(train_details.clone());
}
}

Expand All @@ -277,47 +274,47 @@ async fn project_path(
let signal_updates = signal_updates?;

// 3. Store the projection in the cache (using pipeline)
let trains_hash_values: HashMap<_, _> = trains
let trains_hash_values: HashMap<_, _> = trains_details
.iter()
.map(|t| t.id)
.map(|t| t.train_id)
.zip(trains_hash_values)
.collect();
let mut new_items = vec![];
for id in miss_cache.keys() {
let hash = &trains_hash_values[id];
for train_id in miss_cache.iter().map(|t| t.train_id) {
let hash = &trains_hash_values[&train_id];
let cached_value = CachedProjectPathTrainResult {
space_time_curves: space_time_curves
.get(id)
.get(&train_id)
.expect("Space time curves not available for train")
.clone(),
signal_updates: signal_updates
.get(id)
.get(&train_id)
.expect("Signal update not available for train")
.clone(),
};
hit_cache.insert(*id, cached_value.clone());
hit_cache.push((cached_value.clone(), train_id));
new_items.push((hash, cached_value));
}
valkey_conn.json_set_bulk(&new_items).await?;

let train_map: HashMap<i64, TrainSchedule> = trains.into_iter().map(|ts| (ts.id, ts)).collect();

// 4.1 Fetch rolling stock length
let mut project_path_result = HashMap::new();
let rolling_stock_length: HashMap<_, _> = rolling_stocks
.into_iter()
.map(|rs| (rs.name, rs.length))
.collect();

// 4.2 Build the projection response
for (id, cached) in hit_cache {
let train = train_map.get(&id).expect("Train not found");
let mut project_path_result = HashMap::new();
for (cached, train_id) in hit_cache {
let train = train_map.get(&train_id).expect("Train not found");
let length = rolling_stock_length
.get(&train.rolling_stock_name)
.expect("Rolling stock length not found");

project_path_result.insert(
id,
train_id,
ProjectPathTrainResult {
departure_time: train.start_time,
rolling_stock_length: (length * 1000.).round() as u64,
Expand All @@ -332,6 +329,7 @@ async fn project_path(
/// Input for the projection of a train schedule on a path
#[derive(Debug, Clone, Hash)]
struct TrainSimulationDetails {
train_id: i64,
positions: Vec<u64>,
times: Vec<u64>,
train_path: Vec<TrackRange>,
Expand All @@ -346,7 +344,7 @@ async fn compute_batch_signal_updates<'a>(
path_track_ranges: &'a Vec<TrackRange>,
path_routes: &'a Vec<Identifier>,
path_blocks: &'a Vec<Identifier>,
trains_details: &'a HashMap<i64, TrainSimulationDetails>,
trains_details: &'a [TrainSimulationDetails],
) -> Result<HashMap<i64, Vec<SignalUpdate>>> {
if trains_details.is_empty() {
return Ok(HashMap::new());
Expand All @@ -359,13 +357,13 @@ async fn compute_batch_signal_updates<'a>(
blocks: path_blocks,
train_simulations: trains_details
.iter()
.map(|(id, details)| {
.map(|detail| {
(
*id,
detail.train_id,
TrainSimulation {
signal_critical_positions: &details.signal_critical_positions,
zone_updates: &details.zone_updates,
simulation_end_time: details.times[details.times.len() - 1],
signal_critical_positions: &detail.signal_critical_positions,
zone_updates: &detail.zone_updates,
simulation_end_time: detail.times[detail.times.len() - 1],
},
)
})
Expand All @@ -377,14 +375,14 @@ async fn compute_batch_signal_updates<'a>(

/// Compute space time curves of a list of train schedules
async fn compute_batch_space_time_curves<'a>(
trains_details: &HashMap<i64, TrainSimulationDetails>,
trains_details: &Vec<TrainSimulationDetails>,
path_projection: &PathProjection<'a>,
) -> HashMap<i64, Vec<SpaceTimeCurve>> {
let mut space_time_curves = HashMap::new();

for (train_id, train_detail) in trains_details {
for train_detail in trains_details {
space_time_curves.insert(
*train_id,
train_detail.train_id,
compute_space_time_curves(train_detail, path_projection),
);
}
Expand Down Expand Up @@ -584,6 +582,7 @@ mod tests {
];

let project_path_input = TrainSimulationDetails {
train_id: 0,
positions,
times,
train_path,
Expand Down Expand Up @@ -618,6 +617,7 @@ mod tests {
];

let project_path_input = TrainSimulationDetails {
train_id: 0,
positions: positions.clone(),
times: times.clone(),
train_path,
Expand Down Expand Up @@ -655,6 +655,7 @@ mod tests {
let path_projection = PathProjection::new(&path);

let project_path_input = TrainSimulationDetails {
train_id: 0,
positions,
times,
train_path,
Expand Down

0 comments on commit 61dde6b

Please sign in to comment.