This is the prior over data. Each dimension is considered independent. The input distribution the same dimensionality as the data. One point of confusion is they use
Note! in training we never use the prior, the point of BFNs is that we instead recieve an 'updated' prior, therefore our 'input distribution' is always the prior + updated with a noisy sample and known variance.
Gaussian for discretised/continuous. Categorical for discrete.
In BFNs data is considered to be a distribution with a known variance (or accuracy/precision) over time. When the variance is infinite, there is no information about the original data. When variance is 0, a delta function centered on the data is recovered. Information is progressively recovered with a variance (accuracy) scheduler.
A neural network takes as input the updated prior and the time (which informs the network about the accuracy/precision level)
Discretised data is data considered to be truly continuous, but must fit within K bins e.g. an image captures continuous light spectrum but is binned into 3x256 RGB values.
Accuracy scheduler
the scaling term
'''
# Shape-> Tensor[B, D, K] bins pass the noisy samples to the model, output a continuous distribution and rediscretise it
output_distribution = self.discretised_output_distribution(sender_mu_sample, t=t, gamma=gamma)
gmm = output_distribution*self.k_centers
# Shape-> Tensor[B, D] sum out over final distribution - weighted sums
K_hat = torch.sum(gmm, dim=-1).view(batch_size, -1)
# Shape-> Tensor[B, D]
diff = (discretised_data - K_hat).pow(2)
# loss infinity algorithm 5
loss = -safe_log(self.sigma_one) * self.sigma_one.pow(-2*t) * diff
loss = torch.mean(loss)
'''
Discrete time loss
$$ \text{which cannot be calculated in closed form, but can be estimated with Monte-Carlo sampling.} \ \text{Substituting into Eq. 24,} \$$
$$ L^n(\mathbf{x}) = n \mathbb{E}_{t, p_f, \mathcal{N}(y|\mathbf{x}, \alpha_i^{-1}I)} \ln (\mathcal{N}(y|\mathbf{x}, \alpha_i^{-1}I) )\
- \sum_{d=1}^{D} \ln (\sum_{k=1}^{K} \left( p_O^{(d)}(k | \theta, t_{i-1})\mathcal{N}(y^{(d)} | k_c, \alpha_i^{-1})) \right). \tag{119}$$
'''
# Sender dist - use Monte Carlo sampling
y_sender_distribution = dist.Normal(discretised_data, torch.sqrt(1/alpha))
y_sender_samples = y_sender_distribution.sample(torch.Size([monte_carlo_samples]))
# Receiver distribution - GMM
receiver_mix_dist = dist.Categorical(probs=output_distribution)
receiver_components = dist.Normal(self.k_centers, torch.sqrt(1/alpha).unsqueeze(-1))
receiver_dist = dist.MixtureSameFamily(receiver_mix_dist, receiver_components)
# Calculating the loss KL between the sender and receiver
log_prob_y_sender = y_sender_distribution.log_prob(y_sender_samples)
log_prob_y_receiver = receiver_dist.log_prob(y_sender_samples)
loss = n * torch.mean(log_prob_y_sender - log_prob_y_receiver)
'''
$$ \text{and clip at } [-1,1] \text{ to obtain}
G(x | \mu_x^{(d)}, \sigma_x^{(d)}) = \begin{cases} 0 & \text{if } x \leq -1, \ 1 & \text{if } x \geq 1, \ F(x | \mu_x^{(d)}, \sigma_x^{(d)}) & \text{otherwise}. \end{cases} \tag{108} $$
$$ \text{Then, for } k \in {1, K},
p_O^{(d)}(k | \theta; t) \triangleq G(k | \mu_x^{(d)}, \sigma_x^{(d)}) - G(k-1 | \mu_x^{(d)}, \sigma_x^{(d)}), \tag{109} $$
'''
# run the samples through the model to get prediction of mean and log(sigma) of noise -> Tensor[B, D, 2]
mu_eps, ln_sigma_eps = self.forward(mu, t)
# update prediction of data w.r.t noise predictions
var_scale = torch.sqrt((1-gamma)/gamma)
mu_x = (mu/gamma) - (var_scale * mu_eps)
sigma_x = torch.clamp(var_scale * safe_exp(ln_sigma_eps), self.sigma_one)
# clip output distribution if time is lower than min threshold
mu_x = torch.where(t < t_min, torch.zeros_like(mu_x), mu_x)
sigma_x = torch.where(t < t_min, torch.ones_like(sigma_x), sigma_x)
normal_dist = dist.Normal(mu_x, sigma_x)
broadcasted_k_lower = self.k_lower.repeat(mu_x.shape[1], mu_x.shape[0], 1).transpose(0, 2)
broadcast_k_upper = self.k_upper.repeat(mu_x.shape[1], mu_x.shape[0], 1).transpose(0, 2)
cdf_values_lower = normal_dist.cdf(broadcasted_k_lower)
cdf_values_upper = normal_dist.cdf(broadcast_k_upper)
# make sure the lower cdf is bounded at 0 and the upper cdf at 1, this has the effect of clipping the distribution (see eq 108) and ensures the total sums to 1
cdf_values_lower = torch.where(broadcasted_k_lower<=-1, torch.zeros_like(cdf_values_lower), cdf_values_lower)
cdf_values_upper = torch.where(broadcast_k_upper>=1, torch.ones_like(cdf_values_upper), cdf_values_upper)
# calculate area in each bin
discretised_output_dist = (cdf_values_upper - cdf_values_lower).permute(1, 2, 0)
'''
The guiding heuristic for accuracy scheduler β(t) is to decrease the expected entropy of the input distribution linearly with t.
Figure 9: Accuracy schedule vs. expected entropy for discrete data. The surface plot shows the expectation over the parameter distribution
Discrete-time loss:
$$ L^n(\mathbf{x}) = n \mathbf{E}{t\sim U{1,n},P(\mathbf{\theta}|\mathbf{x},t{i-1}), \mathcal{N}\left(\mathbf{y}|\alpha_{i}(K \mathbf{e_x} - 1),\alpha_{i}K\mathbf{I}\right) } \left[ \mathcal{N}\left(\mathbf{y}|\alpha_{i}(K \mathbf{e_x} - 1),\alpha_{i}K\mathbf{I}\right) \right] \tag{189} $$
$$ = -\sum_{d=1}^{D} \ln \left( \sum_{k=1}^{K} p_{o}^{(d)}(k \mid \theta; t_{i-1}) \mathcal{N}\left(y^{(d)} \mid \alpha_{i}(K\mathbf{e}k - 1), \alpha{i} K \mathbf{I} \right) \right), \tag{190} $$
where
Algorithm 8 Continuous-Time Loss
Require: ( \beta(1) \in \mathbb{R}^+ ), number of classes ( K \in \mathbb{N} )
Input: discrete data ( x \in {1, K}^D )
$ \tau \sim U(0, 1) $
- ( \beta \leftarrow \beta(1)\tau^2 )
- ( y \sim \mathcal{N} (\beta (Ke_x - 1) , \beta KI) )
- ( \theta \leftarrow \text{softmax}(y) )
- ( p_o( \cdot \mid \theta; t) \leftarrow \text{DISCRETE_OUTPUT_DISTRIBUTION}(\theta, t) ) -- output distribution
- ( \hat{e}(\theta, t) \leftarrow \left( \sum_k p_o^{(1)}(k \mid \theta; t)e_k, ..., \sum_k p_o^{(D)}(k \mid \theta; t)e_k \right) ) -- data expectation
- (e_x = \text{one_hot}(x, \text{num_classes}=K) )
- ( L^∞(x) \leftarrow K\beta(1)t \left|e_x - \hat{e}(\theta, t)\right|^2 )
function DISCRETE_OUTPUT_DISTRIBUTION(θ ∈ [0, 1]KD, t ∈ [0, 1])
Input (θ, t) to network, receive Ψ(θ, t) as output
for d ∈ {1, D} do
if k = 2 then
$ p_o^{(d)}(1 | \theta; t) \leftarrow \sigma(\Psi^{(d)}(\theta, t))$
$ p_o^{(d)}(2 | \theta; t) \leftarrow 1 - p_o^{(d)}(1 | \theta; t) $
else
( p_o^{(d)}( \cdot | \theta; t) \leftarrow \text{softmax}(\Psi^{(d)}(\theta, t)) )
end if
end for
Return ( p_o( \cdot | \theta; t) )
end function