-
-
Notifications
You must be signed in to change notification settings - Fork 372
/
Copy pathbase_nuts.hpp
365 lines (292 loc) · 11.9 KB
/
base_nuts.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
#define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
#include <stan/callbacks/logger.hpp>
#include <stan/math/prim.hpp>
#include <stan/mcmc/hmc/base_hmc.hpp>
#include <stan/mcmc/hmc/hamiltonians/ps_point.hpp>
#include <algorithm>
#include <cmath>
#include <limits>
#include <string>
#include <vector>
namespace stan {
namespace mcmc {
/**
* The No-U-Turn sampler (NUTS) with multinomial sampling
*/
template <class Model, template <class, class> class Hamiltonian,
template <class> class Integrator, class BaseRNG>
class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
public:
base_nuts(const Model& model, BaseRNG& rng)
: base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
depth_(0),
max_depth_(5),
max_deltaH_(1000),
n_leapfrog_(0),
divergent_(false),
energy_(0) {}
/**
* specialized constructor for specified diag mass matrix
*/
base_nuts(const Model& model, BaseRNG& rng, Eigen::VectorXd& inv_e_metric)
: base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng,
inv_e_metric),
depth_(0),
max_depth_(5),
max_deltaH_(1000),
n_leapfrog_(0),
divergent_(false),
energy_(0) {}
/**
* specialized constructor for specified dense mass matrix
*/
base_nuts(const Model& model, BaseRNG& rng, Eigen::MatrixXd& inv_e_metric)
: base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng,
inv_e_metric),
depth_(0),
max_depth_(5),
max_deltaH_(1000),
n_leapfrog_(0),
divergent_(false),
energy_(0) {}
~base_nuts() {}
void set_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
}
void set_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
}
void set_max_depth(int d) {
if (d > 0)
max_depth_ = d;
}
void set_max_delta(double d) { max_deltaH_ = d; }
int get_max_depth() { return this->max_depth_; }
double get_max_delta() { return this->max_deltaH_; }
sample transition(sample& init_sample, callbacks::logger& logger) {
// Initialize the algorithm
this->sample_stepsize();
this->seed(init_sample.cont_params());
this->hamiltonian_.sample_p(this->z_, this->rand_int_);
this->hamiltonian_.init(this->z_, logger);
ps_point z_fwd(this->z_); // State at forward end of trajectory
ps_point z_bck(z_fwd); // State at backward end of trajectory
ps_point z_sample(z_fwd);
ps_point z_propose(z_fwd);
// Momentum and sharp momentum at forward end of forward subtree
Eigen::VectorXd p_fwd_fwd = this->z_.p;
Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_);
// Momentum and sharp momentum at backward end of forward subtree
Eigen::VectorXd p_fwd_bck = this->z_.p;
Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd;
// Momentum and sharp momentum at forward end of backward subtree
Eigen::VectorXd p_bck_fwd = this->z_.p;
Eigen::VectorXd p_sharp_bck_fwd = p_sharp_fwd_fwd;
// Momentum and sharp momentum at backward end of backward subtree
Eigen::VectorXd p_bck_bck = this->z_.p;
Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd;
// Integrated momenta along trajectory
Eigen::VectorXd rho = this->z_.p.transpose();
// Log sum of state weights (offset by H0) along trajectory
double log_sum_weight = 0; // log(exp(H0 - H0))
double H0 = this->hamiltonian_.H(this->z_);
int n_leapfrog = 0;
double sum_metro_prob = 0;
// Build a trajectory until the no-u-turn
// criterion is no longer satisfied
this->depth_ = 0;
this->divergent_ = false;
while (this->depth_ < this->max_depth_) {
// Build a new subtree in a random direction
Eigen::VectorXd rho_fwd = Eigen::VectorXd::Zero(rho.size());
Eigen::VectorXd rho_bck = Eigen::VectorXd::Zero(rho.size());
bool valid_subtree = false;
double log_sum_weight_subtree = -std::numeric_limits<double>::infinity();
if (this->rand_uniform_() > 0.5) {
// Extend the current trajectory forward
this->z_.ps_point::operator=(z_fwd);
rho_bck = rho;
p_bck_fwd = p_fwd_fwd;
p_sharp_bck_fwd = p_sharp_fwd_fwd;
valid_subtree = build_tree(
this->depth_, z_propose, p_sharp_fwd_bck, p_sharp_fwd_fwd, rho_fwd,
p_fwd_bck, p_fwd_fwd, H0, 1, n_leapfrog, log_sum_weight_subtree,
sum_metro_prob, logger);
z_fwd.ps_point::operator=(this->z_);
} else {
// Extend the current trajectory backwards
this->z_.ps_point::operator=(z_bck);
rho_fwd = rho;
p_fwd_bck = p_bck_bck;
p_sharp_fwd_bck = p_sharp_bck_bck;
valid_subtree = build_tree(
this->depth_, z_propose, p_sharp_bck_fwd, p_sharp_bck_bck, rho_bck,
p_bck_fwd, p_bck_bck, H0, -1, n_leapfrog, log_sum_weight_subtree,
sum_metro_prob, logger);
z_bck.ps_point::operator=(this->z_);
}
if (!valid_subtree)
break;
// Sample from accepted subtree
++(this->depth_);
if (log_sum_weight_subtree > log_sum_weight) {
z_sample = z_propose;
} else {
double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight);
if (this->rand_uniform_() < accept_prob)
z_sample = z_propose;
}
log_sum_weight
= math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
// Break when no-u-turn criterion is no longer satisfied
rho = rho_bck + rho_fwd;
// Demand satisfaction around merged subtrees
bool persist_criterion
= compute_criterion(p_sharp_bck_bck, p_sharp_fwd_fwd, rho);
// Demand satisfaction between subtrees
Eigen::VectorXd rho_extended = rho_bck + p_fwd_bck;
persist_criterion
&= compute_criterion(p_sharp_bck_bck, p_sharp_fwd_bck, rho_extended);
rho_extended = rho_fwd + p_bck_fwd;
persist_criterion
&= compute_criterion(p_sharp_bck_fwd, p_sharp_fwd_fwd, rho_extended);
if (!persist_criterion)
break;
}
this->n_leapfrog_ = n_leapfrog;
// Compute average acceptance probability across entire trajectory,
// even over subtrees that may have been rejected
double accept_prob = sum_metro_prob / static_cast<double>(n_leapfrog);
this->z_.ps_point::operator=(z_sample);
this->energy_ = this->hamiltonian_.H(this->z_);
return sample(this->z_.q, -this->z_.V, accept_prob);
}
void get_sampler_param_names(std::vector<std::string>& names) {
names.push_back("stepsize__");
names.push_back("treedepth__");
names.push_back("n_leapfrog__");
names.push_back("divergent__");
names.push_back("energy__");
}
void get_sampler_params(std::vector<double>& values) {
values.push_back(this->epsilon_);
values.push_back(this->depth_);
values.push_back(this->n_leapfrog_);
values.push_back(this->divergent_);
values.push_back(this->energy_);
}
virtual bool compute_criterion(Eigen::VectorXd& p_sharp_minus,
Eigen::VectorXd& p_sharp_plus,
Eigen::VectorXd& rho) {
return p_sharp_plus.dot(rho) > 0 && p_sharp_minus.dot(rho) > 0;
}
/**
* Recursively build a new subtree to completion or until
* the subtree becomes invalid. Returns validity of the
* resulting subtree.
*
* @param depth Depth of the desired subtree
* @param z_propose State proposed from subtree
* @param p_sharp_beg Sharp momentum at beginning of new tree
* @param p_sharp_end Sharp momentum at end of new tree
* @param rho Summed momentum across trajectory
* @param p_beg Momentum at beginning of returned tree
* @param p_end Momentum at end of returned tree
* @param H0 Hamiltonian of initial state
* @param sign Direction in time to built subtree
* @param n_leapfrog Summed number of leapfrog evaluations
* @param log_sum_weight Log of summed weights across trajectory
* @param sum_metro_prob Summed Metropolis probabilities across trajectory
* @param logger Logger for messages
*/
bool build_tree(int depth, ps_point& z_propose, Eigen::VectorXd& p_sharp_beg,
Eigen::VectorXd& p_sharp_end, Eigen::VectorXd& rho,
Eigen::VectorXd& p_beg, Eigen::VectorXd& p_end, double H0,
double sign, int& n_leapfrog, double& log_sum_weight,
double& sum_metro_prob, callbacks::logger& logger) {
// Base case
if (depth == 0) {
this->integrator_.evolve(this->z_, this->hamiltonian_,
sign * this->epsilon_, logger);
++n_leapfrog;
double h = this->hamiltonian_.H(this->z_);
if (std::isnan(h))
h = std::numeric_limits<double>::infinity();
if ((h - H0) > this->max_deltaH_)
this->divergent_ = true;
log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);
if (H0 - h > 0)
sum_metro_prob += 1;
else
sum_metro_prob += std::exp(H0 - h);
z_propose = this->z_;
p_sharp_beg = this->hamiltonian_.dtau_dp(this->z_);
p_sharp_end = p_sharp_beg;
rho += this->z_.p;
p_beg = this->z_.p;
p_end = p_beg;
return !this->divergent_;
}
// General recursion
// Build the initial subtree
double log_sum_weight_init = -std::numeric_limits<double>::infinity();
// Momentum and sharp momentum at end of the initial subtree
Eigen::VectorXd p_init_end(this->z_.p.size());
Eigen::VectorXd p_sharp_init_end(this->z_.p.size());
Eigen::VectorXd rho_init = Eigen::VectorXd::Zero(rho.size());
bool valid_init
= build_tree(depth - 1, z_propose, p_sharp_beg, p_sharp_init_end,
rho_init, p_beg, p_init_end, H0, sign, n_leapfrog,
log_sum_weight_init, sum_metro_prob, logger);
if (!valid_init)
return false;
// Build the final subtree
ps_point z_propose_final(this->z_);
double log_sum_weight_final = -std::numeric_limits<double>::infinity();
// Momentum and sharp momentum at beginning of the final subtree
Eigen::VectorXd p_final_beg(this->z_.p.size());
Eigen::VectorXd p_sharp_final_beg(this->z_.p.size());
Eigen::VectorXd rho_final = Eigen::VectorXd::Zero(rho.size());
bool valid_final
= build_tree(depth - 1, z_propose_final, p_sharp_final_beg, p_sharp_end,
rho_final, p_final_beg, p_end, H0, sign, n_leapfrog,
log_sum_weight_final, sum_metro_prob, logger);
if (!valid_final)
return false;
// Multinomial sample from right subtree
double log_sum_weight_subtree
= math::log_sum_exp(log_sum_weight_init, log_sum_weight_final);
log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
if (log_sum_weight_final > log_sum_weight_subtree) {
z_propose = z_propose_final;
} else {
double accept_prob
= std::exp(log_sum_weight_final - log_sum_weight_subtree);
if (this->rand_uniform_() < accept_prob)
z_propose = z_propose_final;
}
Eigen::VectorXd rho_subtree = rho_init + rho_final;
rho += rho_subtree;
// Demand satisfaction around merged subtrees
bool persist_criterion
= compute_criterion(p_sharp_beg, p_sharp_end, rho_subtree);
// Demand satisfaction between subtrees
rho_subtree = rho_init + p_final_beg;
persist_criterion
&= compute_criterion(p_sharp_beg, p_sharp_final_beg, rho_subtree);
rho_subtree = rho_final + p_init_end;
persist_criterion
&= compute_criterion(p_sharp_init_end, p_sharp_end, rho_subtree);
return persist_criterion;
}
int depth_;
int max_depth_;
double max_deltaH_;
int n_leapfrog_;
bool divergent_;
double energy_;
};
} // namespace mcmc
} // namespace stan
#endif