Skip to content

Commit

Permalink
Fix the tests to use Lux and ForwardDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 23, 2023
1 parent a5cc09e commit 116188d
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 48 deletions.
4 changes: 2 additions & 2 deletions docs/src/examples/hamiltonian_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ model = NeuralHamiltonianDE(
save_start = true, saveat = t
)
pred = Array(model(data[:, 1], ps_c, st))
pred = Array(first(model(data[:, 1], ps_c, st)))
plot(data[1, :], data[2, :], lw=4, label="Original")
plot!(pred[1, :], pred[2, :], lw=4, label="Predicted")
xlabel!("Position (q)")
Expand Down Expand Up @@ -112,7 +112,7 @@ model = NeuralHamiltonianDE(
save_start = true, saveat = t
)
pred = Array(model(data[:, 1], ps_c, st))
pred = Array(first(model(data[:, 1], ps_c, st)))
plot(data[1, :], data[2, :], lw=4, label="Original")
plot!(pred[1, :], pred[2, :], lw=4, label="Predicted")
xlabel!("Position (q)")
Expand Down
39 changes: 21 additions & 18 deletions src/hnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,20 @@ struct NeuralHamiltonianDE{M,P,RE,T,A,K} <: NeuralDELayer
tspan::T
args::A
kwargs::K
end

function NeuralHamiltonianDE(model, tspan, args...; p=nothing, kwargs...)
hnn = HamiltonianNN(model, p=p)
new{typeof(hnn.model),typeof(hnn.p),typeof(hnn.re),
typeof(tspan),typeof(args),typeof(kwargs)}(
hnn, hnn.p, tspan, args, kwargs)
end
# TODO: Make sensealg an argument
function NeuralHamiltonianDE(model, tspan, args...; p=nothing, kwargs...)
hnn = HamiltonianNN(model, p=p)
return NeuralHamiltonianDE{typeof(hnn.model),typeof(hnn.p),typeof(hnn.re),
typeof(tspan),typeof(args),typeof(kwargs)}(
hnn, hnn.p, tspan, args, kwargs)
end

function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...;
p=hnn.p, kwargs...) where {M,RE,P}
new{M,P,RE,typeof(tspan),typeof(args),
typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs)
end
function NeuralHamiltonianDE(hnn::HamiltonianNN{M,RE,P}, tspan, args...;
p=hnn.p, kwargs...) where {M,RE,P}
return NeuralHamiltonianDE{M,P,RE,typeof(tspan),typeof(args),
typeof(kwargs)}(hnn, hnn.p, tspan, args, kwargs)
end

function (nhde::NeuralHamiltonianDE)(x, p=nhde.p)
Expand All @@ -112,19 +113,21 @@ function (nhde::NeuralHamiltonianDE)(x, p=nhde.p)
end
prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, p)
# NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use
# ForwardDiff.jl internally.
sense = InterpolatingAdjoint(autojacvec=true)
return solve(prob, nhde.args...; sensealg=sense, nhde.kwargs...)
# FiniteDiff.jl internally.
# FIXME: Using ForwardDiff.jl is erroring in SciMLSensitivity.jl. We should fix that.
sensealg = InterpolatingAdjoint(; autojacvec=false)
return solve(prob, nhde.args...; sensealg, nhde.kwargs...)
end

function (nhde::NeuralHamiltonianDE{<:LuxCore.AbstractExplicitLayer})(x, ps, st)
function neural_hamiltonian!(du, u, p, t)
y, st = nhde.model(u, ps, st)
y, st = nhde.model(u, p, st)
du .= reshape(y, size(du))
end
prob = ODEProblem(ODEFunction{true}(neural_hamiltonian!), x, nhde.tspan, ps)
# NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP. Instead we use
# ForwardDiff.jl internally.
sense = InterpolatingAdjoint(autojacvec=true)
return solve(prob, nhde.args...; sensealg=sense, nhde.kwargs...)
# FiniteDiff.jl internally.
# FIXME: Using ForwardDiff.jl is erroring in SciMLSensitivity.jl. We should fix that.
sensealg = InterpolatingAdjoint(; autojacvec=false)
return solve(prob, nhde.args...; sensealg, nhde.kwargs...), st
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand Down
59 changes: 31 additions & 28 deletions test/hamiltonian_nn.jl
Original file line number Diff line number Diff line change
@@ -1,63 +1,66 @@
using DiffEqFlux, Zygote, OrdinaryDiffEq, ReverseDiff, Test
using DiffEqFlux, Zygote, OrdinaryDiffEq, ForwardDiff, Test, Optimisers, Random, Lux, ComponentArrays

# Checks for Shapes and Non-Zero Gradients
u0 = rand(Float32, 6, 1)

hnn = HamiltonianNN(Flux.Chain(Flux.Dense(6, 12, relu), Flux.Dense(12, 1)))
p = hnn.p
hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1)))
ps, st = Lux.setup(Random.default_rng(), hnn)
ps = ps |> ComponentArray

@test size(hnn(u0)) == (6, 1)
@test size(first(hnn(u0, ps, st))) == (6, 1)

@test ! iszero(ReverseDiff.gradient(p -> sum(hnn(u0, p)), p))
@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps))

hnn = HamiltonianNN(Flux.Chain(Flux.Dense(6, 12, relu), Flux.Dense(12, 1)))
p = hnn.p
hnn = HamiltonianNN(Lux.Chain(Lux.Dense(6, 12, relu), Lux.Dense(12, 1)))
ps, st = Lux.setup(Random.default_rng(), hnn)
ps = ps |> ComponentArray

@test size(hnn(u0)) == (6, 1)
@test size(first(hnn(u0, ps, st))) == (6, 1)

@test ! iszero(ReverseDiff.gradient(p -> sum(hnn(u0, p)), p))
@test !iszero(ForwardDiff.gradient(ps -> sum(first(hnn(u0, ps, st))), ps))

# Test Convergence on a toy problem
t = range(0.0f0, 1.0f0, length = 64)
t = range(0.0f0, 1.0f0, length=64)
π_32 = Float32(π)
q_t = reshape(sin.(_32 * t), 1, :)
p_t = reshape(cos.(_32 * t), 1, :)
dqdt = _32 .* p_t
dpdt = -_32 .* q_t

data = cat(q_t, p_t, dims = 1)
target = cat(dqdt, dpdt, dims = 1)
data = vcat(q_t, p_t)
target = vcat(dqdt, dpdt)

hnn = HamiltonianNN(Flux.Chain(Flux.Dense(2, 16, relu), Flux.Dense(16, 1)))
p = hnn.p
hnn = HamiltonianNN(Lux.Chain(Lux.Dense(2, 16, relu), Lux.Dense(16, 1)))
ps, st = Lux.setup(Random.default_rng(), hnn)
ps = ps |> ComponentArray

opt = ADAM(0.01)
loss(x, y, p) = sum((hnn(x, p) .- y) .^ 2)
st_opt = Optimisers.setup(opt, ps)
loss(data, target, ps) = mean(abs2, first(hnn(data, ps, st)) .- target)

initial_loss = loss(data, target, p)
initial_loss = loss(data, target, ps)

epochs = 100
for epoch in 1:epochs
gs = ReverseDiff.gradient(p -> loss(data, target, p), p)
Flux.Optimise.update!(opt, p, gs)
for epoch in 1:100
# Forward Mode over Reverse Mode for Training
gs = ForwardDiff.gradient(ps -> loss(data, target, ps), ps)
st_opt, ps = Optimisers.update!(st_opt, ps, gs)
end

final_loss = loss(data, target, p)
final_loss = loss(data, target, ps)

@test initial_loss > final_loss
@test initial_loss > 5 * final_loss

# Test output and gradient of NeuralHamiltonianDE Layer
tspan = (0.0f0, 1.0f0)

model = NeuralHamiltonianDE(
hnn, tspan, Tsit5(),
save_everystep = false, save_start = true,
saveat = range(tspan[1], tspan[2], length=10)
save_everystep=false, save_start=true,
saveat=range(tspan[1], tspan[2], length=10)
)
sol = Array(model(data[:, 1]))
sol = Array(first(model(data[:, 1], ps, st)))
@test size(sol) == (2, 10)

ps = Flux.params(model)
gs = Flux.gradient(() -> sum(Array(model(data[:, 1]))), ps)
gs = only(Zygote.gradient(ps -> sum(Array(first(model(data[:, 1], ps, st)))), ps))

@test ! iszero(gs[model.p])
@test !iszero(gs)

0 comments on commit 116188d

Please sign in to comment.