Skip to content
This repository has been archived by the owner on Aug 18, 2023. It is now read-only.

Automate inference and make models arguments to inference engines #181

Merged
merged 55 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0eb7afa
Enable tracing of QuantumCircuit __add__ (#131)
zaqqwerty Jan 14, 2022
081d53a
Weighted average utility (#134)
zaqqwerty Jan 14, 2022
572ca41
Add new style QHBMs (#125)
zaqqwerty Jan 15, 2022
08a0814
Replace custom code with TF op (#141)
zaqqwerty Jan 19, 2022
a802d93
Enable EBM seed (#143)
zaqqwerty Jan 20, 2022
130d6b4
Add modular Hamiltonian expectation (#135)
zaqqwerty Jan 21, 2022
824c8e2
Energy expectation (#150)
zaqqwerty Jan 24, 2022
3b6724a
Update TF, TFP, and TFQ versions (#152)
zaqqwerty Jan 25, 2022
10b9efe
Utility function tests (#153)
zaqqwerty Jan 28, 2022
39a281b
Eager toggle decoration completion (#155)
zaqqwerty Jan 31, 2022
508234e
Merge branch 'google:main' into main
farice Feb 1, 2022
6d4c1d5
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 3, 2022
3b4419e
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 6, 2022
14a10c2
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 8, 2022
db2b402
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 10, 2022
42af703
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 12, 2022
858fa8e
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 13, 2022
37e4c8b
generalize QuantumInference expectation
zaqqwerty Feb 14, 2022
ca48730
removed reduced tests
zaqqwerty Feb 14, 2022
01d543b
inferece change
zaqqwerty Feb 14, 2022
9b114a6
revert
zaqqwerty Feb 14, 2022
e8d6418
genralize expectation
zaqqwerty Feb 14, 2022
3101bc5
circuit infer tests passing
zaqqwerty Feb 14, 2022
2a8c65a
update infer
zaqqwerty Feb 14, 2022
b538339
checkout energy files
zaqqwerty Feb 14, 2022
e912c99
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 14, 2022
bb53dff
Merge branch 'main' into 147_circuits_to_ham
zaqqwerty Feb 14, 2022
8f75cd5
format
zaqqwerty Feb 14, 2022
480b52c
circuit tests passing
zaqqwerty Feb 15, 2022
855e7fd
updated h infer test passing
zaqqwerty Feb 15, 2022
f40e2d2
circuit tests passing
zaqqwerty Feb 15, 2022
b8cb22c
inference tests passing
zaqqwerty Feb 15, 2022
0b9753d
format
zaqqwerty Feb 15, 2022
0e466cf
remove lint
zaqqwerty Feb 15, 2022
378f48b
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 15, 2022
7099edc
Merge branch 'main' into 147_circuits_to_ham
zaqqwerty Feb 15, 2022
b27c1b1
vqt tests passing
zaqqwerty Feb 15, 2022
f9b675c
format
zaqqwerty Feb 15, 2022
553dbfb
remove lint
zaqqwerty Feb 15, 2022
a6f2f75
start updating circuit infer
zaqqwerty Feb 15, 2022
4904110
infere
zaqqwerty Feb 15, 2022
85731bc
update infere
zaqqwerty Feb 15, 2022
e71735b
energy tests passing
zaqqwerty Feb 15, 2022
cb9bcd1
update infer
zaqqwerty Feb 15, 2022
2f7a2ea
Merge remote-tracking branch 'upstream/main' into main
zaqqwerty Feb 15, 2022
1b0bdd6
Merge branch 'main' into 115_new_infer_update
zaqqwerty Feb 15, 2022
c1603bd
move arg
zaqqwerty Feb 15, 2022
617c96b
hamuiltonian inference tests passing
zaqqwerty Feb 15, 2022
daa45d3
format
zaqqwerty Feb 15, 2022
bcad10d
circuit infer test
zaqqwerty Feb 15, 2022
4d1fd5c
remove lint
zaqqwerty Feb 15, 2022
b21cc0e
update vqt tets
zaqqwerty Feb 15, 2022
dc44570
format
zaqqwerty Feb 15, 2022
b5d1666
tolerance update
zaqqwerty Feb 15, 2022
0657d17
lint
zaqqwerty Feb 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions qhbmlib/circuit_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,23 @@ class QuantumInference(tf.keras.layers.Layer):
"""Methods for inference on QuantumCircuit objects."""

def __init__(self,
circuit: circuit_model.QuantumCircuit,
backend: Union[str, cirq.Sampler] = "noiseless",
differentiator: Union[None,
tfq.differentiators.Differentiator] = None,
name: Union[None, str] = None):
"""Initialize a QuantumInference layer.

Args:
circuit: The parameterized quantum circuit on which to do inference.
backend: Specifies what backend TFQ will use to compute expectation
values. `str` options are {"noisy", "noiseless"}; users may also specify
a preconfigured cirq execution object to use instead.
differentiator: Specifies how to take the derivative of a quantum circuit.
name: Identifier for this inference engine.
"""
circuit.build([])
self._circuit = circuit
self._differentiator = differentiator
self._backend = backend
self._sample_layer = tfq.layers.Sample(backend=backend)
Expand Down Expand Up @@ -79,17 +83,19 @@ def _expectation_function(circuits, symbol_names, symbol_values,
def backend(self):
return self._backend

@property
def circuit(self):
return self._circuit

@property
def differentiator(self):
return self._differentiator

def expectation(self, qnn: circuit_model.QuantumCircuit,
initial_states: tf.Tensor,
def expectation(self, initial_states: tf.Tensor,
observables: Union[tf.Tensor, hamiltonian_model.Hamiltonian]):
"""Returns the expectation values of the observables against the QNN.

Args:
qnn: The parameterized quantum circuit on which to do inference.
initial_states: Shape [batch_size, num_qubits] of dtype `tf.int8`.
Each entry is an initial state for the set of qubits. For each state,
`qnn` is applied and the pure state expectation value is calculated.
Expand All @@ -104,11 +110,11 @@ def expectation(self, qnn: circuit_model.QuantumCircuit,
transformed initial state.
"""
if isinstance(observables, tf.Tensor):
u = qnn
u = self.circuit
ops = observables
post_process = lambda x: x
elif isinstance(observables.energy, energy_model.PauliMixin):
u = qnn + observables.circuit_dagger
u = self.circuit + observables.circuit_dagger
ops = observables.operator_shards
post_process = lambda y: tf.map_fn(
lambda x: tf.expand_dims(
Expand All @@ -134,12 +140,10 @@ def expectation(self, qnn: circuit_model.QuantumCircuit,
)
return utils.expand_unique_results(post_process(expectations), idx)

def sample(self, qnn: circuit_model.QuantumCircuit, initial_states: tf.Tensor,
counts: tf.Tensor):
def sample(self, initial_states: tf.Tensor, counts: tf.Tensor):
"""Returns bitstring samples from the QNN.

Args:
qnn: The parameterized quantum circuit on which to do inference.
initial_states: Shape [batch_size, num_qubits] of dtype `tf.int8`.
These are the initial states of each qubit in the circuit.
counts: Shape [batch_size] of dtype `tf.int32` such that `counts[i]` is
Expand All @@ -150,16 +154,16 @@ def sample(self, qnn: circuit_model.QuantumCircuit, initial_states: tf.Tensor,
that `ragged_samples[i]` contains `counts[i]` bitstrings drawn from
`(qnn)|initial_states[i]>`.
"""
circuits = qnn(initial_states)
circuits = self.circuit(initial_states)
num_circuits = tf.shape(circuits)[0]
tiled_values = tf.tile(
tf.expand_dims(qnn.symbol_values, 0), [num_circuits, 1])
tf.expand_dims(self.circuit.symbol_values, 0), [num_circuits, 1])
num_samples_mask = tf.cast((tf.ragged.range(counts) + 1).to_tensor(),
tf.bool)
num_samples_mask = tf.map_fn(tf.random.shuffle, num_samples_mask)
samples = self._sample_layer(
circuits,
symbol_names=qnn.symbol_names,
symbol_names=self.circuit.symbol_names,
symbol_values=tiled_values,
repetitions=tf.expand_dims(tf.math.reduce_max(counts), 0))
return tf.ragged.boolean_mask(samples, num_samples_mask)
120 changes: 83 additions & 37 deletions qhbmlib/energy_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,35 @@ class EnergyInferenceBase(tf.keras.layers.Layer, abc.ABC):
"""

def __init__(self,
input_energy: energy_model.BitstringEnergy,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
"""Initializes an EnergyInferenceBase.

Args:
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(name=name)
self._energy = input_energy
self._energy.build([None, self._energy.num_bits])

self._tracked_variables = input_energy.variables
if len(self._tracked_variables) == 0:
self._checkpoint = False
else:
self._tracked_variables_checkpoint = [
tf.Variable(v.read_value(), trainable=False)
for v in self._tracked_variables
]
self._checkpoint = True

if initial_seed is None:
self._update_seed = tf.Variable(True, trainable=False)
else:
Expand Down Expand Up @@ -104,22 +122,52 @@ def seed(self, initial_seed: Union[None, tf.Tensor]):
self._update_seed.assign(False)
self._seed.assign(tfp.random.sanitize_seed(initial_seed))

@property
def variables_updated(self):
"""Returns True if tracked variables do not have the checkpointed values."""
if self._checkpoint:
variables_not_equal_list = tf.nest.map_structure(
lambda v, vc: tf.math.reduce_any(tf.math.not_equal(v, vc)),
self._tracked_variables, self._tracked_variables_checkpoint)
return tf.math.reduce_any(tf.stack(variables_not_equal_list))
else:
return False

def _checkpoint_variables(self):
"""Checkpoints the currently tracked variables."""
if self._checkpoint:
tf.nest.map_structure(lambda v, vc: vc.assign(v), self._tracked_variables,
self._tracked_variables_checkpoint)

def _preface_inference(self):
"""Things all energy inference methods do before proceeding.

Called by `preface_inference` before the wrapped inference method.
Currently includes:
- run `self.infer` if this is the first call of a wrapped function
- run `self._ready_inference` if this is first call of a wrapped function
- change the seed if not set by the user during initialization
- run `self._ready_inference` if tracked energy parameters changed

Note: subclasses should take care to call the superclass method.
"""
if self._first_inference:
self.infer(self.energy)
self._checkpoint_variables()
self._ready_inference()
self._first_inference.assign(False)
if self._update_seed:
new_seed, _ = tfp.random.split_seed(self.seed)
self._seed.assign(new_seed)
if self.variables_updated:
self._checkpoint_variables()
self._ready_inference()

@abc.abstractmethod
def _ready_inference(self):
"""Performs computations common to all inference methods.

Contains inference code that must be run first if the variables of
`self.energy` have been updated since the last time inference was performed.
"""

@preface_inference
def call(self, inputs, *args, **kwargs):
Expand Down Expand Up @@ -181,37 +229,30 @@ def _sample(self, num_samples: int):
"""Default implementation wrapped by `self.sample`."""
raise NotImplementedError()

@abc.abstractmethod
def infer(self, energy: energy_model.BitstringEnergy):
"""Do the work to ready this layer for use.

This should be called each time the underlying model is updated.

Args:
energy: The parameterized energy function which defines this distribution
via the equations of an energy based model.
"""
raise NotImplementedError()


class EnergyInference(EnergyInferenceBase):
"""Provides some default method implementations."""

def __init__(self,
input_energy: energy_model.BitstringEnergy,
num_expectation_samples: int,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
"""Initializes an EnergyInference.

Args:
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
num_expectation_samples: Number of samples to draw and use for estimating
the expectation value.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(initial_seed, name)
super().__init__(input_energy, initial_seed, name)
self.num_expectation_samples = num_expectation_samples

def _expectation(self, function):
Expand Down Expand Up @@ -330,7 +371,7 @@ class AnalyticEnergyInference(EnergyInference):
"""Uses an explicit categorical distribution to implement parent functions."""

def __init__(self,
num_bits: int,
input_energy: energy_model.BitstringEnergy,
num_expectation_samples: int,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
Expand All @@ -341,19 +382,23 @@ def __init__(self,
and other inference tasks.

Args:
num_bits: Number of bits on which this layer acts.
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
num_expectation_samples: Number of samples to draw and use for estimating
the expectation value.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(num_expectation_samples, initial_seed, name)
super().__init__(input_energy, num_expectation_samples, initial_seed, name)
self._all_bitstrings = tf.constant(
list(itertools.product([0, 1], repeat=num_bits)), dtype=tf.int8)
list(itertools.product([0, 1], repeat=input_energy.num_bits)),
dtype=tf.int8)
self._logits_variable = tf.Variable(
tf.zeros([tf.shape(self._all_bitstrings)[0]]), trainable=False)
-input_energy(self.all_bitstrings), trainable=False)
self._distribution = tfd.Categorical(logits=self._logits_variable)

@property
Expand All @@ -368,9 +413,13 @@ def all_energies(self):

@property
def distribution(self):
"""Categorical distribution set during last call to `self.infer`."""
"""Categorical distribution set during `self._ready_inference`."""
return self._distribution

def _ready_inference(self):
"""See base class docstring."""
self._logits_variable.assign(-self.all_energies)

def _call(self, inputs, *args, **kwargs):
"""See base class docstring."""
if inputs is None:
Expand All @@ -394,41 +443,43 @@ def _sample(self, num_samples: int):
self.distribution.sample(num_samples, seed=self.seed),
axis=0)

def infer(self, energy: energy_model.BitstringEnergy):
"""See base class docstring."""
self._energy = energy
self._logits_variable.assign(-1.0 * self.all_energies)


class BernoulliEnergyInference(EnergyInference):
"""Manages inference for a Bernoulli defined by spin energies."""

def __init__(self,
num_bits: int,
input_energy: energy_model.BernoulliEnergy,
num_expectation_samples: int,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
"""Initializes a BernoulliEnergyInference.

Args:
num_bits: Number of bits on which this layer acts.
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
num_expectation_samples: Number of samples to draw and use for estimating
the expectation value.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(num_expectation_samples, initial_seed, name)
self._logits_variable = tf.Variable(tf.zeros([num_bits]), trainable=False)
super().__init__(input_energy, num_expectation_samples, initial_seed, name)
self._logits_variable = tf.Variable(input_energy.logits, trainable=False)
self._distribution = tfd.Bernoulli(
logits=self._logits_variable, dtype=tf.int8)

@property
def distribution(self):
"""Bernoulli distribution set during last call to `self.infer`."""
"""Bernoulli distribution set during `self._ready_inference`."""
return self._distribution

def _ready_inference(self):
"""See base class docstring."""
self._logits_variable.assign(self.energy.logits)

def _call(self, inputs, *args, **kwargs):
"""See base class docstring."""
if inputs is None:
Expand All @@ -454,14 +505,9 @@ def _log_partition_forward_pass(self):
"""
thetas = 0.5 * self.energy.logits
single_log_partitions = tf.math.log(
tf.math.exp(thetas) + tf.math.exp(-1.0 * thetas))
tf.math.exp(thetas) + tf.math.exp(-thetas))
return tf.math.reduce_sum(single_log_partitions)

def _sample(self, num_samples: int):
"""See base class docstring"""
return self.distribution.sample(num_samples, seed=self.seed)

def infer(self, energy: energy_model.BitstringEnergy):
"""See base class docstring."""
self._energy = energy
self._logits_variable.assign(self.energy.logits)
Loading