Skip to content
This repository has been archived by the owner on Aug 18, 2023. It is now read-only.

Commit

Permalink
Log partition estimator and derivative (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
zaqqwerty authored Feb 15, 2022
1 parent 2fbc816 commit 5c0c96d
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 8 deletions.
46 changes: 44 additions & 2 deletions qhbmlib/energy_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,48 @@ def grad_fn(*upstream, variables):

return _inner_expectation()

def _log_partition(self):
"""Default implementation wrapped by `self.log_partition`."""

@tf.custom_gradient
def _inner_log_partition():
"""Wraps forward pass computaton."""
result = self._log_partition_forward_pass()
# Adds variables in `self.energy` to `variables` argument of `grad_fn`.
_ = [tf.identity(x) for x in self.energy.trainable_variables]
grad_fn = self._log_partition_grad_generator()
return result, grad_fn

return _inner_log_partition()

@abc.abstractmethod
def _log_partition_forward_pass(self):
"""Returns approximation to the log partition function."""
raise NotImplementedError()

def _log_partition_grad_generator(self):
"""Returns default estimator for the log partition function derivative."""

def grad_fn(upstream, variables):
"""See equation C2 in the appendix. TODO(#119)"""

def energy_grad(bitstrings):
"""Calculates the derivative with respect to the current variables."""
with tf.GradientTape() as tape:
energies = self.energy(bitstrings)
jac = tape.jacobian(
energies,
variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
return jac

energy_grad_expectation_list = self.expectation(energy_grad)
return tuple(), [
upstream * (-1.0 * ege) for ege in energy_grad_expectation_list
]

return grad_fn


class AnalyticEnergyInference(EnergyInference):
"""Uses an explicit categorical distribution to implement parent functions."""
Expand Down Expand Up @@ -340,7 +382,7 @@ def _entropy(self):
"""See base class docstring."""
return self.distribution.entropy()

def _log_partition(self):
def _log_partition_forward_pass(self):
"""See base class docstring."""
# TODO(#115)
return tf.reduce_logsumexp(self.distribution.logits_parameter())
Expand Down Expand Up @@ -402,7 +444,7 @@ def _entropy(self):
"""
return tf.reduce_sum(self.distribution.entropy())

def _log_partition(self):
def _log_partition_forward_pass(self):
r"""Returns the exact log partition function.
For a single spin of energy $\theta$, the partition function is
Expand Down
68 changes: 62 additions & 6 deletions tests/energy_infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _entropy(self):
"""Not implemented in this test class."""
raise NotImplementedError()

def _log_partition(self):
def _log_partition_forward_pass(self):
"""Not implemented in this test class."""
raise NotImplementedError()

Expand Down Expand Up @@ -101,6 +101,9 @@ def test_function(bitstrings):

self.test_function = test_function

self.tf_random_seed = 4
self.close_rtol = 1e-2

@test_util.eager_mode_toggle
def test_expectation(self):
"""Confirms correct averaging over input function."""
Expand Down Expand Up @@ -503,9 +506,35 @@ def test_log_partition(self):
actual_layer.infer(energy)

log_partition_wrapper = tf.function(actual_layer.log_partition)
actual_log_partition = log_partition_wrapper()
with tf.GradientTape() as tape:
actual_log_partition = log_partition_wrapper()
self.assertAllClose(actual_log_partition, expected_log_partition)

old_kernel = energy.post_process[0].kernel.read_value()
kernel_len = tf.shape(old_kernel)[0].numpy().tolist()
all_bitstrings = tf.constant([[0, 0], [0, 1], [1, 0], [1, 1]],
dtype=tf.int8)

def exact_log_partition(k, delta):
"""Perturbs the kth variable and calculates the log partition."""
new_kernel = old_kernel + delta * tf.one_hot(k, kernel_len, 1.0, 0.0)
energy.set_weights([new_kernel])
delta_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings))
energy.set_weights([old_kernel])
return delta_log_partition

derivative_list = []
for k in range(kernel_len):
this_derivative = test_util.approximate_derivative(
functools.partial(exact_log_partition, k))
derivative_list.append(this_derivative.numpy())

expected_log_partition_grad = tf.constant([derivative_list])
actual_log_partition_grad = tape.gradient(actual_log_partition,
energy.trainable_variables)
self.assertAllClose(actual_log_partition_grad, expected_log_partition_grad,
self.close_rtol)

@test_util.eager_mode_toggle
def test_entropy(self):
"""Confirms correct value of the entropy function."""
Expand Down Expand Up @@ -659,20 +688,47 @@ def test_samples_seeded(self):

@test_util.eager_mode_toggle
def test_log_partition(self):
"""Confirms correct value of the log partition function."""
"""Confirms correct value of the log partition function and derivative."""
all_bitstrings = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]],
dtype=tf.int8)
energy = energy_model.BernoulliEnergy([5, 6, 7])
ebm_init = tf.keras.initializers.RandomUniform(
-2, -1, seed=self.tf_random_seed)
energy = energy_model.BernoulliEnergy([5, 6, 7], ebm_init)
energy.build([None, energy.num_bits])
actual_layer = energy_infer.BernoulliEnergyInference(3, self.num_samples)
actual_layer = energy_infer.BernoulliEnergyInference(
3, self.num_samples, self.tfp_seed)
actual_layer.infer(energy)
expected_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings))

log_partition_wrapper = tf.function(actual_layer.log_partition)
actual_log_partition = log_partition_wrapper()
with tf.GradientTape() as tape:
actual_log_partition = log_partition_wrapper()
self.assertAllClose(actual_log_partition, expected_log_partition)

old_kernel = energy.post_process[0].kernel.read_value()
kernel_len = tf.shape(old_kernel)[0].numpy().tolist()

def exact_log_partition(k, delta):
"""Perturbs the kth variable and calculates the log partition."""
new_kernel = old_kernel + delta * tf.one_hot(k, kernel_len, 1.0, 0.0)
energy.set_weights([new_kernel])
delta_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings))
energy.set_weights([old_kernel])
return delta_log_partition

derivative_list = []
for k in range(kernel_len):
this_derivative = test_util.approximate_derivative(
functools.partial(exact_log_partition, k))
derivative_list.append(this_derivative.numpy())

expected_log_partition_grad = tf.constant([derivative_list])
actual_log_partition_grad = tape.gradient(actual_log_partition,
energy.trainable_variables)
self.assertAllClose(actual_log_partition_grad, expected_log_partition_grad,
self.close_rtol)

@test_util.eager_mode_toggle
def test_entropy(self):
r"""Confirms that the entropy is S(p) = -\sum_x p(x)\ln(p(x)).
Expand Down

0 comments on commit 5c0c96d

Please sign in to comment.