Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible symbolic backend support via ParametricMCPs.SymbolicUtils #21

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ version = "0.1.1"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"

[compat]
BlockArrays = "0.16"
ChainRulesCore = "1"
ParametricMCPs = "0.1.5"
Symbolics = "4,5"
ParametricMCPs = "0.1.14"
TrajectoryGamesBase = "0.3.6"
julia = "1.7"
4 changes: 2 additions & 2 deletions src/MCPTrajectoryGameSolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ using TrajectoryGamesBase:
unflatten_trajectory,
unstack_trajectory

using Symbolics: Symbolics
using ParametricMCPs: ParametricMCPs
using ParametricMCPs: ParametricMCPs, SymbolicUtils

using BlockArrays: BlockArrays, mortar, blocks, eachblock
using ChainRulesCore: ChainRulesCore

Expand Down
66 changes: 30 additions & 36 deletions src/solver_setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
context_dimension = 0,
compute_sensitivities = true,
parametric_mcp_options = (;),
symbolic_backend = SymbolicUtils.SymbolicsBackend(),
)
dimensions = let
state_blocks =
Expand All @@ -22,28 +23,19 @@
(; state_blocks, state, control_blocks, control, context = context_dimension, horizon)
end

initial_state_symbolic = let
Symbolics.@variables(x0[1:(dimensions.state)]) |>
only |>
scalarize |>
initial_state_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :x0, dimensions.state) |>
to_blockvector(dimensions.state_blocks)
end

xs_symbolic = let
Symbolics.@variables(X[1:(dimensions.state * horizon)]) |>
only |>
scalarize |>
xs_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :X, dimensions.state * horizon) |>
to_vector_of_blockvectors(dimensions.state_blocks)
end

us_symbolic = let
Symbolics.@variables(U[1:(dimensions.control * horizon)]) |>
only |>
scalarize |>
us_symbolic =
SymbolicUtils.make_variables(symbolic_backend, :U, dimensions.control * horizon) |>
to_vector_of_blockvectors(dimensions.control_blocks)
end

context_symbolic = Symbolics.@variables(context[1:context_dimension]) |> only |> scalarize
context_symbolic = SymbolicUtils.make_variables(symbolic_backend, :context, context_dimension)

cost_per_player_symbolic = game.cost(xs_symbolic, us_symbolic, context_symbolic)

Expand Down Expand Up @@ -90,7 +82,8 @@
end

if isnothing(game.coupling_constraints)
coupling_constraints_symbolic = Symbolics.Num[]
coupling_constraints_symbolic =

Check warning on line 85 in src/solver_setup.jl

View check run for this annotation

Codecov / codecov/patch

src/solver_setup.jl#L85

Added line #L85 was not covered by tests
SymbolicUtils.make_variables(symbolic_backend, :coupling_constraints, 0)
else
# Note: we don't constraint the first state because we have no control authority over that anyway
coupling_constraints_symbolic =
Expand All @@ -100,19 +93,30 @@
# set up the duals for all constraints
# private constraints
μ_private_symbolic =
Symbolics.@variables(μ[1:length(equality_constraints_symbolic)]) |> only |> scalarize
λ_private_symbolic =
Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |>
only |>
scalarize
SymbolicUtils.make_variables(symbolic_backend, :μ, length(equality_constraints_symbolic))

#λ_private_symbolic =
# Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |>
# only |>
# scalarize
λ_private_symbolic = SymbolicUtils.make_variables(
symbolic_backend,
:λ_private,
length(inequality_constraints_symoblic),
)

# shared constraints
λ_shared_symbolic =
Symbolics.@variables(λ_shared[1:length(coupling_constraints_symbolic)]) |> only |> scalarize
λ_shared_symbolic = SymbolicUtils.make_variables(
symbolic_backend,
:λ_shared,
length(coupling_constraints_symbolic),
)

# multiplier scaling per player as a runtime parameter
# TODO: technically, we could have this scaling for *every* element of the constraint and
# actually every constraint but for now let's keep it simple
shared_constraint_premultipliers_symbolic =
Symbolics.@variables(γ_scaling[1:num_players(game)]) |> only |> scalarize
SymbolicUtils.make_variables(symbolic_backend, :γ_scaling, num_players(game))

private_variables_per_player_symbolic =
flatten_trajetory_per_player((; xs = xs_symbolic, us = us_symbolic))
Expand All @@ -129,7 +133,7 @@
λ_private_symbolic' * inequality_constraints_symoblic -
λ_shared_symbolic' * coupling_constraints_symbolic * γ_ii

Symbolics.gradient(L_ii, τ_ii)
SymbolicUtils.gradient(L_ii, τ_ii)
end

# set up the full KKT system as an MCP
Expand Down Expand Up @@ -181,13 +185,3 @@
function compose_parameter_vector(; initial_state, context, shared_constraint_premultipliers)
[initial_state; context; shared_constraint_premultipliers]
end

"""
Like Symbolics.scalarize but robusutly handle empty arrays.
"""
function scalarize(num)
if length(num) == 0
return Symbolics.Num[]
end
Symbolics.scalarize(num)
end
87 changes: 51 additions & 36 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ using Symbolics: Symbolics

include("Demo.jl")

function isfeasible(game::TrajectoryGamesBase.TrajectoryGame, trajectory; tol=1e-4)
function isfeasible(game::TrajectoryGamesBase.TrajectoryGame, trajectory; tol = 1e-4)
isfeasible(game.dynamics, trajectory; tol) &&
isfeasible(game.env, trajectory; tol) &&
all(game.coupling_constraints(trajectory.xs, trajectory.us) .>= 0 - tol)
end

function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory; tol=1e-4)
function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory; tol = 1e-4)
dynamics_steps_consistent = all(
map(2:length(trajectory.xs)) do t
residual =
trajectory.xs[t] - dynamics(trajectory.xs[t-1], trajectory.us[t-1], t - 1)
trajectory.xs[t] - dynamics(trajectory.xs[t - 1], trajectory.us[t - 1], t - 1)
sum(abs, residual) < tol
end,
)
Expand All @@ -45,7 +45,7 @@ function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory;
dynamics_steps_consistent && state_bounds_feasible && control_bounds_feasible
end

function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol=1e-4)
function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol = 1e-4)
trajectory_per_player = MCPTrajectoryGameSolver.unstack_trajectory(trajectory)

map(enumerate(trajectory_per_player)) do (ii, trajectory)
Expand All @@ -68,20 +68,20 @@ function input_sanity(; solver, game, initial_state, context)
solver,
game,
initial_state;
context=context_with_wrong_size,
context = context_with_wrong_size,
)
multipliers_despite_no_shared_constraints = [1]
@test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!(
solver,
game,
initial_state;
context,
shared_constraint_premultipliers=multipliers_despite_no_shared_constraints,
shared_constraint_premultipliers = multipliers_despite_no_shared_constraints,
)
end
end

function forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy, tol=1e-4)
function forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy, tol = 1e-4)
@testset "forwardpass sanity" begin
nash_trajectory =
TrajectoryGamesBase.rollout(game.dynamics, strategy, initial_state, horizon)
Expand Down Expand Up @@ -112,8 +112,8 @@ function backward_pass_sanity(;
solver,
game,
initial_state,
rng=Random.MersenneTwister(1),
θs=[randn(rng, 4) for _ in 1:10],
rng = Random.MersenneTwister(1),
θs = [randn(rng, 4) for _ in 1:10],
)
@testset "backward pass sanity" begin
function loss(θ)
Expand All @@ -122,7 +122,7 @@ function backward_pass_sanity(;
solver,
game,
initial_state;
context=θ,
context = θ,
)

sum(strategy.substrategies) do substrategy
Expand All @@ -136,7 +136,7 @@ function backward_pass_sanity(;
for θ in θs
∇_zygote = Zygote.gradient(loss, θ) |> only
∇_finitediff = FiniteDiff.finite_difference_gradient(loss, θ)
@test isapprox(∇_zygote, ∇_finitediff; atol=1e-4)
@test isapprox(∇_zygote, ∇_finitediff; atol = 1e-4)
end
end
end
Expand All @@ -147,34 +147,49 @@ function main()
context = [0.0, 1.0, 0.0, 1.0]
initial_state = mortar([[1.0, 0, 0, 0], [-1.0, 0, 0, 0]])

local solver, solver_parallel

@testset "Tests" begin
@testset "solver setup" begin
solver =
MCPTrajectoryGameSolver.Solver(game, horizon; context_dimension=length(context))
# exercise some inner solver options...
solver_parallel = MCPTrajectoryGameSolver.Solver(
game,
horizon;
context_dimension=length(context),
parametric_mcp_options=(; parallel=Symbolics.ShardedForm()),
)
end
for options in [
(; symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.SymbolicsBackend(),),
(;
symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.SymbolicsBackend(),
parametric_mcp_options = (;
backend_options = (; parallel = Symbolics.ShardedForm())
),
),
(;
symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.FastDifferentiationBackend(),
),
]
local solver

@testset "$options" begin
@testset "solver setup" begin
solver = nothing
solver = MCPTrajectoryGameSolver.Solver(
game,
horizon;
context_dimension = length(context),
options...,
)
end

@testset "solve" begin
for solver in [solver, solver_parallel]
input_sanity(; solver, game, initial_state, context)
strategy =
TrajectoryGamesBase.solve_trajectory_game!(solver, game, initial_state; context)
forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy)
backward_pass_sanity(; solver, game, initial_state)
end
end
@testset "solve" begin
input_sanity(; solver, game, initial_state, context)
strategy = TrajectoryGamesBase.solve_trajectory_game!(
solver,
game,
initial_state;
context,
)
forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy)
backward_pass_sanity(; solver, game, initial_state)
end

@testset "integration test" begin
Demo.demo_model_predictive_game_play()
Demo.demo_inverse_game()
@testset "integration test" begin
Demo.demo_model_predictive_game_play()
Demo.demo_inverse_game()
end
end
end
end
end
Expand Down
Loading