Skip to content
This repository has been archived by the owner on Aug 18, 2023. It is now read-only.

Log partition estimator and derivative #180

Merged
merged 57 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
0eb7afa
Enable tracing of QuantumCircuit __add__ (#131)
zaqqwerty Jan 14, 2022
081d53a
Weighted average utility (#134)
zaqqwerty Jan 14, 2022
572ca41
Add new style QHBMs (#125)
zaqqwerty Jan 15, 2022
08a0814
Replace custom code with TF op (#141)
zaqqwerty Jan 19, 2022
a802d93
Enable EBM seed (#143)
zaqqwerty Jan 20, 2022
130d6b4
Add modular Hamiltonian expectation (#135)
zaqqwerty Jan 21, 2022
824c8e2
Energy expectation (#150)
zaqqwerty Jan 24, 2022
3b6724a
Update TF, TFP, and TFQ versions (#152)
zaqqwerty Jan 25, 2022
10b9efe
Utility function tests (#153)
zaqqwerty Jan 28, 2022
39a281b
Eager toggle decoration completion (#155)
zaqqwerty Jan 31, 2022
508234e
Merge branch 'google:main' into main
farice Feb 1, 2022
6d4c1d5
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 3, 2022
3b4419e
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 6, 2022
14a10c2
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 8, 2022
db2b402
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 10, 2022
ff0383d
need to update unique function
zaqqwerty Feb 11, 2022
b5e5135
unique tests passing
zaqqwerty Feb 11, 2022
62141cf
update call
zaqqwerty Feb 11, 2022
1365d77
revert circuit
zaqqwerty Feb 11, 2022
00023d9
update signature everywhere
zaqqwerty Feb 11, 2022
d03f230
correction
zaqqwerty Feb 11, 2022
b8bcb5d
update tests with inverse
zaqqwerty Feb 11, 2022
a123933
format
zaqqwerty Feb 11, 2022
8d3d7e5
add more tests
zaqqwerty Feb 11, 2022
a48dee9
format
zaqqwerty Feb 11, 2022
d788a02
remove lint
zaqqwerty Feb 11, 2022
92cb271
update expectation signature
zaqqwerty Feb 12, 2022
e1eae3f
updates
zaqqwerty Feb 12, 2022
42af703
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 12, 2022
e900b0f
Merge branch 'main' into 158_move_sample_arg
zaqqwerty Feb 12, 2022
5cd5697
tests passing
zaqqwerty Feb 13, 2022
c99fe28
update hamiltonian inference
zaqqwerty Feb 13, 2022
858fa8e
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 13, 2022
d457293
Merge branch 'main' into 158_move_sample_arg
zaqqwerty Feb 13, 2022
b11ed8c
format
zaqqwerty Feb 13, 2022
d14eb62
add num_samples arg
zaqqwerty Feb 13, 2022
8946a16
format
zaqqwerty Feb 13, 2022
9b6dd9d
hamiltonian infer tests passing
zaqqwerty Feb 13, 2022
96f6252
update
zaqqwerty Feb 13, 2022
cfa632f
format
zaqqwerty Feb 13, 2022
94c3e68
update docstring
zaqqwerty Feb 13, 2022
76548ed
update docstring
zaqqwerty Feb 13, 2022
59bef9a
fix loss with squeeze
zaqqwerty Feb 13, 2022
65be26f
format
zaqqwerty Feb 13, 2022
7942337
add initial log partition code
zaqqwerty Feb 14, 2022
be0367d
Merge branch '158_move_sample_arg' into 140_new_partition
zaqqwerty Feb 14, 2022
97554ed
update subclass methods
zaqqwerty Feb 14, 2022
1e43467
add bernoulli test
zaqqwerty Feb 14, 2022
e720c46
partition function tests passing
zaqqwerty Feb 14, 2022
6689cd2
default test passing
zaqqwerty Feb 14, 2022
efd0c16
format
zaqqwerty Feb 14, 2022
cef941f
remove lint
zaqqwerty Feb 14, 2022
e912c99
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 14, 2022
66e8318
Merge branch 'main' into 140_new_partition
zaqqwerty Feb 14, 2022
3feb6e9
remove forward pass
zaqqwerty Feb 15, 2022
b05934b
remove lint
zaqqwerty Feb 15, 2022
96c6e44
add error
zaqqwerty Feb 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions qhbmlib/energy_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,47 @@ def grad_fn(*upstream, variables):

return _inner_expectation()

def _log_partition(self):
"""Default implementation wrapped by `self.log_partition`."""

@tf.custom_gradient
def _inner_log_partition():
"""Wraps forward pass computaton."""
result = self._log_partition_forward_pass()
# Adds variables in `self.energy` to `variables` argument of `grad_fn`.
_ = [tf.identity(x) for x in self.energy.trainable_variables]
grad_fn = self._log_partition_grad_generator()
return result, grad_fn

return _inner_log_partition()

@abc.abstractmethod
def _log_partition_forward_pass(self):
"""Returns approximation to the log partition function."""

def _log_partition_grad_generator(self):
"""Returns default estimator for the log partition function derivative."""

def grad_fn(upstream, variables):
"""See equation C2 in the appendix. TODO(#119)"""

def energy_grad(bitstrings):
"""Calculates the derivative with respect to the current variables."""
with tf.GradientTape() as tape:
energies = self.energy(bitstrings)
jac = tape.jacobian(
energies,
variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
return jac

energy_grad_expectation_list = self.expectation(energy_grad)
return tuple(), [
upstream * (-1.0 * ege) for ege in energy_grad_expectation_list
]

return grad_fn


class AnalyticEnergyInference(EnergyInference):
"""Uses an explicit categorical distribution to implement parent functions."""
Expand Down Expand Up @@ -340,7 +381,7 @@ def _entropy(self):
"""See base class docstring."""
return self.distribution.entropy()

def _log_partition(self):
def _log_partition_forward_pass(self):
"""See base class docstring."""
# TODO(#115)
return tf.reduce_logsumexp(self.distribution.logits_parameter())
Expand Down Expand Up @@ -402,7 +443,7 @@ def _entropy(self):
"""
return tf.reduce_sum(self.distribution.entropy())

def _log_partition(self):
def _log_partition_forward_pass(self):
r"""Returns the exact log partition function.

For a single spin of energy $\theta$, the partition function is
Expand Down
68 changes: 62 additions & 6 deletions tests/energy_infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _entropy(self):
"""Not implemented in this test class."""
raise NotImplementedError()

def _log_partition(self):
def _log_partition_forward_pass(self):
"""Not implemented in this test class."""
raise NotImplementedError()

Expand Down Expand Up @@ -101,6 +101,9 @@ def test_function(bitstrings):

self.test_function = test_function

self.tf_random_seed = 4
self.close_rtol = 1e-2

@test_util.eager_mode_toggle
def test_expectation(self):
"""Confirms correct averaging over input function."""
Expand Down Expand Up @@ -503,9 +506,35 @@ def test_log_partition(self):
actual_layer.infer(energy)

log_partition_wrapper = tf.function(actual_layer.log_partition)
actual_log_partition = log_partition_wrapper()
with tf.GradientTape() as tape:
actual_log_partition = log_partition_wrapper()
self.assertAllClose(actual_log_partition, expected_log_partition)

old_kernel = energy.post_process[0].kernel.read_value()
kernel_len = tf.shape(old_kernel)[0].numpy().tolist()
all_bitstrings = tf.constant([[0, 0], [0, 1], [1, 0], [1, 1]],
dtype=tf.int8)

def exact_log_partition(k, delta):
"""Perturbs the kth variable and calculates the log partition."""
new_kernel = old_kernel + delta * tf.one_hot(k, kernel_len, 1.0, 0.0)
energy.set_weights([new_kernel])
delta_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings))
energy.set_weights([old_kernel])
return delta_log_partition

derivative_list = []
for k in range(kernel_len):
this_derivative = test_util.approximate_derivative(
functools.partial(exact_log_partition, k))
derivative_list.append(this_derivative.numpy())

expected_log_partition_grad = tf.constant([derivative_list])
actual_log_partition_grad = tape.gradient(actual_log_partition,
energy.trainable_variables)
self.assertAllClose(actual_log_partition_grad, expected_log_partition_grad,
self.close_rtol)

@test_util.eager_mode_toggle
def test_entropy(self):
"""Confirms correct value of the entropy function."""
Expand Down Expand Up @@ -659,20 +688,47 @@ def test_samples_seeded(self):

@test_util.eager_mode_toggle
def test_log_partition(self):
"""Confirms correct value of the log partition function."""
"""Confirms correct value of the log partition function and derivative."""
all_bitstrings = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]],
dtype=tf.int8)
energy = energy_model.BernoulliEnergy([5, 6, 7])
ebm_init = tf.keras.initializers.RandomUniform(
-2, -1, seed=self.tf_random_seed)
energy = energy_model.BernoulliEnergy([5, 6, 7], ebm_init)
energy.build([None, energy.num_bits])
actual_layer = energy_infer.BernoulliEnergyInference(3, self.num_samples)
actual_layer = energy_infer.BernoulliEnergyInference(
3, self.num_samples, self.tfp_seed)
actual_layer.infer(energy)
expected_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings))

log_partition_wrapper = tf.function(actual_layer.log_partition)
actual_log_partition = log_partition_wrapper()
with tf.GradientTape() as tape:
actual_log_partition = log_partition_wrapper()
self.assertAllClose(actual_log_partition, expected_log_partition)

old_kernel = energy.post_process[0].kernel.read_value()
kernel_len = tf.shape(old_kernel)[0].numpy().tolist()

def exact_log_partition(k, delta):
"""Perturbs the kth variable and calculates the log partition."""
new_kernel = old_kernel + delta * tf.one_hot(k, kernel_len, 1.0, 0.0)
energy.set_weights([new_kernel])
delta_log_partition = tf.reduce_logsumexp(-1.0 * energy(all_bitstrings))
energy.set_weights([old_kernel])
return delta_log_partition

derivative_list = []
for k in range(kernel_len):
this_derivative = test_util.approximate_derivative(
functools.partial(exact_log_partition, k))
derivative_list.append(this_derivative.numpy())

expected_log_partition_grad = tf.constant([derivative_list])
actual_log_partition_grad = tape.gradient(actual_log_partition,
energy.trainable_variables)
self.assertAllClose(actual_log_partition_grad, expected_log_partition_grad,
self.close_rtol)

@test_util.eager_mode_toggle
def test_entropy(self):
r"""Confirms that the entropy is S(p) = -\sum_x p(x)\ln(p(x)).
Expand Down