Skip to content

Commit

Permalink
Merge pull request #228 from ReactiveBayes/fix_infer_dicttype
Browse files Browse the repository at this point in the history
Fix missing check for infer function
  • Loading branch information
bvdmitri authored Feb 23, 2024
2 parents 23f95bb + 9cd0d84 commit 495baa7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,7 @@ function infer(;
__infer_check_dicttype(:initmarginals, initmarginals)
__infer_check_dicttype(:initmessages, initmessages)
__infer_check_dicttype(:callbacks, callbacks)
__infer_check_dicttype(:data, data)

if isnothing(autoupdates)
__check_available_callbacks(warn, callbacks, available_callbacks(__inference))
Expand Down
43 changes: 43 additions & 0 deletions test/inference_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1026,3 +1026,46 @@ end

@test all(result.predictions[:y] .== Bernoulli(mean(Beta(1.0, 1.0))))
end

@testitem "Test misspecified types in infer function" begin
@model function rolling_die(n)
y = datavar(Vector{Float64}, n)

θ ~ Dirichlet(ones(6))
for i in 1:n
y[i] ~ Categorical(θ)
end
end

observations = [[1.0; zeros(5)], [zeros(5); 1.0]]

@testset "Test misspecified data" begin
@test_throws "Keyword argument `data` expects either `Dict` or `NamedTuple` as an input" infer(model = rolling_die(2), data = (y = observations))
result = infer(model = rolling_die(2), data = (y = observations,))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end

@testset "Test misspecified initmarginals" begin
@test_throws "Keyword argument `initmarginals` expects either `Dict` or `NamedTuple` as an input" infer(
model = rolling_die(2), data = (y = observations,), initmarginals == Dirichlet(ones(6)))
)
result = infer(model = rolling_die(2), data = (y = observations,), initmarginals == Dirichlet(ones(6)),))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end

@testset "Test misspecified initmessages" begin
@test_throws "Keyword argument `initmessages` expects either `Dict` or `NamedTuple` as an input" infer(
model = rolling_die(2), data = (y = observations,), initmessages == Dirichlet(ones(6)))
)
result = infer(model = rolling_die(2), data = (y = observations,), initmessages == Dirichlet(ones(6)),))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end

@testset "Test misspecified callbacks" begin
@test_throws "Keyword argument `callbacks` expects either `Dict` or `NamedTuple` as an input" infer(
model = rolling_die(2), data = (y = observations,), callbacks = (before_model_creation = (args...) -> nothing)
)
result = infer(model = rolling_die(2), data = (y = observations,), callbacks = (before_model_creation = (args...) -> nothing,))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end
end

0 comments on commit 495baa7

Please sign in to comment.