Skip to content

Latest commit



156 lines (113 loc) · 4.38 KB

File metadata and controls

156 lines (113 loc) · 4.38 KB

Subspace Inference for Bayesian Deep Learning

This package aims to generate the subspace and subspace inferences.

This work is implemented by referring the folloing publication:

Izmailov, P., Maddox, W. J., Kirichenko, P., Garipov, T., Vetrov, D., & Wilson, A. G. (2020, August). Subspace inference for Bayesian deep learning. In Uncertainty in Artificial Intelligence (pp. 1169-1179). PMLR.

Subspace Inference

To generate the uncertainty in machine learing models using subspace inference method

subspace_inference(model, cost, data, opt; callback =()->(return 0),
	σ_z = 1.0,	σ_m = 1.0, σ_p = 1.0,
	itr =1000, T=25, c=1, M=20, print_freq=1)

Input Arguments

  • model : Machine learning model. Eg: Chain(Dense(10,2)). Model should be created with Chain in Flux
  • cost : Cost function. Eg: L(x, y) = Flux.Losses.mse(m(x), y)
  • data : Inputs and outputs. Eg: X = rand(10,100); Y = rand(2,100); data = DataLoader(X,Y);
  • opt : Optimzer. Eg: opt = ADAM(0.1)

Keyword Arguments

  • callback : Callback function during training. Eg: callback() = @show(L(X,Y))
  • σ_z : Standard deviation of subspace
  • σ_m : Standard deviation of likelihood model
  • σ_p : Standard deviation of prior
  • itr : Iterations for sampling
  • T : Number of steps for subspace calculation. Eg: T= 1
  • c : Moment update frequency. Eg: c = 1
  • M : Maximum number of columns in deviation matrix. Eg: M= 3


  • chn : Chain with samples with uncertainty
  • lp : Log probabilities of all samples
  • W_swa : Mean Weight
  • re : Model reformatting functioninformations


using SubspaceInference
using Flux
using Flux: @epochs
using Flux: Flux.Data.DataLoader

l_m = 10
l_n = 100
O = 2

X = rand(l_m,l_n) #input
Y = rand(O,l_n) #output 

data =  DataLoader(X,Y, shuffle=true)

m = Chain(Dense(l_m, 20), Dense(20, 20), Dense(20, O)) #model

L(x, y) = Flux.Losses.mse(m(x), y) #cost function

ps = Flux.params(m) #model parameters

opt = ADAM(0.1) #optimizer

callback() = @show(L(X,Y)) #callback function

@epochs 1 Flux.train!(L, ps, data, opt, cb = () -> callback()) #training

M = 3
T = 10
c= 1
itr = 10
L1(m, x, y) = Flux.Losses.mse(m(x), y) #cost function
chn, lp, W_swa = subspace_inference(m, L1, data, opt, itr = itr, T=T, c=1, M=M)

Subspace Construction

If you just want to generate subspace, you can use subspace_construction function.

The subspace can be generated by using the following function:

	subspace_construction(model, cost, data, opt; 
		callback = ()->(return 0), T = 10, c = 1, M = 3, 
		LR_init = 0.01, print_freq = 1

Input Arguments

  • model : Machine learning model. Eg: Chain(Dense(10,2)). Model should be created with Chain in Flux
  • cost : Cost function. Eg: L(x, y) = Flux.Losses.mse(m(x), y)
  • data : Inputs and outputs. Eg: X = rand(10,100); Y = rand(2,100); data = DataLoader(X,Y);
  • opt : Optimzer. Eg: opt = ADAM(0.1)

Keyword Arguments

  • callback : Callback function during training. Eg: callback() = @show(L(X,Y))
  • T : Number of steps for subspace calculation. Eg: T= 1
  • c : Moment update frequency. Eg: c = 1
  • M : Maximum number of columns in deviation matrix. Eg: M= 2
  • LR_init : Initial learning rate cyclic learning rate updation
  • print_freq: Loss printing frequency


  • W_swa : Mean weights
  • P : Projection Matrix
  • re : Model reconstruction function


using SubspaceInference
using Flux
using Flux: @epochs
using Flux: Flux.Data.DataLoader

l_m = 10
l_n = 100
O = 2

X = rand(l_m,l_n) #input
Y = rand(O,l_n) #output 

data =  DataLoader(X,Y, shuffle=true)

m = Chain(Dense(l_m, 20), Dense(20, 20), Dense(20, O)) #model

L(x, y) = Flux.Losses.mse(m(x), y) #cost function

ps = Flux.params(m) #model parameters

opt = ADAM(0.1) #optimizer

callback() = @show(L(X,Y)) #callback function

@epochs 1 Flux.train!(L, ps, data, opt, cb = () -> callback()) #training

M = 3
T = 10
c= 1
L(m, x, y) = Flux.Losses.mse(m(x), y) #cost function
W_swa, P = subspace_construction(m, L, data, opt, T = T, c = c, M = M)


chn, lp = SubspaceInference.inference(m, data, W_swa, P; σ_z = 1.0,
	σ_m = 1.0, σ_p = 1.0, itr=100, M = 3, alg = :mh)