diff --git a/qhbmlib/energy_infer.py b/qhbmlib/energy_infer.py index d2807866..5d508c71 100644 --- a/qhbmlib/energy_infer.py +++ b/qhbmlib/energy_infer.py @@ -283,6 +283,48 @@ def grad_fn(*upstream, variables): return _inner_expectation() + def _log_partition(self): + """Default implementation wrapped by `self.log_partition`.""" + + @tf.custom_gradient + def _inner_log_partition(): + """Wraps forward pass computaton.""" + result = self._log_partition_forward_pass() + # Adds variables in `self.energy` to `variables` argument of `grad_fn`. + _ = [tf.identity(x) for x in self.energy.trainable_variables] + grad_fn = self._log_partition_grad_generator() + return result, grad_fn + + return _inner_log_partition() + + @abc.abstractmethod + def _log_partition_forward_pass(self): + """Returns approximation to the log partition function.""" + raise NotImplementedError() + + def _log_partition_grad_generator(self): + """Returns default estimator for the log partition function derivative.""" + + def grad_fn(upstream, variables): + """See equation C2 in the appendix. TODO(#119)""" + + def energy_grad(bitstrings): + """Calculates the derivative with respect to the current variables.""" + with tf.GradientTape() as tape: + energies = self.energy(bitstrings) + jac = tape.jacobian( + energies, + variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO) + return jac + + energy_grad_expectation_list = self.expectation(energy_grad) + return tuple(), [ + upstream * (-1.0 * ege) for ege in energy_grad_expectation_list + ] + + return grad_fn + class AnalyticEnergyInference(EnergyInference): """Uses an explicit categorical distribution to implement parent functions.""" @@ -340,7 +382,7 @@ def _entropy(self): """See base class docstring.""" return self.distribution.entropy() - def _log_partition(self): + def _log_partition_forward_pass(self): """See base class docstring.""" # TODO(#115) return tf.reduce_logsumexp(self.distribution.logits_parameter()) @@ -402,7 +444,7 @@ def _entropy(self): """ return tf.reduce_sum(self.distribution.entropy()) - def _log_partition(self): + def _log_partition_forward_pass(self): r"""Returns the exact log partition function. For a single spin of energy $\theta$, the partition function is diff --git a/tests/energy_infer_test.py b/tests/energy_infer_test.py index c3de4f83..9dfb8bfb 100644 --- a/tests/energy_infer_test.py +++ b/tests/energy_infer_test.py @@ -69,7 +69,7 @@ def _entropy(self): """Not implemented in this test class.""" raise NotImplementedError() - def _log_partition(self): + def _log_partition_forward_pass(self): """Not implemented in this test class.""" raise NotImplementedError() @@ -101,6 +101,9 @@ def test_function(bitstrings): self.test_function = test_function + self.tf_random_seed = 4 + self.close_rtol = 1e-2 + @test_util.eager_mode_toggle def test_expectation(self): """Confirms correct averaging over input function.""" @@ -503,9 +506,35 @@ def test_log_partition(self): actual_layer.infer(energy) log_partition_wrapper = tf.function(actual_layer.log_partition) - actual_log_partition = log_partition_wrapper() + with tf.GradientTape() as tape: + actual_log_partition = log_partition_wrapper() self.assertAllClose(actual_log_partition, expected_log_partition) + old_kernel = energy.post_process[0].kernel.read_value() + kernel_len = tf.shape(old_kernel)[0].numpy().tolist() + all_bitstrings = tf.constant([[0, 0], [0, 1], [1, 0], [1, 1]], + dtype=tf.int8) + + def exact_log_partition(k, delta): + """Perturbs the kth variable and calculates the log partition.""" + new_kernel = old_kernel + delta * tf.one_hot(k, kernel_len, 1.0, 0.0) + energy.set_weights([new_kernel]) + delta_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings)) + energy.set_weights([old_kernel]) + return delta_log_partition + + derivative_list = [] + for k in range(kernel_len): + this_derivative = test_util.approximate_derivative( + functools.partial(exact_log_partition, k)) + derivative_list.append(this_derivative.numpy()) + + expected_log_partition_grad = tf.constant([derivative_list]) + actual_log_partition_grad = tape.gradient(actual_log_partition, + energy.trainable_variables) + self.assertAllClose(actual_log_partition_grad, expected_log_partition_grad, + self.close_rtol) + @test_util.eager_mode_toggle def test_entropy(self): """Confirms correct value of the entropy function.""" @@ -659,20 +688,47 @@ def test_samples_seeded(self): @test_util.eager_mode_toggle def test_log_partition(self): - """Confirms correct value of the log partition function.""" + """Confirms correct value of the log partition function and derivative.""" all_bitstrings = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int8) - energy = energy_model.BernoulliEnergy([5, 6, 7]) + ebm_init = tf.keras.initializers.RandomUniform( + -2, -1, seed=self.tf_random_seed) + energy = energy_model.BernoulliEnergy([5, 6, 7], ebm_init) energy.build([None, energy.num_bits]) - actual_layer = energy_infer.BernoulliEnergyInference(3, self.num_samples) + actual_layer = energy_infer.BernoulliEnergyInference( + 3, self.num_samples, self.tfp_seed) actual_layer.infer(energy) expected_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings)) log_partition_wrapper = tf.function(actual_layer.log_partition) - actual_log_partition = log_partition_wrapper() + with tf.GradientTape() as tape: + actual_log_partition = log_partition_wrapper() self.assertAllClose(actual_log_partition, expected_log_partition) + old_kernel = energy.post_process[0].kernel.read_value() + kernel_len = tf.shape(old_kernel)[0].numpy().tolist() + + def exact_log_partition(k, delta): + """Perturbs the kth variable and calculates the log partition.""" + new_kernel = old_kernel + delta * tf.one_hot(k, kernel_len, 1.0, 0.0) + energy.set_weights([new_kernel]) + delta_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings)) + energy.set_weights([old_kernel]) + return delta_log_partition + + derivative_list = [] + for k in range(kernel_len): + this_derivative = test_util.approximate_derivative( + functools.partial(exact_log_partition, k)) + derivative_list.append(this_derivative.numpy()) + + expected_log_partition_grad = tf.constant([derivative_list]) + actual_log_partition_grad = tape.gradient(actual_log_partition, + energy.trainable_variables) + self.assertAllClose(actual_log_partition_grad, expected_log_partition_grad, + self.close_rtol) + @test_util.eager_mode_toggle def test_entropy(self): r"""Confirms that the entropy is S(p) = -\sum_x p(x)\ln(p(x)).