Skip to content

Commit

Permalink
core: change stdcm heuristic to typealias
Browse files Browse the repository at this point in the history
  • Loading branch information
Erashin committed Jun 26, 2024
1 parent d59bbae commit 96b8f78
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 25 deletions.
8 changes: 0 additions & 8 deletions core/src/main/kotlin/fr/sncf/osrd/graph/Interfaces.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@ import fr.sncf.osrd.utils.indexing.StaticIdx
import fr.sncf.osrd.utils.units.Length
import fr.sncf.osrd.utils.units.Offset

/**
* This interface defines a function that can be used as a heuristic for an A* pathfinding. It takes
* a node as input, and returns an estimation of the remaining distance.
*/
fun interface STDCMAStarHeuristic<NodeT> {
fun apply(node: NodeT): Double
}

/**
* This interface defines a function that can be used as a heuristic for an A* pathfinding. It takes
* an edge and an offset on this edge as inputs, and returns an estimation of the remaining
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/STDCMHeuristic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package fr.sncf.osrd.stdcm
import fr.sncf.osrd.api.pathfinding.makePathProps
import fr.sncf.osrd.envelope_sim.PhysicsRollingStock
import fr.sncf.osrd.envelope_sim_infra.MRSP
import fr.sncf.osrd.graph.STDCMAStarHeuristic
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.sim_infra.api.BlockId
import fr.sncf.osrd.sim_infra.api.BlockInfra
Expand Down Expand Up @@ -45,7 +44,17 @@ private data class PendingBlock(
}
}

/** Runs all the pre-processing and initialize the STDCM A* heuristic. */
/**
* This typealias defines a function that can be used as a heuristic for an A* pathfinding. It takes
* a node as input, and returns an estimation of the remaining time needed to get to the end.
*/
typealias STDCMAStarHeuristic<NodeT> = (NodeT) -> Double

fun <NodeT> List<STDCMAStarHeuristic<NodeT>>.apply(node: NodeT, nbPassedSteps: Int): Double {
return this[nbPassedSteps](node)
}

/** Runs all the pre-processing and initializes the STDCM A* heuristic. */
fun makeSTDCMHeuristics(
blockInfra: BlockInfra,
rawInfra: RawInfra,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import fr.sncf.osrd.graph.*
import fr.sncf.osrd.reporting.exceptions.ErrorType
import fr.sncf.osrd.reporting.exceptions.OSRDError
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.stdcm.STDCMResult
import fr.sncf.osrd.stdcm.STDCMStep
import fr.sncf.osrd.stdcm.*
import fr.sncf.osrd.stdcm.infra_exploration.initInfraExplorerWithEnvelope
import fr.sncf.osrd.stdcm.makeSTDCMHeuristics
import fr.sncf.osrd.stdcm.preprocessing.interfaces.BlockAvailabilityInterface
import fr.sncf.osrd.train.RollingStock
import fr.sncf.osrd.utils.units.Offset
Expand Down Expand Up @@ -144,7 +142,7 @@ class STDCMPathfinding(
private fun findPathImpl(): Result? {
val queue = PriorityQueue<STDCMNode>()
for (location in starts) {
location.remainingTimeEstimation = remainingTimeEstimators!![0].apply(location)
location.remainingTimeEstimation = remainingTimeEstimators!!.apply(location, 0)
queue.add(location)
}
val start = Instant.now()
Expand All @@ -164,7 +162,7 @@ class STDCMPathfinding(
for (neighbor in neighbors) {
if (neighbor.waypointIndex < remainingTimeEstimators!!.size)
neighbor.remainingTimeEstimation =
remainingTimeEstimators!![neighbor.waypointIndex].apply(neighbor)
remainingTimeEstimators!!.apply(neighbor, neighbor.waypointIndex)
queue.add(neighbor)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package fr.sncf.osrd.stdcm.preprocessing

import fr.sncf.osrd.envelope_sim.SimpleRollingStock
import fr.sncf.osrd.graph.PathfindingEdgeLocationId
import fr.sncf.osrd.graph.STDCMAStarHeuristic
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.sim_infra.api.BlockId
import fr.sncf.osrd.stdcm.STDCMAStarHeuristic
import fr.sncf.osrd.stdcm.STDCMStep
import fr.sncf.osrd.stdcm.apply
import fr.sncf.osrd.stdcm.graph.STDCMEdge
import fr.sncf.osrd.stdcm.graph.STDCMNode
import fr.sncf.osrd.stdcm.infra_exploration.initInfraExplorerWithEnvelope
Expand Down Expand Up @@ -77,28 +78,28 @@ class STDCMHeuristicTests {

assertEquals(
400.0 - 50.0,
getLocationRemainingTime(infra, blocks[0], 50.meters, heuristics[0])
getLocationRemainingTime(infra, blocks[0], 50.meters, 0, heuristics)
)
assertEquals(
400.0 - 85.0,
getLocationRemainingTime(infra, blocks[0], 85.meters, heuristics[0])
getLocationRemainingTime(infra, blocks[0], 85.meters, 0, heuristics)
)
assertEquals(
400.0 - 100.0 - 25.0,
getLocationRemainingTime(infra, blocks[1], 25.meters, heuristics[1])
getLocationRemainingTime(infra, blocks[1], 25.meters, 1, heuristics)
)
assertEquals(
400.0 - 100.0 - 75.0,
getLocationRemainingTime(infra, blocks[1], 75.meters, heuristics[2])
getLocationRemainingTime(infra, blocks[1], 75.meters, 2, heuristics)
)
assertEquals(
400.0 - 200.0,
getLocationRemainingTime(infra, blocks[2], 0.meters, heuristics[3])
getLocationRemainingTime(infra, blocks[2], 0.meters, 3, heuristics)
)
assertEquals(0.0, getLocationRemainingTime(infra, blocks[3], null, heuristics[3]))
assertEquals(0.0, getLocationRemainingTime(infra, blocks[3], null, 3, heuristics))
assertEquals(
Double.POSITIVE_INFINITY,
getLocationRemainingTime(infra, blocks[3], 0.meters, heuristics[0])
getLocationRemainingTime(infra, blocks[3], 0.meters, 0, heuristics)
)
}

Expand All @@ -110,7 +111,8 @@ class STDCMHeuristicTests {
infra: DummyInfra,
block: BlockId,
nodeOffsetOnEdge: Distance?,
heuristic: STDCMAStarHeuristic<STDCMNode>
nbPassedSteps: Int,
heuristics: List<STDCMAStarHeuristic<STDCMNode>>
): Double {
val explorer =
initInfraExplorerWithEnvelope(
Expand Down Expand Up @@ -142,6 +144,6 @@ class STDCMHeuristicTests {
var locationOnEdge: Offset<Block>? = null
if (nodeOffsetOnEdge != null) locationOnEdge = Offset(nodeOffsetOnEdge)
val node = STDCMNode(0.0, 0.0, explorer, 0.0, 0.0, defaultEdge, 0, locationOnEdge, null)
return heuristic.apply(node)
return heuristics.apply(node, nbPassedSteps)
}
}

0 comments on commit 96b8f78

Please sign in to comment.