forked from blackjax-devs/blackjax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
280 lines (228 loc) · 10.5 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Metric space in which the Hamiltonian dynamic is embedded.
An important particular case (and the most used in practice) of metric for the
position space in the Euclidean metric. It is defined by a definite positive
matrix :math:`M` with fixed value so that the kinetic energy of the hamiltonian
dynamic is independent of the position and only depends on the momentum
:math:`p` :cite:p:`betancourt2017geometric`.
For a Newtonian hamiltonian dynamic the kinetic energy is given by:
.. math::
K(p) = \frac{1}{2} p^T M^{-1} p
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
from jax.scipy import stats as sp_stats
from blackjax.types import Array, PRNGKey, PyTree
from blackjax.util import generate_gaussian_noise
__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"]
class KineticEnergy(Protocol):
def __call__(self, momentum: PyTree, position: Optional[PyTree] = None) -> float:
...
class CheckTurning(Protocol):
def __call__(
self,
momentum_left: PyTree,
momentum_right: PyTree,
momentum_sum: PyTree,
position_left: Optional[PyTree] = None,
position_right: Optional[PyTree] = None,
) -> bool:
...
class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, PyTree], PyTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
MetricTypes = Union[Metric, Array, Callable[[PyTree], Array]]
def default_metric(metric: MetricTypes) -> Metric:
"""Convert an input metric into a ``Metric`` object following sensible default rules
The metric can be specified in three different ways:
- A ``Metric`` object that implements the full interface
- An ``Array`` which is assumed to specify the inverse mass matrix of a static
metric
- A function that takes a coordinate position and returns the mass matrix at that
location
"""
if isinstance(metric, Metric):
return metric
# If the argument is a callable, we assume that it returns the mass matrix
# at the given position and return the corresponding Riemannian metric.
if callable(metric):
return gaussian_riemannian(metric)
# If we make it here then the argument should be an array, and we'll assume
# that it specifies a static inverse mass matrix.
return gaussian_euclidean(metric)
def gaussian_euclidean(
inverse_mass_matrix: Array,
) -> Metric:
r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum
:cite:p:`betancourt2013general`.
The gaussian euclidean metric is a euclidean metric further characterized
by setting the conditional probability density :math:`\pi(momentum|position)`
to follow a standard gaussian distribution. A Newtonian hamiltonian
dynamics is assumed.
Parameters
----------
inverse_mass_matrix
One or two-dimensional array corresponding respectively to a diagonal
or dense mass matrix. The inverse mass matrix is multiplied to a
flattened version of the Pytree in which the chain position is stored
(the current value of the random variables). The order of the variables
should thus match JAX's tree flattening order, and more specifically
that of `ravel_pytree`.
In particular, JAX sorts dictionaries by key when flattening them. The
value of each variables will appear in the flattened Pytree following
the order given by `sort(keys)`.
Returns
-------
momentum_generator
A function that generates a value for the momentum at random.
kinetic_energy
A function that returns the kinetic energy given the momentum.
is_turning
A function that determines whether a trajectory is turning back on
itself given the values of the momentum along the trajectory.
"""
ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type]
shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type]
if ndim == 1: # diagonal mass matrix
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
matmul = jnp.multiply
elif ndim == 2:
# inverse mass matrix can be factored into L*L.T. We want the cholesky
# factor (inverse of L.T) of the mass matrix.
L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
L, identity, lower=True, trans=True
)
# Note that mass_matrix_sqrt is a upper triangular matrix here, with
# jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T)
# == inverse_mass_matrix
# An alternative is to compute directly the cholesky factor of the inverse mass
# matrix
# mass_matrix_sqrt = jscipy.linalg.cholesky(
# jscipy.linalg.inv(inverse_mass_matrix), lower=True)
# which the result would instead be a lower triangular matrix.
matmul = jnp.matmul
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {ndim}."
)
def momentum_generator(rng_key: PRNGKey, position: PyTree) -> PyTree:
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)
def kinetic_energy(momentum: PyTree, position: Optional[PyTree] = None) -> float:
del position
momentum, _ = ravel_pytree(momentum)
velocity = matmul(inverse_mass_matrix, momentum)
kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum)
return kinetic_energy_val
def is_turning(
momentum_left: PyTree,
momentum_right: PyTree,
momentum_sum: PyTree,
position_left: Optional[PyTree] = None,
position_right: Optional[PyTree] = None,
) -> bool:
"""Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`.
Parameters
----------
momentum_left
Momentum of the leftmost point of the trajectory.
momentum_right
Momentum of the rightmost point of the trajectory.
momentum_sum
Sum of the momenta along the trajectory.
"""
del position_left, position_right
m_left, _ = ravel_pytree(momentum_left)
m_right, _ = ravel_pytree(momentum_right)
m_sum, _ = ravel_pytree(momentum_sum)
velocity_left = matmul(inverse_mass_matrix, m_left)
velocity_right = matmul(inverse_mass_matrix, m_right)
# rho = m_sum
rho = m_sum - (m_right + m_left) / 2
turning_at_left = jnp.dot(velocity_left, rho) <= 0
turning_at_right = jnp.dot(velocity_right, rho) <= 0
return turning_at_left | turning_at_right
return Metric(momentum_generator, kinetic_energy, is_turning)
def gaussian_riemannian(
mass_matrix_fn: Callable,
) -> Metric:
def momentum_generator(rng_key: PRNGKey, position: PyTree) -> PyTree:
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
mass_matrix_sqrt = jnp.sqrt(mass_matrix)
elif ndim == 2:
mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True)
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt)
def kinetic_energy(momentum: PyTree, position: Optional[PyTree] = None) -> float:
if position is None:
raise ValueError(
"A Reinmannian kinetic energy function must be called with the "
"position specified; make sure to use a Reinmannian-capable "
"integrator like `implicit_midpoint`."
)
momentum, _ = ravel_pytree(momentum)
mass_matrix = mass_matrix_fn(position)
ndim = jnp.ndim(mass_matrix)
if ndim == 1:
return -jnp.sum(sp_stats.norm.logpdf(momentum, 0.0, jnp.sqrt(mass_matrix)))
elif ndim == 2:
return -sp_stats.multivariate_normal.logpdf(
momentum, jnp.zeros_like(momentum), mass_matrix
)
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {jnp.ndim(mass_matrix)}."
)
def is_turning(
momentum_left: PyTree,
momentum_right: PyTree,
momentum_sum: PyTree,
position_left: Optional[PyTree] = None,
position_right: Optional[PyTree] = None,
) -> bool:
del momentum_left, momentum_right, momentum_sum, position_left, position_right
raise NotImplementedError(
"NUTS sampling is not yet implemented for Riemannian manifolds"
)
# Here's a possible implementation of this function, but the NUTS
# proposal will require some refactoring to work properly, since we need
# to be able to access the coordinates at the left and right endpoints
# to compute the mass matrix at those points.
# m_left, _ = ravel_pytree(momentum_left)
# m_right, _ = ravel_pytree(momentum_right)
# m_sum, _ = ravel_pytree(momentum_sum)
# mass_matrix_left = mass_matrix_fn(position_left)
# mass_matrix_right = mass_matrix_fn(position_right)
# velocity_left = jnp.linalg.solve(mass_matrix_left, m_left)
# velocity_right = jnp.linalg.solve(mass_matrix_right, m_right)
# # rho = m_sum
# rho = m_sum - (m_right + m_left) / 2
# turning_at_left = jnp.dot(velocity_left, rho) <= 0
# turning_at_right = jnp.dot(velocity_right, rho) <= 0
# return turning_at_left | turning_at_right
return Metric(momentum_generator, kinetic_energy, is_turning)