diff --git a/Project.toml b/Project.toml index 42641ed..b804185 100644 --- a/Project.toml +++ b/Project.toml @@ -7,13 +7,11 @@ version = "0.1.1" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad" -Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574" [compat] BlockArrays = "0.16" ChainRulesCore = "1" -ParametricMCPs = "0.1.5" -Symbolics = "4,5" +ParametricMCPs = "0.1.14" TrajectoryGamesBase = "0.3.6" julia = "1.7" diff --git a/src/MCPTrajectoryGameSolver.jl b/src/MCPTrajectoryGameSolver.jl index b497cd1..1518957 100644 --- a/src/MCPTrajectoryGameSolver.jl +++ b/src/MCPTrajectoryGameSolver.jl @@ -18,8 +18,8 @@ using TrajectoryGamesBase: unflatten_trajectory, unstack_trajectory -using Symbolics: Symbolics -using ParametricMCPs: ParametricMCPs +using ParametricMCPs: ParametricMCPs, SymbolicUtils + using BlockArrays: BlockArrays, mortar, blocks, eachblock using ChainRulesCore: ChainRulesCore diff --git a/src/solver_setup.jl b/src/solver_setup.jl index 45f19c9..44f1d03 100644 --- a/src/solver_setup.jl +++ b/src/solver_setup.jl @@ -11,6 +11,7 @@ function Solver( context_dimension = 0, compute_sensitivities = true, parametric_mcp_options = (;), + symbolic_backend = SymbolicUtils.SymbolicsBackend(), ) dimensions = let state_blocks = @@ -22,28 +23,19 @@ function Solver( (; state_blocks, state, control_blocks, control, context = context_dimension, horizon) end - initial_state_symbolic = let - Symbolics.@variables(x0[1:(dimensions.state)]) |> - only |> - scalarize |> + initial_state_symbolic = + SymbolicUtils.make_variables(symbolic_backend, :x0, dimensions.state) |> to_blockvector(dimensions.state_blocks) - end - xs_symbolic = let - Symbolics.@variables(X[1:(dimensions.state * horizon)]) |> - only |> - scalarize |> + xs_symbolic = + SymbolicUtils.make_variables(symbolic_backend, :X, dimensions.state * horizon) |> to_vector_of_blockvectors(dimensions.state_blocks) - end - us_symbolic = let - Symbolics.@variables(U[1:(dimensions.control * horizon)]) |> - only |> - scalarize |> + us_symbolic = + SymbolicUtils.make_variables(symbolic_backend, :U, dimensions.control * horizon) |> to_vector_of_blockvectors(dimensions.control_blocks) - end - context_symbolic = Symbolics.@variables(context[1:context_dimension]) |> only |> scalarize + context_symbolic = SymbolicUtils.make_variables(symbolic_backend, :context, context_dimension) cost_per_player_symbolic = game.cost(xs_symbolic, us_symbolic, context_symbolic) @@ -90,7 +82,8 @@ function Solver( end if isnothing(game.coupling_constraints) - coupling_constraints_symbolic = Symbolics.Num[] + coupling_constraints_symbolic = + SymbolicUtils.make_variables(symbolic_backend, :coupling_constraints, 0) else # Note: we don't constraint the first state because we have no control authority over that anyway coupling_constraints_symbolic = @@ -100,19 +93,30 @@ function Solver( # set up the duals for all constraints # private constraints μ_private_symbolic = - Symbolics.@variables(μ[1:length(equality_constraints_symbolic)]) |> only |> scalarize - λ_private_symbolic = - Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |> - only |> - scalarize + SymbolicUtils.make_variables(symbolic_backend, :μ, length(equality_constraints_symbolic)) + + #λ_private_symbolic = + # Symbolics.@variables(λ_private[1:length(inequality_constraints_symoblic)]) |> + # only |> + # scalarize + λ_private_symbolic = SymbolicUtils.make_variables( + symbolic_backend, + :λ_private, + length(inequality_constraints_symoblic), + ) + # shared constraints - λ_shared_symbolic = - Symbolics.@variables(λ_shared[1:length(coupling_constraints_symbolic)]) |> only |> scalarize + λ_shared_symbolic = SymbolicUtils.make_variables( + symbolic_backend, + :λ_shared, + length(coupling_constraints_symbolic), + ) + # multiplier scaling per player as a runtime parameter # TODO: technically, we could have this scaling for *every* element of the constraint and # actually every constraint but for now let's keep it simple shared_constraint_premultipliers_symbolic = - Symbolics.@variables(γ_scaling[1:num_players(game)]) |> only |> scalarize + SymbolicUtils.make_variables(symbolic_backend, :γ_scaling, num_players(game)) private_variables_per_player_symbolic = flatten_trajetory_per_player((; xs = xs_symbolic, us = us_symbolic)) @@ -129,7 +133,7 @@ function Solver( λ_private_symbolic' * inequality_constraints_symoblic - λ_shared_symbolic' * coupling_constraints_symbolic * γ_ii - Symbolics.gradient(L_ii, τ_ii) + SymbolicUtils.gradient(L_ii, τ_ii) end # set up the full KKT system as an MCP @@ -181,13 +185,3 @@ end function compose_parameter_vector(; initial_state, context, shared_constraint_premultipliers) [initial_state; context; shared_constraint_premultipliers] end - -""" -Like Symbolics.scalarize but robusutly handle empty arrays. -""" -function scalarize(num) - if length(num) == 0 - return Symbolics.Num[] - end - Symbolics.scalarize(num) -end diff --git a/test/runtests.jl b/test/runtests.jl index 37e4094..627fce0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,17 +13,17 @@ using Symbolics: Symbolics include("Demo.jl") -function isfeasible(game::TrajectoryGamesBase.TrajectoryGame, trajectory; tol=1e-4) +function isfeasible(game::TrajectoryGamesBase.TrajectoryGame, trajectory; tol = 1e-4) isfeasible(game.dynamics, trajectory; tol) && isfeasible(game.env, trajectory; tol) && all(game.coupling_constraints(trajectory.xs, trajectory.us) .>= 0 - tol) end -function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory; tol=1e-4) +function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory; tol = 1e-4) dynamics_steps_consistent = all( map(2:length(trajectory.xs)) do t residual = - trajectory.xs[t] - dynamics(trajectory.xs[t-1], trajectory.us[t-1], t - 1) + trajectory.xs[t] - dynamics(trajectory.xs[t - 1], trajectory.us[t - 1], t - 1) sum(abs, residual) < tol end, ) @@ -45,7 +45,7 @@ function isfeasible(dynamics::TrajectoryGamesBase.AbstractDynamics, trajectory; dynamics_steps_consistent && state_bounds_feasible && control_bounds_feasible end -function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol=1e-4) +function isfeasible(env::TrajectoryGamesBase.PolygonEnvironment, trajectory; tol = 1e-4) trajectory_per_player = MCPTrajectoryGameSolver.unstack_trajectory(trajectory) map(enumerate(trajectory_per_player)) do (ii, trajectory) @@ -68,7 +68,7 @@ function input_sanity(; solver, game, initial_state, context) solver, game, initial_state; - context=context_with_wrong_size, + context = context_with_wrong_size, ) multipliers_despite_no_shared_constraints = [1] @test_throws ArgumentError TrajectoryGamesBase.solve_trajectory_game!( @@ -76,12 +76,12 @@ function input_sanity(; solver, game, initial_state, context) game, initial_state; context, - shared_constraint_premultipliers=multipliers_despite_no_shared_constraints, + shared_constraint_premultipliers = multipliers_despite_no_shared_constraints, ) end end -function forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy, tol=1e-4) +function forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy, tol = 1e-4) @testset "forwardpass sanity" begin nash_trajectory = TrajectoryGamesBase.rollout(game.dynamics, strategy, initial_state, horizon) @@ -112,8 +112,8 @@ function backward_pass_sanity(; solver, game, initial_state, - rng=Random.MersenneTwister(1), - θs=[randn(rng, 4) for _ in 1:10], + rng = Random.MersenneTwister(1), + θs = [randn(rng, 4) for _ in 1:10], ) @testset "backward pass sanity" begin function loss(θ) @@ -122,7 +122,7 @@ function backward_pass_sanity(; solver, game, initial_state; - context=θ, + context = θ, ) sum(strategy.substrategies) do substrategy @@ -136,7 +136,7 @@ function backward_pass_sanity(; for θ in θs ∇_zygote = Zygote.gradient(loss, θ) |> only ∇_finitediff = FiniteDiff.finite_difference_gradient(loss, θ) - @test isapprox(∇_zygote, ∇_finitediff; atol=1e-4) + @test isapprox(∇_zygote, ∇_finitediff; atol = 1e-4) end end end @@ -147,34 +147,49 @@ function main() context = [0.0, 1.0, 0.0, 1.0] initial_state = mortar([[1.0, 0, 0, 0], [-1.0, 0, 0, 0]]) - local solver, solver_parallel - @testset "Tests" begin - @testset "solver setup" begin - solver = - MCPTrajectoryGameSolver.Solver(game, horizon; context_dimension=length(context)) - # exercise some inner solver options... - solver_parallel = MCPTrajectoryGameSolver.Solver( - game, - horizon; - context_dimension=length(context), - parametric_mcp_options=(; parallel=Symbolics.ShardedForm()), - ) - end + for options in [ + (; symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.SymbolicsBackend(),), + (; + symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.SymbolicsBackend(), + parametric_mcp_options = (; + backend_options = (; parallel = Symbolics.ShardedForm()) + ), + ), + (; + symbolic_backend = MCPTrajectoryGameSolver.SymbolicUtils.FastDifferentiationBackend(), + ), + ] + local solver + + @testset "$options" begin + @testset "solver setup" begin + solver = nothing + solver = MCPTrajectoryGameSolver.Solver( + game, + horizon; + context_dimension = length(context), + options..., + ) + end - @testset "solve" begin - for solver in [solver, solver_parallel] - input_sanity(; solver, game, initial_state, context) - strategy = - TrajectoryGamesBase.solve_trajectory_game!(solver, game, initial_state; context) - forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy) - backward_pass_sanity(; solver, game, initial_state) - end - end + @testset "solve" begin + input_sanity(; solver, game, initial_state, context) + strategy = TrajectoryGamesBase.solve_trajectory_game!( + solver, + game, + initial_state; + context, + ) + forward_pass_sanity(; solver, game, initial_state, context, horizon, strategy) + backward_pass_sanity(; solver, game, initial_state) + end - @testset "integration test" begin - Demo.demo_model_predictive_game_play() - Demo.demo_inverse_game() + @testset "integration test" begin + Demo.demo_model_predictive_game_play() + Demo.demo_inverse_game() + end + end end end end