Skip to content

Commit

Permalink
core: use nodes in stdcm priority queue and replace weight by clear c…
Browse files Browse the repository at this point in the history
…omparison
  • Loading branch information
Erashin committed Jun 27, 2024
1 parent 22d010d commit 5acf0be
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 149 deletions.
11 changes: 0 additions & 11 deletions core/src/main/kotlin/fr/sncf/osrd/graph/Interfaces.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,7 @@ fun interface TargetsOnEdge<EdgeT, OffsetType> {
fun apply(edge: EdgeT): Collection<Pathfinding.EdgeLocation<EdgeT, OffsetType>>
}

/** Alternate way to define the cost: returns the absolute cost of a location on an edge */
fun interface TotalCostUntilEdgeLocation<EdgeT, OffsetType> {
fun apply(edgeLocation: Pathfinding.EdgeLocation<EdgeT, OffsetType>): Double
}

// Type aliases to avoid repeating `StaticIdx<T>, T` when edge types are static idx
typealias AStarHeuristicId<T> = AStarHeuristic<StaticIdx<T>, T>

typealias EdgeToLengthId<T> = EdgeToLength<StaticIdx<T>, T>

typealias PathfindingConstraint<T> = EdgeToRanges<StaticIdx<T>, T>

typealias TargetsOnEdgeId<T> = TargetsOnEdge<StaticIdx<T>, T>

typealias TotalCostUntilEdgeLocationId<T> = TotalCostUntilEdgeLocation<StaticIdx<T>, T>
43 changes: 26 additions & 17 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/STDCMHeuristic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ package fr.sncf.osrd.stdcm
import fr.sncf.osrd.api.pathfinding.makePathProps
import fr.sncf.osrd.envelope_sim.PhysicsRollingStock
import fr.sncf.osrd.envelope_sim_infra.MRSP
import fr.sncf.osrd.graph.AStarHeuristic
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.sim_infra.api.BlockId
import fr.sncf.osrd.sim_infra.api.BlockInfra
import fr.sncf.osrd.sim_infra.api.RawInfra
import fr.sncf.osrd.sim_infra.utils.getBlockEntry
import fr.sncf.osrd.stdcm.graph.STDCMEdge
import fr.sncf.osrd.stdcm.graph.STDCMNode
import fr.sncf.osrd.utils.indexing.StaticIdx
import fr.sncf.osrd.utils.units.Offset
import fr.sncf.osrd.utils.units.meters
Expand Down Expand Up @@ -45,15 +44,24 @@ private data class PendingBlock(
}
}

/** Runs all the pre-processing and initialize the STDCM A* heuristic. */
/**
* This typealias defines a function that can be used as a heuristic for an A* pathfinding. It takes
* a node as input, and returns an estimation of the remaining time needed to get to the end.
*/
typealias STDCMAStarHeuristic<NodeT> = (NodeT) -> Double

fun <NodeT> List<STDCMAStarHeuristic<NodeT>>.apply(node: NodeT, nbPassedSteps: Int): Double {
return this[nbPassedSteps](node)
}

/** Runs all the pre-processing and initializes the STDCM A* heuristic. */
fun makeSTDCMHeuristics(
blockInfra: BlockInfra,
rawInfra: RawInfra,
steps: List<STDCMStep>,
maxRunningTime: Double,
rollingStock: PhysicsRollingStock,
maxDepartureDelay: Double,
): List<AStarHeuristic<STDCMEdge, STDCMEdge>> {
): List<STDCMAStarHeuristic<STDCMNode>> {
logger.info("Start building STDCM heuristic...")
// One map per number of reached pathfinding step
val maps = mutableListOf<MutableMap<BlockId, Double>>()
Expand All @@ -76,23 +84,24 @@ fun makeSTDCMHeuristics(
}
}

// We build one function (`AStarHeuristic`) per number of reached step
val res = mutableListOf<AStarHeuristic<STDCMEdge, STDCMEdge>>()
// We build one function (`STDCMAStarHeuristic`) per number of reached step
val res = mutableListOf<STDCMAStarHeuristic<STDCMNode>>()
for (nPassedSteps in maps.indices) {
res.add { edge, offset ->
res.add { node ->
// We need to iterate through the previous maps,
// to handle cases where several steps are on the same block
for (i in (0..nPassedSteps).reversed()) {
val cachedRemainingDistance = maps[i][edge.block] ?: continue
val blockOffset = edge.envelopeStartOffset + offset.distance
val cachedRemainingTime = maps[i][node.previousEdge.block] ?: continue
val remainingTime =
cachedRemainingDistance -
getBlockTime(rawInfra, blockInfra, edge.block, rollingStock, blockOffset)

// Accounts for the math in the `costToEdgeLocation`.
// We need the resulting value to be in the same referential as the cost
// used as STDCM cost function, which scales the running time
return@add remainingTime * maxDepartureDelay
cachedRemainingTime -
getBlockTime(
rawInfra,
blockInfra,
node.previousEdge.block,
rollingStock,
node.locationOnEdge
)
return@add remainingTime
}
return@add Double.POSITIVE_INFINITY
}
Expand Down
39 changes: 1 addition & 38 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMEdge.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import fr.sncf.osrd.utils.units.Length
import fr.sncf.osrd.utils.units.Offset
import fr.sncf.osrd.utils.units.meters
import java.lang.Double.isNaN
import java.util.*

data class STDCMEdge(
val infraExplorer:
Expand Down Expand Up @@ -44,49 +43,13 @@ data class STDCMEdge(
val totalTime:
Double, // How long it takes to go from the beginning to the end of the block, taking the
// standard allowance into account
var weight: Double? = null // Weight (total distance from start + estimation to end) of the edge
) : Comparable<STDCMEdge> {
) {
val block = infraExplorer.getCurrentBlock()

init {
assert(!isNaN(timeStart)) { "STDCM edge starts at NaN time" }
}

override fun equals(other: Any?): Boolean {
if (other == null || other.javaClass != STDCMEdge::class.java) return false
val otherEdge = other as STDCMEdge
return if (
infraExplorer.getLastEdgeIdentifier() != otherEdge.infraExplorer.getLastEdgeIdentifier()
)
false
else
minuteTimeStart == otherEdge.minuteTimeStart &&
envelopeStartOffset == otherEdge.envelopeStartOffset

// We need to consider that the edges aren't equal if the times are different,
// but if we do it "naively" we end up visiting the same places a near-infinite number of
// times.
// We handle it by discretizing the start time of the edge: we round the time down to the
// minute and compare
// this value.
}

override fun compareTo(other: STDCMEdge): Int {
return if (weight != other.weight) weight!!.compareTo(other.weight!!)
else {
// If the weights are equal, we prioritize the highest number of reached targets
other.waypointIndex - waypointIndex
}
}

override fun hashCode(): Int {
return Objects.hash(
infraExplorer.getLastEdgeIdentifier(),
minuteTimeStart,
envelopeStartOffset
)
}

/** Returns the node at the end of this edge */
fun getEdgeEnd(graph: STDCMGraph): STDCMNode {
var newWaypointIndex = waypointIndex
Expand Down
32 changes: 30 additions & 2 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fr.sncf.osrd.stdcm.graph
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.stdcm.infra_exploration.InfraExplorerWithEnvelope
import fr.sncf.osrd.utils.units.Offset
import kotlin.math.abs

data class STDCMNode(
val time: Double, // Time at the transition of the edge
Expand All @@ -18,8 +19,31 @@ data class STDCMNode(
Offset<
Block
>?, // Position on a block, if this node isn't on the transition between blocks (stop)
val stopDuration: Double? // When the node is a stop, how long the train remains here
) {
val stopDuration: Double?, // When the node is a stop, how long the train remains here
var remainingTimeEstimation: Double =
0.0, // Estimation of the min time it takes to reach the end from this node
) : Comparable<STDCMNode> {

/**
* Defines the estimated better path between 2 nodes, in terms of total run time, then departure
* time, then number of reached targets. If the result is negative, the current node has a
* better path, and should be explored first. This method allows us to order the nodes in a
* priority queue, from the best path to the worst path. We then explore them in that order.
*/
override fun compareTo(other: STDCMNode): Int {
val runTimeEstimation = getCurrentRunningTime() + remainingTimeEstimation
val otherRunTimeEstimation = other.getCurrentRunningTime() + other.remainingTimeEstimation
// Firstly, minimize the total run time: highest priority node takes the least time to
// complete the path
return if (abs(runTimeEstimation - otherRunTimeEstimation) >= 1e-3)
runTimeEstimation.compareTo(otherRunTimeEstimation)
// If not, take the train which departs first, as it is the closest to the demanded
// departure time
else if (time != other.time) time.compareTo(other.time)
// In the end, prioritize the highest number of reached targets
else other.waypointIndex - waypointIndex
}

override fun toString(): String {
// Not everything is included, otherwise it may recurse a lot over edges / nodes
return String.format(
Expand All @@ -30,4 +54,8 @@ data class STDCMNode(
waypointIndex
)
}

fun getCurrentRunningTime(): Double {
return time - totalPrevAddedDelay
}
}
80 changes: 25 additions & 55 deletions core/src/main/kotlin/fr/sncf/osrd/stdcm/graph/STDCMPathfinding.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import fr.sncf.osrd.graph.*
import fr.sncf.osrd.reporting.exceptions.ErrorType
import fr.sncf.osrd.reporting.exceptions.OSRDError
import fr.sncf.osrd.sim_infra.api.Block
import fr.sncf.osrd.stdcm.STDCMResult
import fr.sncf.osrd.stdcm.STDCMStep
import fr.sncf.osrd.stdcm.*
import fr.sncf.osrd.stdcm.infra_exploration.initInfraExplorerWithEnvelope
import fr.sncf.osrd.stdcm.makeSTDCMHeuristics
import fr.sncf.osrd.stdcm.preprocessing.interfaces.BlockAvailabilityInterface
import fr.sncf.osrd.train.RollingStock
import fr.sncf.osrd.utils.units.Offset
Expand Down Expand Up @@ -73,8 +71,8 @@ class STDCMPathfinding(
private val pathfindingTimeout: Double = 120.0
) {

private var estimateRemainingDistance: List<AStarHeuristic<STDCMEdge, STDCMEdge>>? = ArrayList()
private var starts: Set<STDCMEdge> = HashSet()
private var remainingTimeEstimators: List<STDCMAStarHeuristic<STDCMNode>>? = ArrayList()
private var starts: Set<STDCMNode> = HashSet()

var graph: STDCMGraph =
STDCMGraph(
Expand All @@ -94,14 +92,13 @@ class STDCMPathfinding(
assert(steps.size >= 2) { "Not enough steps have been set to find a path" }

// Initialize the A* heuristic
estimateRemainingDistance =
remainingTimeEstimators =
makeSTDCMHeuristics(
fullInfra.blockInfra,
fullInfra.rawInfra,
steps,
maxRunTime,
rollingStock,
maxDepartureDelay
rollingStock
)

val constraints =
Expand Down Expand Up @@ -143,46 +140,37 @@ class STDCMPathfinding(
}

private fun findPathImpl(): Result? {
val queue = PriorityQueue<STDCMEdge>()
val queue = PriorityQueue<STDCMNode>()
for (location in starts) {
val totalCostUntilEdge = computeTotalCostUntilEdge(location)
val distanceLeftEstimation =
estimateRemainingDistance!![0].apply(location, location.length)
location.weight = distanceLeftEstimation + totalCostUntilEdge
location.remainingTimeEstimation = remainingTimeEstimators!!.apply(location, 0)
queue.add(location)
}
val start = Instant.now()
while (true) {
if (Duration.between(start, Instant.now()).toSeconds() >= pathfindingTimeout)
throw OSRDError(ErrorType.PathfindingTimeoutError)
val edge = queue.poll() ?: return null
if (edge.weight!!.isInfinite()) {
// TODO: filter with max running time, can't be done with abstract weight
val endNode = queue.poll() ?: return null
if (endNode.getCurrentRunningTime() + endNode.remainingTimeEstimation > maxRunTime)
return null
}
// TODO: we mostly reason in terms of endNode, we should probably change the queue.
val endNode = graph.getEdgeEnd(edge)
if (endNode.waypointIndex >= graph.steps.size - 1) {
return buildResult(edge)
return buildResult(endNode)
}
val neighbors = graph.getAdjacentEdges(endNode)
val neighbors = getAdjacentNodes(endNode)
for (neighbor in neighbors) {
val totalCostUntilEdge = computeTotalCostUntilEdge(neighbor)
var distanceLeftEstimation = 0.0
if (neighbor.waypointIndex < estimateRemainingDistance!!.size)
distanceLeftEstimation =
estimateRemainingDistance!![neighbor.waypointIndex].apply(
neighbor,
neighbor.length
)
neighbor.weight = totalCostUntilEdge + distanceLeftEstimation
if (neighbor.waypointIndex < remainingTimeEstimators!!.size)
neighbor.remainingTimeEstimation =
remainingTimeEstimators!!.apply(neighbor, neighbor.waypointIndex)
queue.add(neighbor)
}
}
}

private fun buildResult(edge: STDCMEdge): Result {
var mutLastEdge: STDCMEdge? = edge
private fun getAdjacentNodes(node: STDCMNode): Collection<STDCMNode> {
return graph.getAdjacentEdges(node).map { it.getEdgeEnd(graph) }
}

private fun buildResult(node: STDCMNode): Result {
var mutLastEdge: STDCMEdge? = node.previousEdge
val edges = ArrayDeque<STDCMEdge>()

while (mutLastEdge != null) {
Expand Down Expand Up @@ -222,26 +210,6 @@ class STDCMPathfinding(
return res
}

/**
* Compute the total cost of a path (in s) to an edge location This estimation of the total cost
* is used to compare paths in the pathfinding algorithm. We select the shortest path (in
* duration), and for 2 paths with the same duration, we select the earliest one. The path
* weight which takes into account the total duration of the path and the time shift at the
* departure (with different weights): path_duration * maxDepartureDelay + departure_time_shift.
*
* <br></br> EXAMPLE Let's assume we are trying to find a train between 9am and 10am. The
* maxDepartureDelay is 1 hour (3600s). Let's assume we have found two possible trains:
* - the first one leaves at 9:59 and lasts for 20:00 min.
* - the second one leaves at 9:00 and lasts for 20:01 min. As we are looking for the fastest
* train, the first train should have the lightest weight, which is the case with the formula
* above.
*/
private fun computeTotalCostUntilEdge(edge: STDCMEdge): Double {
val timeEnd = edge.getApproximateTimeAtLocation(edge.length)
val pathDuration = timeEnd - edge.totalDepartureTimeShift
return pathDuration * maxDepartureDelay + edge.totalDepartureTimeShift
}

/** Converts locations on a block id into a location on a STDCMGraph.Edge. */
private fun convertLocations(
graph: STDCMGraph,
Expand All @@ -251,8 +219,8 @@ class STDCMPathfinding(
rollingStock: RollingStock,
stops: List<Collection<PathfindingEdgeLocationId<Block>>> = listOf(),
constraints: List<PathfindingConstraint<Block>>
): Set<STDCMEdge> {
val res = HashSet<STDCMEdge>()
): Set<STDCMNode> {
val res = HashSet<STDCMNode>()

for (location in locations) {
val infraExplorers =
Expand All @@ -265,7 +233,9 @@ class STDCMPathfinding(
.setStartOffset(location.offset)
.setPrevMaximumAddedDelay(maxDepartureDelay)
.makeAllEdges()
for (edge in edges) res.add(edge)
for (edge in edges) {
res.add(edge.getEdgeEnd(graph))
}
}
}
return res
Expand Down
Loading

0 comments on commit 5acf0be

Please sign in to comment.