Skip to content

Commit

Permalink
Fix compatibility for earlier versions of Keras (#1690)
Browse files Browse the repository at this point in the history
* Fix compatibility for Keras3.1.0

* Address comments
  • Loading branch information
james77777778 authored Jul 11, 2024
1 parent 4858a22 commit 360da5b
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 17 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ jobs:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
version: [latest]
include:
- backend: torch
version: 3.1
runs-on: ubuntu-latest
env:
KERAS_BACKEND: ${{ matrix.backend }}
Expand All @@ -42,6 +46,11 @@ jobs:
run: |
pip install -r requirements.txt --progress-bar off
pip install --no-deps -e "." --progress-bar off
- name: Pin Keras version
if: ${{ matrix.version == '3.1'}}
run: |
pip uninstall -y keras
pip install keras==3.1.0 --progress-bar off
- name: Test with pytest
run: |
pytest keras_nlp/
Expand Down
20 changes: 15 additions & 5 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import keras
from keras import ops
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.utils.keras_utils import assert_quantization_support


@keras_nlp_export("keras_nlp.layers.ReversibleEmbedding")
Expand Down Expand Up @@ -107,7 +109,10 @@ def __init__(

def build(self, inputs_shape=None):
super().build(inputs_shape)
if not self.tie_weights and self.quantization_mode != "int8":
if (
not self.tie_weights
and getattr(self, "quantization_mode", None) != "int8"
):
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
Expand Down Expand Up @@ -142,11 +147,15 @@ def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if parse(keras.version()) < parse("3.2.0"):
return
target_variables = []
if not self.tie_weights:
# Store the reverse embedding weights as the last weights.
target_variables.append(self.reverse_embeddings)
if self.quantization_mode == "int8":
if getattr(self, "quantization_mode", None) == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(target_variables, start=len(store)):
store[str(i)] = variable
Expand All @@ -158,7 +167,7 @@ def load_own_variables(self, store):
if not self.tie_weights:
# Last weights in the stores are the reverse embedding weights.
target_variables = [self.reverse_embeddings]
if self.quantization_mode == "int8":
if getattr(self, "quantization_mode", None) == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(
target_variables, start=len(store) - len(target_variables)
Expand Down Expand Up @@ -226,10 +235,11 @@ def _int8_call(self, inputs, reverse=False):

return super()._int8_call(inputs)

def quantize(self, mode):
def quantize(self, mode, type_check=True):
import gc

if type(self) is not ReversibleEmbedding:
assert_quantization_support()
if type_check and type(self) is not ReversibleEmbedding:
raise NotImplementedError(
f"Layer {self.__class__.__name__} does not have a `quantize()` "
"method implemented."
Expand Down
7 changes: 7 additions & 0 deletions keras_nlp/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ReversibleEmbedding,
)
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.keras_utils import has_quantization_support


class ReversibleEmbeddingTest(TestCase):
Expand Down Expand Up @@ -103,6 +104,9 @@ def test_reverse_dtype(self):
("tie_weights", True), ("untie_weights", False)
)
def test_quantize_int8(self, tie_weights):
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
)
Expand Down Expand Up @@ -151,6 +155,9 @@ def test_quantize_int8(self, tie_weights):
("untie_weights", False),
)
def test_quantize_dtype_argument(self, tie_weights):
if not has_quantization_support():
self.skipTest("This version of Keras doesn't support quantization.")

self.run_layer_test(
cls=ReversibleEmbedding,
init_kwargs={
Expand Down
35 changes: 25 additions & 10 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.utils.keras_utils import assert_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
Expand Down Expand Up @@ -75,7 +76,14 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
self.dtype_policy = keras.dtype_policies.get(dtype)
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
if hasattr(keras.dtype_policies, "get"):
self.dtype_policy = keras.dtype_policies.get(dtype)
else:
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
dtype = dtype.name
dtype = dtype or keras.config.dtype_policy().name
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)

def __setattr__(self, name, value):
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
Expand All @@ -100,6 +108,10 @@ def token_embedding(self):
def token_embedding(self, value):
self._token_embedding = value

def quantize(self, mode, **kwargs):
assert_quantization_support()
return super().quantize(mode, **kwargs)

def get_config(self):
# Don't chain to super here. `get_config()` for functional models is
# a nested layer config and cannot be passed to Backbone constructors.
Expand All @@ -109,15 +121,18 @@ def get_config(self):
}

# Add quantization support by utilizing `DTypePolicyMap`
if isinstance(self.dtype_policy, keras.dtype_policies.DTypePolicyMap):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
if hasattr(keras.dtype_policies, "DTypePolicyMap"):
if isinstance(
self.dtype_policy, keras.dtype_policies.DTypePolicyMap
):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
return config

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion keras_nlp/src/models/pali_gemma/pali_gemma_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,14 @@ def __init__(
classifier_activation
)
self.image_sequence_length = int((image_size / patch_size) ** 2)
self.dtype_policy = keras.dtype_policies.get(dtype)
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
if hasattr(keras.dtype_policies, "get"):
self.dtype_policy = keras.dtype_policies.get(dtype)
else:
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
dtype = dtype.name
dtype = dtype or keras.config.dtype_policy().name
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)

def get_config(self):
config = super().get_config()
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from keras_nlp.src import layers as keras_nlp_layers
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
from keras_nlp.src.utils.keras_utils import has_quantization_support
from keras_nlp.src.utils.tensor_utils import is_float_dtype


Expand Down Expand Up @@ -445,7 +446,7 @@ def run_backbone_test(
self.run_precision_test(cls, init_kwargs, input_data)

# Check quantization.
if run_quantization_check:
if run_quantization_check and has_quantization_support():
self.run_quantization_test(backbone, cls, init_kwargs, input_data)

def run_task_test(
Expand Down
13 changes: 13 additions & 0 deletions keras_nlp/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import keras
from absl import logging
from packaging.version import parse

from keras_nlp.src.utils.tensor_utils import is_tensor_type

Expand Down Expand Up @@ -102,3 +103,15 @@ def print_msg(message, line_break=True):
@keras.saving.register_keras_serializable(package="keras_nlp")
def gelu_approximate(x):
return keras.activations.gelu(x, approximate=True)


def has_quantization_support():
return False if parse(keras.version()) < parse("3.4.0") else True


def assert_quantization_support():
if not has_quantization_support():
raise ValueError(
"Quantization API requires Keras >= 3.4.0 to function "
f"correctly. Received: '{keras.version()}'"
)

0 comments on commit 360da5b

Please sign in to comment.