diff --git a/RELEASE.md b/RELEASE.md index e71076bed4..2a23c0503e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,5 +1,8 @@ # Unreleased +* Enable support for models with non-trainable variables in + `functional_model_from_keras`. + ## Breaking Changes * Updated `com_github_grpc_grpc` to version `1.50.0`. diff --git a/tensorflow_federated/python/learning/models/BUILD b/tensorflow_federated/python/learning/models/BUILD index c4ff95a20e..948011b87e 100644 --- a/tensorflow_federated/python/learning/models/BUILD +++ b/tensorflow_federated/python/learning/models/BUILD @@ -55,7 +55,6 @@ py_library( "//tensorflow_federated/python/learning/metrics:keras_finalizer", "//tensorflow_federated/python/learning/metrics:keras_utils", "//tensorflow_federated/python/learning/metrics:types", - "//tensorflow_federated/python/tensorflow_libs:variable_utils", ], ) diff --git a/tensorflow_federated/python/learning/models/functional.py b/tensorflow_federated/python/learning/models/functional.py index 6164711e84..98239411b6 100644 --- a/tensorflow_federated/python/learning/models/functional.py +++ b/tensorflow_federated/python/learning/models/functional.py @@ -38,7 +38,6 @@ from tensorflow_federated.python.learning.metrics import keras_utils from tensorflow_federated.python.learning.metrics import types from tensorflow_federated.python.learning.models import variable -from tensorflow_federated.python.tensorflow_libs import variable_utils Weight = Union[np.ndarray, int, float] @@ -477,15 +476,6 @@ def functional_model_from_keras( 'incompatible with `tff.learning.models.FunctionalModel`. Consider ' 'using group normalization instead.' ) - if keras_model.non_trainable_variables: - raise KerasFunctionalModelError( - 'Received a Keras model with non-trainable variables. Keras models' - ' with non-trainable variables are currently not supported by' - ' FunctionalModel. Most training algorithms (e.g. Federated' - ' Averaging) will not aggregate them, and they are not updated' - ' locally by the optimizer. We can relax this in the future if we' - ' have APIs that support updating non-trainable variables.' - ) elif not callable(keras_model): raise ValueError( '`keras_model` must be a `tf.keras.Model` or a no-arg ' @@ -508,42 +498,43 @@ def functional_model_from_keras( # also setup ops to inject the current model weights, because the cloned model # will be re-initialized from scratch. with tf.Graph().as_default() as g: - with variable_utils.record_variable_creation_scope() as captured_variables: - if isinstance(keras_model, tf.keras.Model): - try: - cloned_model = tf.keras.models.clone_model(keras_model) - except RuntimeError as e: - raise KerasFunctionalModelError( - 'Encountered a error converting the Keras model. Often this ' - 'occurs when the `tf.keras.Model` has a layer that receives ' - 'inputs from other layers directly (e.g. shared embeddings).' - 'To avoid the problem, wrap the `tf.keras.Model` construction in ' - 'a no-arg callable (e.g. lambda) and pass that callable to ' - '`functional_model_from_keras`' - ) from e - if len(cloned_model.variables) != len(keras_model.variables): - raise KerasFunctionalModelError( - 'The input Keras model is likely sharing variables across layers ' - 'which is unsupported. Cloning the model will duplicate these ' - 'variables and result in unexpected training gradients.' - ) - else: - cloned_model = keras_model() - - # Ensure our cloned model has the same weights as the current model. - # We'll feed in the current model waits into the palceholders for - # assignmnet in a session below. - def assign_placeholder(v): - p = tf.compat.v1.placeholder(dtype=v.dtype) - return v.assign(p), p + if isinstance(keras_model, tf.keras.Model): + try: + cloned_model = tf.keras.models.clone_model(keras_model) + except RuntimeError as e: + raise KerasFunctionalModelError( + 'Encountered a error converting the Keras model. Often this ' + 'occurs when the `tf.keras.Model` has a layer that receives ' + 'inputs from other layers directly (e.g. shared embeddings).' + 'To avoid the problem, wrap the `tf.keras.Model` construction in ' + 'a no-arg callable (e.g. lambda) and pass that callable to ' + '`functional_model_from_keras`' + ) from e + if len(cloned_model.variables) != len(keras_model.variables): + raise KerasFunctionalModelError( + 'The input Keras model is likely sharing variables across layers ' + 'which is unsupported. Cloning the model will duplicate these ' + 'variables and result in unexpected training gradients.' + ) + else: + cloned_model = keras_model() + captured_variables = cloned_model.variables + captured_trainable_variables = cloned_model.trainable_variables + captured_nontrainable_variables = cloned_model.non_trainable_variables + + # Ensure our cloned model has the same weights as the current model. + # We'll feed in the current model waits into the placeholders for + # assignmnet in a session below. + def assign_placeholder(v): + p = tf.compat.v1.placeholder(dtype=v.dtype) + return v.assign(p), p + + assign_ops, placeholders = zip( + *(assign_placeholder(v) for v in cloned_model.variables) + ) - assign_ops, placeholders = zip( - *(assign_placeholder(v) for v in cloned_model.variables) - ) - trainable_variables = tuple(v for v in captured_variables if v.trainable) - non_trainable_variables = tuple( - v for v in captured_variables if not v.trainable - ) + trainable_variables = tuple(v for v in captured_trainable_variables) + non_trainable_variables = tuple(v for v in captured_nontrainable_variables) # Here we get the initial weights from the incoming keras model in the order # they are constructed; and also ensure that the values are set to the diff --git a/tensorflow_federated/python/learning/models/functional_test.py b/tensorflow_federated/python/learning/models/functional_test.py index 2a89fff09d..be29662cc3 100644 --- a/tensorflow_federated/python/learning/models/functional_test.py +++ b/tensorflow_federated/python/learning/models/functional_test.py @@ -899,23 +899,24 @@ def train(): self.assertGreater(initial_loss, 2.0) self.assertLess(final_loss, 0.2) - def test_keras_model_with_non_trainable_variables_fails(self): + def test_keras_model_with_non_trainable_variables(self): inputs = tf.keras.layers.Input(shape=[1]) d = tf.keras.layers.Dense(1) d.trainable = False outputs = d(inputs) keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) - with self.assertRaisesRegex( - functional.KerasFunctionalModelError, 'non-trainable variables' - ): - functional.functional_model_from_keras( - keras_model, - tf.keras.losses.MeanSquaredError(), - input_spec=( - tf.TensorSpec(shape=[None, 1]), - tf.TensorSpec(shape=[None, 1]), - ), - ) + functional_model = functional.functional_model_from_keras( + keras_model, + tf.keras.losses.MeanSquaredError(), + input_spec=( + tf.TensorSpec(shape=[None, 1]), + tf.TensorSpec(shape=[None, 1]), + ), + ) + self.assertEmpty(functional_model.initial_weights[0]) + # We expect there to be two non-trainable variables: the kernel and bias + # of the dense layer. + self.assertLen(functional_model.initial_weights[1], 2) def test_keras_model_with_batch_normalization_fails(self): model = tf.keras.models.Sequential([