Skip to content

Commit

Permalink
Merge pull request #418 from ReactiveBayes/multinomial_polya_model
Browse files Browse the repository at this point in the history
Multinomial polya model
  • Loading branch information
bvdmitri authored Feb 26, 2025
2 parents c07338a + ab2337d commit 5d55c53
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RxInfer"
uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d"
authors = ["Bagaev Dmitry <d.v.bagaev@tue.nl> and contributors"]
version = "4.0.1"
version = "4.1.0"

[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
Expand Down Expand Up @@ -56,7 +56,7 @@ Preferences = "1.4.3"
PrettyTables = "2"
ProgressMeter = "1.0.0"
Random = "1.9"
ReactiveMP = "~5.0.0"
ReactiveMP = "~5.1.0"
Reexport = "1.2.0"
Rocket = "1.8.0"
Static = "0.8.10, 1"
Expand Down
4 changes: 2 additions & 2 deletions codemeta.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
"downloadUrl": "/~https://github.com/reactivebayes/RxInfer.jl/releases",
"issueTracker": "/~https://github.com/reactivebayes/RxInfer.jl/issues",
"name": "RxInfer.jl",
"version": "4.0.1",
"version": "4.1.0",
"description": "Julia package for automated, scalable and efficient Bayesian inference on factor graphs with reactive message passing. ",
"applicationCategory": "Statistics",
"developmentStatus": "active",
"readme": "https://rxinfer.ml",
"softwareVersion": "4.0.1",
"softwareVersion": "4.1.0",
"keywords": [
"Bayesian inference",
"message passing",
Expand Down
81 changes: 81 additions & 0 deletions test/models/regression/multinomialreg_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
@testitem "Multinomial regression with MultinomialPolya (offline inference) node" begin
using BenchmarkTools, Plots, Distributions, LinearAlgebra, StableRNGs, ExponentialFamily.LogExpFunctions

include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))

N = 20
k = 10
nsamples = 1000
X, ψ, p = generate_multinomial_data(; N = N, k = k, nsamples = nsamples)

@model function multinomial_model(y, N, ξ_ψ, W_ψ)
ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ)
for i in eachindex(y)
y[i] ~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies= MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ))}
end
end

result = infer(
model = multinomial_model(ξ_ψ = zeros(k - 1), W_ψ = rand(Wishart(k, diageye(k - 1))), N = N),
data = (y = X,),
iterations = 100,
free_energy = true,
showprogress = false,
returnvars = KeepLast(),
options = (limit_stack_depth = 100,)
)

m = mean(result.posteriors[])
pest = logistic_stic_breaking(m)

mse = mean((pest - p) .^ 2)
@test mse < 2e-5

@test result.free_energy[end] < result.free_energy[1]
@test result.free_energy[end] <= result.free_energy[end - 1]
@test abs(result.free_energy[end - 1] - result.free_energy[end]) < 1e-8
end

@testitem "Multinomial regression - online inference" begin
using BenchmarkTools, Plots, Distributions, LinearAlgebra, StableRNGs, ExponentialFamily.LogExpFunctions

include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))

N = 50
k = 40
nsamples = 5000
X, ψ, p = generate_multinomial_data(; N = N, k = k, nsamples = nsamples)

@model function multinomial_model(y, N, ξ_ψ, W_ψ, k)
ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ)
y ~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies= MvNormalWeightedMeanPrecision(zeros(k - 1), diageye(k - 1)))}
end

@autoupdates function auto()
ξ_ψ, W_ψ = weightedmean_precision(q(ψ))
end
init = @initialization begin
q(ψ) = MvNormalWeightedMeanPrecision(zeros(k - 1), rand(Wishart(k, diageye(k - 1))))
end

result = infer(
model = multinomial_model(N = N, k = k),
data = (y = X,),
initialization = init,
iterations = 1,
autoupdates = auto(),
keephistory = length(X),
free_energy = true,
showprogress = false
)

m = result.history[][end]

pest = logistic_stic_breaking(mean(m))
mse = mean((pest - p) .^ 2)
@test mse < 1e-3

@test result.free_energy_final_only_history[end] < result.free_energy_final_only_history[1]
#Free energy over time decreases in a noisy way. It is not a monotonic decrease.

end
21 changes: 21 additions & 0 deletions test/utiltests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,24 @@ macro test_expression_generating(lhs, rhs)
end
)
end

function generate_multinomial_data(rng = StableRNG(123); N = 3, k = 3, nsamples = 5000)
ψ = randn(rng, k)
p = ReactiveMP.softmax(ψ)

X = rand(rng, Multinomial(N, p), nsamples)
X = [X[:, i] for i in axes(X, 2)]
return X, ψ, p
end

function logistic_stic_breaking(m)
Km1 = length(m)

p = Array{Float64}(undef, Km1 + 1)
p[1] = logistic(m[1])
for i in 2:Km1
p[i] = logistic(m[i]) * (1 - sum(p[1:(i - 1)]))
end
p[end] = 1 - sum(p[1:(end - 1)])
return p
end

2 comments on commit 5d55c53

@bvdmitri
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/125871

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.1.0 -m "<description of version>" 5d55c537ceac9d4691b2d8d09331f36cb61905d4
git push origin v4.1.0

Please sign in to comment.