From dfc9e4cb8e1a13f0652928629d4ab81d59bc4420 Mon Sep 17 00:00:00 2001 From: Florian Amsallem Date: Fri, 26 Jan 2024 17:04:10 +0100 Subject: [PATCH] core: fix mrsp construction --- .../api/tracks/undirected/SpeedLimits.java | 22 ++++++++++++++----- .../tracks/undirected/RJSParsingTests.java | 15 +++++++------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/core/src/main/java/fr/sncf/osrd/infra/api/tracks/undirected/SpeedLimits.java b/core/src/main/java/fr/sncf/osrd/infra/api/tracks/undirected/SpeedLimits.java index 79c114817f2..81fef6c23dd 100644 --- a/core/src/main/java/fr/sncf/osrd/infra/api/tracks/undirected/SpeedLimits.java +++ b/core/src/main/java/fr/sncf/osrd/infra/api/tracks/undirected/SpeedLimits.java @@ -3,9 +3,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import fr.sncf.osrd.railjson.schema.infra.trackranges.RJSSpeedSection; -import fr.sncf.osrd.sim_infra.impl.SpeedSection; -import fr.sncf.osrd.utils.units.Speed; -import java.util.HashMap; +import java.util.Objects; public final class SpeedLimits { @@ -51,12 +49,26 @@ public static SpeedLimits merge(SpeedLimits a, SpeedLimits b) { var categories = Sets.union(a.speedLimitByTag.keySet(), b.speedLimitByTag.keySet()); var builder = ImmutableMap.builder(); for (var category : categories) { - Double speedA = a.speedLimitByTag.getOrDefault(category, Double.POSITIVE_INFINITY); - Double speedB = b.speedLimitByTag.getOrDefault(category, Double.POSITIVE_INFINITY); + Double speedA = a.speedLimitByTag.getOrDefault(category, a.defaultSpeedLimit); + Double speedB = b.speedLimitByTag.getOrDefault(category, b.defaultSpeedLimit); assert speedA != null && speedB != null; var speed = Double.min(speedA, speedB); builder.put(category, speed); } return new SpeedLimits(defaultSpeed, builder.build()); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SpeedLimits other = (SpeedLimits) o; + return Double.compare(defaultSpeedLimit, other.defaultSpeedLimit) == 0 + && Objects.equals(speedLimitByTag, other.speedLimitByTag); + } + + @Override + public int hashCode() { + return Objects.hash(defaultSpeedLimit, speedLimitByTag); + } } diff --git a/core/src/test/java/fr/sncf/osrd/infra/tracks/undirected/RJSParsingTests.java b/core/src/test/java/fr/sncf/osrd/infra/tracks/undirected/RJSParsingTests.java index 478bfbf0f25..60dd727ae40 100644 --- a/core/src/test/java/fr/sncf/osrd/infra/tracks/undirected/RJSParsingTests.java +++ b/core/src/test/java/fr/sncf/osrd/infra/tracks/undirected/RJSParsingTests.java @@ -67,7 +67,7 @@ public void testOverlappingSpeedSections() throws Exception { var rjsInfra = Helpers.getExampleInfra("one_line/infra.json"); var track = rjsInfra.trackSections.iterator().next(); rjsInfra.speedSections = List.of( - new RJSSpeedSection("id", 42, Map.of( + new RJSSpeedSection("id", 27, Map.of( "category1", 10., "category2", 20. ), List.of(new RJSApplicableDirectionsTrackRange( @@ -78,7 +78,7 @@ public void testOverlappingSpeedSections() throws Exception { ))), new RJSSpeedSection("id", 45, Map.of( "category2", 12., - "category3", 17. + "category3", 30. ), List.of(new RJSApplicableDirectionsTrackRange( track.id, ApplicableDirection.START_TO_STOP, @@ -88,21 +88,22 @@ public void testOverlappingSpeedSections() throws Exception { ); var parsedInfra = UndirectedInfraBuilder.parseInfra(rjsInfra, new DiagnosticRecorderImpl(true)); var expected = TreeRangeMap.create(); - expected.put(Range.closed(0., 5.), new SpeedLimits(42, ImmutableMap.of( + expected.put(Range.closed(0., 5.), new SpeedLimits(27, ImmutableMap.of( "category1", 10., "category2", 20. ))); - expected.put(Range.closed(5., 10.), new SpeedLimits(42, ImmutableMap.of( + expected.put(Range.closed(5., 10.), new SpeedLimits(27, ImmutableMap.of( "category1", 10., "category2", 12., - "category3", 17. + "category3", 27. ))); expected.put(Range.closed(10., 15.), new SpeedLimits(45, ImmutableMap.of( "category2", 12., - "category3", 17. + "category3", 30. ))); + expected.put(Range.closed(15., 1000.), new SpeedLimits(Double.POSITIVE_INFINITY, ImmutableMap.of())); var speedLimits = parsedInfra.getTrackSection(track.id).getSpeedSections().get(Direction.FORWARD); - equalsIgnoringTransitions(expected, speedLimits); + assertTrue(equalsIgnoringTransitions(expected, speedLimits)); } @Test