Skip to content

Commit

Permalink
fix categorical sample can't pass hypothesis test and entropy shape e…
Browse files Browse the repository at this point in the history
…rror bug
  • Loading branch information
cxxly committed Jan 28, 2022
1 parent 63ac9d6 commit 852300d
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 117 deletions.
36 changes: 21 additions & 15 deletions python/paddle/distribution/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ class Beta(ExponentialFamily):
Args:
alpha (float|Tensor): alpha parameter of beta distribution, positive(>0), support broadcast semantic. when the parameter is tensor type, represent multiple independent distribution with batch_shape(refer to `Distribution`) equals to shape of alpha after broadcast.
beta (float|Tensor): beta parameter of beta distribution, positive(>0).when the parameter is tensor type, represent multiple independent distribution with batch_shape(refer to `Distribution`) equals to shape of beta after broadcast.
alpha (float|Tensor): Alpha parameter. It supports broadcast semantics.
The value of alpha must be positive. When the parameter is a tensor,
it represents multiple independent distribution with
a batch_shape(refer to ``Distribution`` ).
beta (float|Tensor): Beta parameter. It supports broadcast semantics.
The value of beta must be positive(>0). When the parameter is tensor,
it represent multiple independent distribution with
a batch_shape(refer to ``Distribution`` ).
Examples:
Expand Down Expand Up @@ -93,56 +99,56 @@ def __init__(self, alpha, beta):

@property
def mean(self):
"""mean of beta distribution.
"""Mean of beta distribution.
"""
return self.alpha / (self.alpha + self.beta)

@property
def variance(self):
"""variance of beat distribution
"""Variance of beat distribution
"""
sum = self.alpha + self.beta
return self.alpha * self.beta / (sum.pow(2) * (sum + 1))

def prob(self, value):
"""probability density funciotn evaluated at value
"""Probability density funciotn evaluated at value
Args:
value (Tensor): value to be evaluated.
value (Tensor): Value to be evaluated.
Returns:
Tensor: probability.
Tensor: Probability.
"""
return paddle.exp(self.log_prob(value))

def log_prob(self, value):
"""log probability density funciton evaluated at value
"""Log probability density funciton evaluated at value
Args:
value (Tensor): value to be evaluated
value (Tensor): Value to be evaluated
Returns:
Tensor: log probability.
Tensor: Log probability.
"""
return self._dirichlet.log_prob(paddle.stack([value, 1.0 - value], -1))

def sample(self, shape=()):
"""sample from beta distribution with sample shape.
"""Sample from beta distribution with sample shape.
Args:
shape (Sequence[int], optional): sample shape.
shape (Sequence[int], optional): Sample shape.
Returns:
sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`.
"""
shape = shape if isinstance(shape, tuple) else tuple(shape)
return paddle.squeeze(self._dirichlet.sample(shape)[..., 0], axis=-1)

def entropy(self):
"""entropy of dirichlet distribution
"""Entropy of dirichlet distribution
Returns:
Tensor: entropy.
Tensor: Entropy.
"""
return self._dirichlet.entropy()

Expand Down
61 changes: 34 additions & 27 deletions python/paddle/distribution/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import warnings

import numpy as np
import paddle
from paddle import _C_ops

from ..fluid import core
Expand Down Expand Up @@ -153,15 +154,22 @@ def sample(self, shape):
logits_shape = list(self.logits.shape)
if len(logits_shape) > 1:
sample_shape = shape + logits_shape[:-1]
logits = nn.reshape(self.logits,
[np.prod(logits_shape[:-1]), logits_shape[-1]])
logits = paddle.reshape(
self.logits, [np.prod(logits_shape[:-1]), logits_shape[-1]])
else:
sample_shape = shape
logits = self.logits

sample_index = multinomial(
self._logits_to_probs(logits), num_samples, True)
return nn.reshape(sample_index, sample_shape, name=name)

# multinomial sample shape is (logits.shape[:-1], num_samples), need to
# tanspose to (num_samples, logits.shape[:-1])
permute = list(range(sample_index.dim()))
permute.insert(0, permute.pop(-1))
sample_index = sample_index.transpose(permute)

return paddle.reshape(sample_index, sample_shape, name=name)

def kl_divergence(self, other):
"""The KL-divergence between two Categorical distributions.
Expand Down Expand Up @@ -202,19 +210,19 @@ def kl_divergence(self, other):
check_type(other, 'other', Categorical, 'kl_divergence')

logits = self.logits - \
nn.reduce_max(self.logits, dim=-1, keep_dim=True)
other_logits = other.logits - nn.reduce_max(
other.logits, dim=-1, keep_dim=True)
paddle.max(self.logits, axis=-1, keepdim=True)
other_logits = other.logits - paddle.max(
other.logits, axis=-1, keepdim=True)
e_logits = ops.exp(logits)
other_e_logits = ops.exp(other_logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
other_z = nn.reduce_sum(other_e_logits, dim=-1, keep_dim=True)
z = paddle.sum(e_logits, axis=-1, keepdim=True)
other_z = paddle.sum(other_e_logits, axis=-1, keepdim=True)
prob = e_logits / z
kl = nn.reduce_sum(
prob * (logits - nn.log(z) - other_logits + nn.log(other_z)),
dim=-1,
keep_dim=True,
name=name)
kl = paddle.sum(prob * (
logits - paddle.log(z) - other_logits + paddle.log(other_z)),
axis=-1,
keepdim=True,
name=name)

return kl

Expand Down Expand Up @@ -244,14 +252,13 @@ def entropy(self):
"""
name = self.name + '_entropy'
logits = self.logits - \
nn.reduce_max(self.logits, dim=-1, keep_dim=True)
paddle.max(self.logits, axis=-1, keepdim=True)
e_logits = ops.exp(logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
z = paddle.sum(e_logits, axis=-1, keepdim=True)
prob = e_logits / z

neg_entropy = nn.reduce_sum(
prob * (logits - nn.log(z)), dim=-1, keep_dim=True)
entropy = nn.scale(neg_entropy, scale=-1.0, name=name)
neg_entropy = paddle.sum(prob * (logits - paddle.log(z)), axis=-1)
entropy = paddle.scale(neg_entropy, scale=-1.0, name=name)
return entropy

def probs(self, value):
Expand Down Expand Up @@ -291,41 +298,41 @@ def probs(self, value):
"""
name = self.name + '_probs'

dist_sum = nn.reduce_sum(self.logits, dim=-1, keep_dim=True)
dist_sum = paddle.sum(self.logits, axis=-1, keepdim=True)
prob = self.logits / dist_sum

shape = list(prob.shape)
value_shape = list(value.shape)
if len(shape) == 1:
num_value_in_one_dist = np.prod(value_shape)
index_value = nn.reshape(value, [num_value_in_one_dist, 1])
index_value = paddle.reshape(value, [num_value_in_one_dist, 1])
index = index_value
else:
num_dist = np.prod(shape[:-1])
num_value_in_one_dist = value_shape[-1]
prob = nn.reshape(prob, [num_dist, shape[-1]])
prob = paddle.reshape(prob, [num_dist, shape[-1]])
if len(value_shape) == 1:
value = nn.expand(value, [num_dist])
value_shape = shape[:-1] + value_shape
index_value = nn.reshape(value, [num_dist, -1, 1])
index_value = paddle.reshape(value, [num_dist, -1, 1])
if shape[:-1] != value_shape[:-1]:
raise ValueError(
"shape of value {} must match shape of logits {}".format(
str(value_shape[:-1]), str(shape[:-1])))

index_prefix = nn.unsqueeze(
index_prefix = paddle.unsqueeze(
arange(
num_dist, dtype=index_value.dtype), axes=-1)
num_dist, dtype=index_value.dtype), axis=-1)
index_prefix = nn.expand(index_prefix, [1, num_value_in_one_dist])
index_prefix = nn.unsqueeze(index_prefix, axes=-1)
index_prefix = paddle.unsqueeze(index_prefix, axis=-1)

if index_value.dtype != index_prefix.dtype:
tensor.cast(index_prefix, dtype=index_value.dtype)
index = concat([index_prefix, index_value], axis=-1)

# value is the category index to search for the corresponding probability.
select_prob = gather_nd(prob, index)
return nn.reshape(select_prob, value_shape, name=name)
return paddle.reshape(select_prob, value_shape, name=name)

def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method.
Expand Down Expand Up @@ -357,4 +364,4 @@ def log_prob(self, value):
"""
name = self.name + '_log_prob'

return nn.log(self.probs(value), name=name)
return paddle.log(self.probs(value), name=name)
45 changes: 27 additions & 18 deletions python/paddle/distribution/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,37 @@

class Dirichlet(ExponentialFamily):
r"""
Dirichlet distribution with parameter concentration.
Dirichlet distribution with parameter "concentration".
The Dirichlet distribution is defined over the `(k-1)-simplex` using a
positive, lenght-k vector concentration(`k > 1`).
The Dirichlet is identically the Beta distribution when `k = 2`.
For independent and identically distributed continuous random variable :math:`\boldsymbol X \in R_k` , and support :math:`\boldsymbol X \in (0,1), ||\boldsymbol X|| = 1` , The probability density function (pdf) is
For independent and identically distributed continuous random variable
:math:`\boldsymbol X \in R_k` , and support
:math:`\boldsymbol X \in (0,1), ||\boldsymbol X|| = 1` ,
The probability density function (pdf) is
.. math::
f(\boldsymbol X; \boldsymbol \alpha) = \frac{1}{B(\boldsymbol \alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1}
where :math:`\boldsymbol \alpha = {\alpha_1,...,\alpha_k}, k \ge 2` is parameter, the normalizing constant is the multivariate beta function.
where :math:`\boldsymbol \alpha = {\alpha_1,...,\alpha_k}, k \ge 2` is
parameter, the normalizing constant is the multivariate beta function.
.. math::
B(\boldsymbol \alpha) = \frac{\prod_{i=1}^{k} \Gamma(\alpha_i)}{\Gamma(\alpha_0)}
:math:`\alpha_0=\sum_{i=1}^{k} \alpha_i` is the sum of parameters, :math:`\Gamma(\alpha)` is gamma function。
:math:`\alpha_0=\sum_{i=1}^{k} \alpha_i` is the sum of parameters,
:math:`\Gamma(\alpha)` is gamma function.
Args:
concentration (Tensor): concentration parameter of dirichlet distribution, also called :math:`\alpha`. when concentration over one dimension,the last axis is parameter of distribution, ``event_shape=concentration.shape[-1:]`` , other axes is batch with ``batch_shape=concentration.shape[:-1]`` .
concentration (Tensor): "Concentration" parameter of dirichlet
distribution, also called :math:`\alpha`. When it's over one
dimension, the last axis denotes the parameter of distribution,
``event_shape=concentration.shape[-1:]`` , axes other than last are
condsider batch dimensions with ``batch_shape=concentration.shape[:-1]`` .
Examples:
Expand Down Expand Up @@ -73,59 +82,59 @@ def __init__(self, concentration):

@property
def mean(self):
"""mean of Dirichelt distribution.
"""Mean of Dirichelt distribution.
Returns:
mean value of distribution.
Mean value of distribution.
"""
return self.concentration / self.concentration.sum(-1, keepdim=True)

@property
def variance(self):
"""variance of Dirichlet distribution.
"""Variance of Dirichlet distribution.
Returns:
variance value of distribution.
Variance value of distribution.
"""
concentration0 = self.concentration.sum(-1, keepdim=True)
return (self.concentration * (concentration0 - self.concentration)) / (
concentration0.pow(2) * (concentration0 + 1))

def sample(self, shape=()):
"""sample from dirichlet distribution.
"""Sample from dirichlet distribution.
Args:
shape (Sequence[int], optional): sample shape. Defaults to empty tuple.
shape (Sequence[int], optional): Sample shape. Defaults to empty tuple.
"""
shape = shape if isinstance(shape, tuple) else tuple(shape)
return _dirichlet(self.concentration.expand(self._extend_shape(shape)))

def prob(self, value):
"""Probability density function(pdf) evaluated at value.
"""Probability density function(PDF) evaluated at value.
Args:
value (Tensor): value to be evaluated.
value (Tensor): Value to be evaluated.
Returns:
pdf evaluated at value.
PDF evaluated at value.
"""
return paddle.exp(self.log_prob(value))

def log_prob(self, value):
"""log of probability densitiy function.
"""Log of probability densitiy function.
Args:
value (Tensor): value to be evaluated.
value (Tensor): Value to be evaluated.
"""
return ((paddle.log(value) * (self.concentration - 1.0)
).sum(-1) + paddle.lgamma(self.concentration.sum(-1)) -
paddle.lgamma(self.concentration).sum(-1))

def entropy(self):
"""entropy of Dirichlet distribution.
"""Entropy of Dirichlet distribution.
Returns:
entropy of distribution.
Entropy of distribution.
"""
concentration0 = self.concentration.sum(-1)
k = self.concentration.shape[-1]
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ def log_prob(self, value):
raise NotImplementedError

def probs(self, value):
"""Probability density/mass function."""
"""Probability density/mass function.
.. note::
This method will be deprecated in the future, please use `prob`
instead.
"""
raise NotImplementedError

def _extend_shape(self, sample_shape):
Expand Down
Loading

0 comments on commit 852300d

Please sign in to comment.