Skip to content

Commit

Permalink
Update test error msg check
Browse files Browse the repository at this point in the history
  • Loading branch information
albertpod committed Feb 22, 2024
1 parent ad8a5d9 commit 45b1d6e
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions test/inference_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1040,25 +1040,31 @@ end
observations = [[1.0; zeros(5)], [zeros(5); 1.0]]

@testset "Test misspecified data" begin
@test_throws ErrorException infer(model = rolling_die(2), data = (y = observations))
@test_throws "Keyword argument `data` expects either `Dict` or `NamedTuple` as an input" infer(model = rolling_die(2), data = (y = observations))
result = infer(model = rolling_die(2), data = (y = observations,))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end

@testset "Test misspecified initmarginals" begin
@test_throws ErrorException infer(model = rolling_die(2), data = (y = observations,), initmarginals == Dirichlet(ones(6))))
@test_throws "Keyword argument `initmarginals` expects either `Dict` or `NamedTuple` as an input" infer(
model = rolling_die(2), data = (y = observations,), initmarginals == Dirichlet(ones(6)))
)
result = infer(model = rolling_die(2), data = (y = observations,), initmarginals == Dirichlet(ones(6)),))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end

@testset "Test misspecified initmessages" begin
@test_throws ErrorException infer(model = rolling_die(2), data = (y = observations,), initmessages == Dirichlet(ones(6))))
@test_throws "Keyword argument `initmessages` expects either `Dict` or `NamedTuple` as an input" infer(
model = rolling_die(2), data = (y = observations,), initmessages == Dirichlet(ones(6)))
)
result = infer(model = rolling_die(2), data = (y = observations,), initmessages == Dirichlet(ones(6)),))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end

@testset "Test misspecified callbacks" begin
@test_throws ErrorException infer(model = rolling_die(2), data = (y = observations,), callbacks = (before_model_creation = (args...) -> nothing))
@test_throws "Keyword argument `callbacks` expects either `Dict` or `NamedTuple` as an input" infer(
model = rolling_die(2), data = (y = observations,), callbacks = (before_model_creation = (args...) -> nothing)
)
result = infer(model = rolling_die(2), data = (y = observations,), callbacks = (before_model_creation = (args...) -> nothing,))
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
end
Expand Down

0 comments on commit 45b1d6e

Please sign in to comment.