Skip to content

Commit

Permalink
Tidy up + enable all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Sep 27, 2024
1 parent 219fe22 commit 495e137
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
15 changes: 11 additions & 4 deletions ext/TemporalGPsMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
module TemporalGPsMooncakeExt

using Mooncake, TemporalGPs
import Mooncake: rrule!!, CoDual, primal, @is_primitive, zero_fcodual, MinimalCtx
import Mooncake:
rrule!!,
CoDual,
primal,
@is_primitive,
zero_fcodual,
MinimalCtx

@is_primitive MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real}
@is_primitive MinimalCtx Tuple{typeof(TemporalGPs.time_exp), Matrix{<:Real}, Real}
function rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64})
B_dB = zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t)))
_A = primal(A)
B_dB = zero_fcodual(TemporalGPs.time_exp(_A, primal(t)))
B = primal(B_dB)
dB = tangent(B_dB)
time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B))
time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (_A * B))
return B_dB, time_exp_pb
end

Expand Down
47 changes: 22 additions & 25 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ using KernelFunctions: kappa
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components
using Test

_logistic(x) = 1 / (1 + exp(-x))

# Everything is tested once the LGSSM is constructed, so it is sufficient just to ensure
# that Zygote can handle construction.
function _construction_tester(f_naive::GP, storage::StorageType, σ², t::AbstractVector)
Expand Down Expand Up @@ -92,84 +90,71 @@ println("lti_sde:")
N = 13
kernels = vcat(
# Base kernels.
(name="base-Matern12Kernel", val=Matern12Kernel(), to_vec_grad=false),
(name="base-Matern12Kernel", val=Matern12Kernel()),
map([Matern32Kernel, Matern52Kernel]) do k
(; name="base-$k", val=k(), to_vec_grad=false)
(; name="base-$k", val=k())
end,

# Scaled kernels.
map([1e-1, 1.0, 10.0, 100.0]) do σ²
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel(), to_vec_grad=false)
(; name="scaled-σ²=$σ²", val=σ² * Matern32Kernel())
end,

# Stretched kernels.
map([1e-2, 0.1, 1.0, 10.0, 100.0]) do λ
(; name="stretched-λ=", val=Matern32Kernel() ScaleTransform(λ), to_vec_grad=false)
(; name="stretched-λ=", val=Matern32Kernel() ScaleTransform(λ))
end,

# Approx periodic kernels
map([7, 11]) do N
(
name="approx-periodic-N=$N",
val=ApproxPeriodicKernel{N}(; r=1.0),
to_vec_grad=true,
)
(name="approx-periodic-N=$N", val=ApproxPeriodicKernel{N}(; r=1.0))
end,
# TEST_TOFIX
# Gradients should be fixed on those composites.
# Error is mostly due do an incompatibility of Tangents
# between Zygote and FiniteDifferences.

# Product kernels
(
name="prod-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ScaleTransform(0.1) * Matern32Kernel()
ScaleTransform(1.1),
to_vec_grad=nothing,
),
(
name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel",
val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(),
to_vec_grad=nothing,
),
# THIS IS KNOWN NOT TO WORK!
# (
# name="prod-(Matern32Kernel + ConstantKernel) * Matern52Kernel",
# val=(Matern32Kernel() + ConstantKernel()) * Matern52Kernel(),
# to_vec_grad=nothing,
# ),

# Summed kernels.
(
name="sum-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ScaleTransform(0.1) +
0.3 * Matern32Kernel() ScaleTransform(1.1),
to_vec_grad=nothing,
),
(
name="sum-Matern32Kernel-Matern52Kernel-ConstantKernel",
val=2.0 * Matern32Kernel() +
0.5 * Matern52Kernel() +
1.0 * ConstantKernel(),
to_vec_grad=nothing,
),
)

# Construct a Gauss-Markov model with either dense storage or static storage.
storages = (
(name="dense storage Float64", val=ArrayStorage(Float64)),
# (name="static storage Float64", val=SArrayStorage(Float64)),
(name="static storage Float64", val=SArrayStorage(Float64)),
)

# Either regular spacing or irregular spacing in time.
ts = (
(name="irregular spacing", val=collect(RegularSpacing(0.0, 0.3, N))),
# (name="regular spacing", val=RegularSpacing(0.0, 0.3, N)),
(name="regular spacing", val=RegularSpacing(0.0, 0.3, N)),
)

σ²s = (
(name="homoscedastic noise", val=(0.1,)),
# (name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )),
(name="heteroscedastic noise", val=(rand(rng, N) .+ 1e-1, )),
)

means = (
Expand All @@ -178,15 +163,21 @@ println("lti_sde:")
(name="Custom Mean", val=CustomMean(x -> 2x)),
)

@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for kernel in
kernels,
@testset "$(kernel.name), $(m.name), $(storage.name), $(t.name), $(σ².name)" for
kernel in kernels,
m in means,
storage in storages,
t in ts,
σ² in σ²s

println("$(kernel.name), $(storage.name), $(m.name), $(t.name), $(σ².name)")

if kernel.val isa TemporalGPs.ApproxPeriodicKernel &&
storage.val isa SArrayStorage
@info "skipping because ApproxPeriodicKernel not compatible with SArrayStorage"
continue
end

# Construct Gauss-Markov model.
f_naive = GP(m.val, kernel.val)
fx_naive = f_naive(collect(t.val), σ².val...)
Expand Down Expand Up @@ -217,4 +208,10 @@ println("lti_sde:")
)
end
end
@testset "time_exp AD" begin
test_rule(
Xoshiro(123), t -> TemporalGPs.time_exp([1.0 2.0; 3.0 4.0], t), rand();
is_primitive=false,
)
end
end

0 comments on commit 495e137

Please sign in to comment.