diff --git a/lit_nlp/api/model.py b/lit_nlp/api/model.py index e8b29363..ce034a20 100644 --- a/lit_nlp/api/model.py +++ b/lit_nlp/api/model.py @@ -49,6 +49,13 @@ def maybe_copy_np(arr): # If this is not a view of another array. if arr.base is None: return arr + # Tensorflow provides a bridge to share memory between tensorflow and numpy + # arrays. This looks like a view into an array but the base is a + # tensorflow_wrapper not an array, so the view heuristics below don't work. We + # can check for this case by checking is arr.base has the ndim attribute. + # /~https://github.com/tensorflow/tensorflow/blob/6ed79e8429730c33dc894175da7a1849a8e3e57f/tensorflow/python/lib/core/ndarray_tensor_bridge.cc#L90 + if not hasattr(arr.base, 'ndim'): + return np.copy(arr) # Heuristic to check if we should 'detach' this array from the parent blob. # We want to know if this array is a view that might leak memory. # The simplest check is if arr.base is larger than arr, but we don't want to diff --git a/lit_nlp/examples/lm_salience_demo.py b/lit_nlp/examples/lm_salience_demo.py index 10091280..05f6c1a4 100644 --- a/lit_nlp/examples/lm_salience_demo.py +++ b/lit_nlp/examples/lm_salience_demo.py @@ -1,33 +1,43 @@ r"""Demo for sequence salience with a left-to-right language model. -To use with Gemma models, install the latest versions of Keras and KerasNLP: +To use with the Gemma, Llama, or Mistral models, install the latest versions of +Keras, KerasNLP, and/or HuggingFace Transformers: - pip install keras>=3.0.5 keras-nlp>=0.8.0 + pip install keras>=3.1.0 keras-nlp>=0.9.0 transformers>=4.38.0 + +To run with the default configuration (Gemma on TensorFlow via Keras): -To run: blaze run -c opt examples:lm_salience_demo -- \ --models=gemma_instruct_2b_en:gemma_instruct_2b_en \ --port=8890 --alsologtostderr -We strongly recommend a GPU or other accelerator to run this demo, although for -testing, the smaller GPT-2 models run well on CPU. To use tensorflow weights of -GPT2, set the flag values as below: ---models=gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz ---hf_framework=tensorflow - -We also support pytorch weights for GPT-2 model, simply set the flag values: ---models=gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2-pt.tar.gz ---hf_framework=pytorch - -A few more examples of the flag setup for other supported models (GPU required): -Llama2: --hf_framework=pytorch - --models=llama2:meta-llama/Llama-2-7b-hf -Mistral: --hf_framework=pytorch - --models=mistral:mistralai/Mistral-7B-v0.1 - -By default this include a small set of sample prompts, but you can load your -own examples using the --datasets flag or through the "Configure" menu in the -UI. +MODELS: + +We strongly recommend a GPU or other accelerator to run this server with LLMs. +The table below shows the model names and presets for common models. Use these +to parameterize the --models flag with comma-separated `{model}:{preset}` +strings, and remember the number of models loaded will be limited by the memory +available on your accelerator. + +| Model | dl_framework | dl_backend=tensorflow Preset | dl_backend=torch Preset | +| ------- | ------------ | ---------------------------- | ------------------------------------ | +| Gemma | kerasnlp | gemma_1.1_instruct_7b_en | gemma_1.1_instruct_7b_en | +| Gemma | transformers | Unavailable | google/gemma-1.1-7b-it | +| Llama 2 | kerasnlp | llama2_instruct_7b_en | llama2_instruct_7b_en | +| Llama 2 | transformers | Unavailable | meta-llama/Llama-2-7b-hf | +| Mistral | kerasnlp | mistral_instruct_7b_en | mistral_instruct_7b_en | +| Mistral | transformers | Unavailable | mistralai/Mistral-7B-Instruct-v0.2 | + +Additional model presets can be found at the following locations, though +compatibility with the LIT model wrappers is not guaranteed: + +* KerasNLP: https://keras.io/api/keras_nlp/models/ +* HuggingFace Transformers: https://huggingface.co/models + +DATASETS: + +By default this includes a small set of sample prompts. You can load your own +examples using the --datasets flag or through the "Configure" menu in the UI. """ from collections.abc import Sequence @@ -37,31 +47,15 @@ import sys from typing import Optional -# TODO(b/327281789): remove once keras 3 is the default. -# Temporary; need to set this before importing keras_nlp -os.environ["FORCE_KERAS_3"] = "True" - -# pylint: disable=g-import-not-at-top from absl import app from absl import flags from absl import logging -import keras -from keras_nlp import models as keras_models from lit_nlp import dev_server from lit_nlp import server_flags from lit_nlp.api import layout from lit_nlp.examples.datasets import lm as lm_data -from lit_nlp.examples.models import instrumented_keras_lms as lit_keras -from lit_nlp.examples.models import pretrained_lms from lit_nlp.lib import file_cache -# pytype: disable=import-error -try: - import torch -except (ModuleNotFoundError, ImportError): - logging.warning("PyTorch is not available.") -# pytype: enable=import-error - # NOTE: additional flags defined in server_flags.py FLAGS = flags.FLAGS @@ -70,16 +64,14 @@ _MODELS = flags.DEFINE_list( "models", - [ - "gemma_instruct_2b_en:gemma_instruct_2b_en", - "gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz", - ], - "Models to load, as :. Currently supports Gemma (Keras NLP) and" - "HuggingFace models. For HuggingFace models, GPT2, Llama, Mistral have been" - "verified to work with this demo. Thereotically, supported decoder models" - "in `transformers.AutoModelForCausalLM` should work, but adjustments might" - "be needed on their tokenizers (e.g. need to define custom pad_token when" - "eos_token is not available to use as pad_token).", + ["gemma_instruct_2b_en:gemma_instruct_2b_en"], + "Models to load, as :. Path can be a URL, a local file path, or" + " the name of a preset for the configured Deep Learning framework (either" + " KerasNLP or HuggingFace Transformers; see --dl_framework for more). This" + " demo is tested with Gemma, GPT2, Llama, and Mistral on all supported" + " --dl_framework values. Other models should work, but adjustments might be" + " needed on their tokenizers (e.g., to define custom pad_token" + " when eos_token is not available to use as pad_token).", ) _DATASETS = flags.DEFINE_list( @@ -99,19 +91,31 @@ ), ) -_HF_FRAMEWORK = flags.DEFINE_enum( - "hf_framework", +_DL_BACKEND = flags.DEFINE_enum( + "dl_backend", "tensorflow", - ["tensorflow", "pytorch"], - "Deep learning framework for the HuggingFace model.", + ["jax", "torch", "tensorflow"], + "The deep learning backend framework that the model runs on. All models" + " loaded by this server will use the same backend, incompatibilities will" + " result in errors.", +) + +_DL_FRAMEWORK = flags.DEFINE_enum( + "dl_framework", + "kerasnlp", + ["kerasnlp", "transformers"], + "The deep learning framework that loads and runs the model on the backend." + " This server will attempt to load all models specified by the --models" + " flag with the configured framework, incompatibilities will result in" + " errors.", ) _PRECISION = flags.DEFINE_enum( "precision", "bfloat16", ["bfloat16", "float32"], - "Floating point precision for the HuggingFace (PyTorch) and Keras models," - "only `bfloat16` and `float32` are supported for now.", + "Floating point precision for the models, only `bfloat16` and `float32` are" + " supported for now.", ) @@ -201,20 +205,28 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: raise app.UsageError("Too many command-line arguments.") # Set Keras backend and floating-point precision. - os.environ["KERAS_BACKEND"] = "tensorflow" - if hasattr(keras, "config") and hasattr(keras.config, "set_floatx"): + if _DL_FRAMEWORK.value == "kerasnlp": + # NOTE: Keras and KerasNLP require that certain environment variables are + # set before they are imported. + # TODO(b/327281789): Remove FORCE_KERAS_3 once Keras 3 is the default. + os.environ["FORCE_KERAS_3"] = "True" + os.environ["KERAS_BACKEND"] = _DL_BACKEND.value + + # NOTE: Imported here and not at the top of the file to avoid + # initialization issues with the environment variables above. We should also + # import keras before any other Keras-related modules (e.g., KerasNLP or the + # LIT wrappers) to limit the potenital for improperly configured backends. + import keras # pylint: disable=g-import-not-at-top + keras.config.set_floatx(_PRECISION.value) - else: - # TODO(b/327281789): remove once we can guarantee Keras 3. - logging.warn( - "keras.config.set_floatx() not available; using default precision." - ) + elif _DL_BACKEND.value == "torch": + # NOTE: Keras sets precision for all backends with set_floatx(), but for + # HuggingFace Transformers with PyTorch we need to set it explicitly. + import torch # pylint: disable=g-import-not-at-top # pytype: disable=import-error - if _HF_FRAMEWORK.value == "pytorch": - if _PRECISION.value == "bfloat16": - torch.set_default_dtype(torch.bfloat16) - else: - torch.set_default_dtype(torch.float32) + torch.set_default_dtype( + torch.bfloat16 if _PRECISION.value == "bfloat16" else torch.float32 + ) plaintextPrompts = functools.partial( # pylint: disable=invalid-name lm_data.PlaintextSents, field_name="prompt" @@ -259,38 +271,37 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # Load models, according to the --models flag. models = {} for model_string in _MODELS.value: - # Only split on the first ':', because path may be a URL - # containing 'https://' + # Only split on the first ':' as path may be a URL containing 'https://' model_name, path = model_string.split(":", 1) logging.info("Loading model '%s' from '%s'", model_name, path) - if model_name.startswith("gemma"): - path = file_cache.cached_path( - path, - extract_compressed_file=path.endswith(".tar.gz"), - copy_directories=True, - ) + + path = file_cache.cached_path( + path, + extract_compressed_file=path.endswith(".tar.gz"), + copy_directories=True, + ) + + if _DL_FRAMEWORK.value == "keras": + # pylint: disable=g-import-not-at-top + from keras_nlp import models as keras_models + from lit_nlp.examples.models import instrumented_keras_lms as lit_keras + # pylint: enable=g-import-not-at-top # Load the weights once for the underlying Keras model. - gemma_keras_model = keras_models.GemmaCausalLM.from_preset(path) # pytype: disable=module-attr - models = models | lit_keras.initialize_model_group_for_salience( - model_name, gemma_keras_model, max_length=512, batch_size=4 + model = keras_models.CausalLM.from_preset(path) + models |= lit_keras.initialize_model_group_for_salience( + model_name, model, max_length=512, batch_size=4 ) # Disable embeddings from the generation model. # TODO(lit-dev): re-enable embeddings if we can figure out why UMAP was # crashing? Maybe need n > 2 examples. models[model_name].output_embeddings = False else: + # NOTE: (Style Deviation) Imported here to limit uncessary imports. + from lit_nlp.examples.models import pretrained_lms # pylint: disable=g-import-not-at-top # Assuming a valid decoder model name supported by # `transformers.AutoModelForCausalLM` is provided to "path". - models[model_name] = pretrained_lms.HFGenerativeModel( - path, framework=_HF_FRAMEWORK.value, max_new_tokens=512 - ) - # Salience wrapper, using same underlying Keras models so as not to - # load the weights twice. - models[f"_{model_name}_salience"] = ( - pretrained_lms.HFSalienceModel.from_loaded(models[model_name]) - ) - models[f"_{model_name}_tokenizer"] = ( - pretrained_lms.HFTokenizerModel.from_loaded(models[model_name]) + models |= pretrained_lms.initialize_model_group_for_salience( + model_name, path, framework=_DL_BACKEND.value, max_new_tokens=512 ) for name in datasets: diff --git a/lit_nlp/examples/models/instrumented_keras_lms.py b/lit_nlp/examples/models/instrumented_keras_lms.py index 453487bd..44236311 100644 --- a/lit_nlp/examples/models/instrumented_keras_lms.py +++ b/lit_nlp/examples/models/instrumented_keras_lms.py @@ -6,11 +6,31 @@ from typing import Sequence from absl import logging +import keras.backend +import keras.ops from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types from lit_nlp.lib import utils as lit_utils import numpy as np -import tensorflow as tf + + +# pylint: disable=g-import-not-at-top +# pytype: disable=import-error +# NOTE: The Keras backend must be set before loading the Keras library. You can +# set the backend using the KERAS_BACKEND environment variable or your +# ~/.keras/keras.json configuration file. For more information, see: +# https://keras.io/getting_started/#configuring-your-backend +if keras.backend.backend() == "tensorflow": + import tensorflow as tf +elif keras.backend.backend() == "jax": + # TODO(lit-dev): Update imports once a solution to JAX salience is decided. + pass +elif keras.backend.backend() == "torch": + import torch +else: + raise ValueError(f"Unsupported backend: {keras.backend.backend()}") +# pytype: enable=import-error +# pylint: enable=g-import-not-at-top _DEFAULT_MAX_LENGTH = 1024 @@ -54,7 +74,7 @@ def __init__( and manipulate activations between layers. We use this for salience, below. Args: - model: pre-loaded Keras LM using the TF backend + model: pre-loaded Keras LM max_length: max sequence length dynamic_sequence_length: if true, will trim padding to the length of the longest sequence in a batch. Recommended for CPU and GPU usage, but may @@ -90,7 +110,7 @@ def encode_inputs(self, texts: Sequence[str]): texts: list of input strings Returns: - encoded_inputs compatible with model.score() or other functions + A dict[str, Tensor] compatible with model.score(), etc. functions. """ # First: pack to max_length encoded_inputs = self.model.preprocessor.generate_preprocess( @@ -101,20 +121,26 @@ def encode_inputs(self, texts: Sequence[str]): # Trim to the maximum length needed to contain any non-padding tokens. mask = encoded_inputs["padding_mask"] + + if keras.backend.backend() == "tensorflow": + max_indices = [tf.reduce_max(tf.where(row)) for row in mask] + elif keras.backend.backend() == "torch": + max_indices = [torch.max(torch.where(row)[0]) for row in mask] + else: + raise ValueError(f"Unsupported backend: {keras.backend.backend()}") # Find position of last 'True' in each row. seq_ends: Sequence[int] = [ - 1 + tf.reduce_max(tf.where(mask[i])).numpy().tolist() - for i in range(mask.shape[0]) + keras.ops.convert_to_numpy(i).tolist() + 1 for i in max_indices ] - trimmed_length = max(seq_ends) + longest_sequence = max(seq_ends) # TODO(lit-dev): remove this line, or make it logging.debug ? logging.info( "Trimming batch to trimmed_length = %d based on sequence ends %s", - trimmed_length, + longest_sequence, seq_ends, ) # Actually trim the input tensors. - return {k: v[:, :trimmed_length] for k, v in encoded_inputs.items()} + return {k: v[:, :longest_sequence] for k, v in encoded_inputs.items()} @classmethod def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw): @@ -156,23 +182,32 @@ def __init__(self, *args, output_embeddings=True, **kw): def embed_texts(self, texts: Sequence[str]): processed_inputs = self.encode_inputs(texts) - # [batch_size, num_tokens, emb_dim] + # [batch_size, num_tokens, emb_dim] embs = self.embedder(processed_inputs["token_ids"]) - # [batch_size, num_tokens] + # [batch_size, num_tokens] mask = processed_inputs["padding_mask"] return embs, mask def embed_and_mean_pool(self, texts: Sequence[str]): """Return a single vector for each text.""" embs, mask = self.embed_texts(texts) - # [batch_size, num_tokens, 1] - mask = tf.expand_dims(tf.cast(mask, dtype=embs.dtype), axis=2) - # [batch_size, 1, emb_dim] - pooled_embs = tf.reduce_sum( - mask * embs, axis=1, keepdims=True - ) / tf.reduce_sum(mask, axis=1, keepdims=True) - # [batch_size, emb_dim] - return tf.squeeze(pooled_embs, axis=1) + # [batch_size, num_tokens, 1] + cast_mask = keras.ops.cast(mask, dtype=embs.dtype) + + if keras.backend.backend() == "tensorflow": + expanded_mask = tf.expand_dims(cast_mask, axis=2) + pooled_embs = tf.reduce_sum( + expanded_mask * embs, axis=1, keepdims=True + ) / tf.reduce_sum(expanded_mask, axis=1, keepdims=True) + return tf.squeeze(pooled_embs, axis=1) + elif keras.backend.backend() == "torch": + expanded_mask = torch.unsqueeze(cast_mask, dim=2) + pooled_embs = torch.sum( + expanded_mask * embs, dim=1, keepdim=True + ) / torch.sum(expanded_mask, dim=1, keepdim=True) + return torch.squeeze(pooled_embs, dim=1) + else: + raise ValueError(f"Unsupported backend: {keras.backend.backend()}") def predict_minibatch( self, @@ -200,11 +235,9 @@ def predict_minibatch( # Or just embed full_response. response_embeddings = self.embed_and_mean_pool(responses) - for i in range(len(inputs)): - outputs[i][FieldNames.PROMPT_EMBEDDINGS] = prompt_embeddings[i].numpy() - outputs[i][FieldNames.RESPONSE_EMBEDDINGS] = response_embeddings[ - i - ].numpy() + for o, p, r in zip(outputs, prompt_embeddings, response_embeddings): + o[FieldNames.PROMPT_EMBEDDINGS] = keras.ops.convert_to_numpy(p) + o[FieldNames.RESPONSE_EMBEDDINGS] = keras.ops.convert_to_numpy(r) return outputs @@ -244,18 +277,22 @@ def __init__(self, *args, **kw): ) def _pred(self, input_ids, padding_mask, target_masks): - """Predict a batch of tokenized text.""" - # [batch_size, num_tokens]; ignore the last one in each row. - target_ids = tf.roll(input_ids, shift=-1, axis=1) + """Predict a batch of tokenized text. + + Args: + input_ids: A Tensor with shape [batch_size, num_tokens] + padding_mask: A Tensor with shape [batch_size, num_tokens] + target_masks: A Numpy Array with shape [batch_size, num_tokens] + Returns: + Batched outputs for post-processing. + """ ## # Process target masks - # It doesn't make sense to interpret the first token, since it is not ever # predicted. But we need to ensure that the mask[0] is zero, so it doesn't # cause problems when 'rolled' to the last position below. - modified_masks = [[0] + list(mask[1:]) for mask in target_masks] - seq_len = target_ids.shape[1] + seq_len = keras.ops.shape(input_ids)[1] pad_fn = functools.partial( lit_utils.pad1d, min_len=seq_len, @@ -263,15 +300,46 @@ def _pred(self, input_ids, padding_mask, target_masks): pad_val=0, pad_left=False, ) - padded_target_masks = np.stack( + + modified_masks = [[0] + list(mask[1:]) for mask in target_masks] + stacked_padded_masks = keras.ops.stack( [pad_fn(mask) for mask in modified_masks], axis=0, ) + # Shift masks back so they align with the target_ids generated in the + # backend-specific prediction functions. + rolled_masks = keras.ops.roll(stacked_padded_masks, shift=-1, axis=1) + loss_mask = keras.ops.convert_to_tensor(rolled_masks, dtype="bool") + + pred_kw_args = { + "input_ids": input_ids, + "padding_mask": padding_mask, + "loss_mask": loss_mask, + } + if keras.backend.backend() == "tensorflow": + grad_l2, grad_dot_input = self._pred_tf(**pred_kw_args) + elif keras.backend.backend() == "jax": + grad_l2, grad_dot_input = self._pred_jax(**pred_kw_args) + elif keras.backend.backend() == "torch": + grad_l2, grad_dot_input = self._pred_torch(**pred_kw_args) + else: + raise ValueError(f"Unsupported backend: {keras.backend.backend()}") + + batched_outputs = { + "input_ids": input_ids, + "padding_mask": padding_mask, + # Gradients are already aligned to input tokens. + FieldNames.GRAD_NORM: grad_l2, + FieldNames.GRAD_DOT_INPUT: grad_dot_input, + # Shift token loss to align with (input) tokens. + # FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1), + } - padded_target_masks = tf.constant(padded_target_masks, dtype=tf.bool) - # Shift masks back so they align with target_ids. - loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) + return batched_outputs + def _pred_tf(self, input_ids, padding_mask, loss_mask): + # [batch_size, num_tokens]; ignore the last one in each row. + target_ids = tf.roll(input_ids, shift=-1, axis=1) embeddings = None with tf.GradientTape(watch_accessed_variables=False) as tape: @@ -283,7 +351,7 @@ def layer_intercept_fn(x, i): tape.watch(embeddings) return x - # [batch_size, num_tokens] + # [batch_size, num_tokens] per_token_loss = self.model.score( token_ids=input_ids, padding_mask=padding_mask, @@ -291,26 +359,62 @@ def layer_intercept_fn(x, i): layer_intercept_fn=layer_intercept_fn, target_ids=target_ids, ) - masked_loss = per_token_loss * tf.cast(loss_mask, per_token_loss.dtype) + masked_loss = per_token_loss * keras.ops.cast( + loss_mask, per_token_loss.dtype + ) - # [batch_size, num_tokens, hdim] + # [batch_size, num_tokens, hdim] grads = tape.gradient(masked_loss, embeddings) - # [batch_size, num_tokens] + # [batch_size, num_tokens] grad_l2 = tf.norm(grads, axis=2) - # [batch_size, num_tokens] + # [batch_size, num_tokens] grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2) + return grad_l2, grad_dot_input + + # TODO(b/333373960): Implement salience computation for JAX. + def _pred_jax(self, input_ids, padding_mask, loss_mask): + # NOTE: JAX computes gradients automatically w.r.t function inputs and + # outputs. The score function takes token_ids as its input but salience is + # computed w.r.t. the embeddings, thus JAX cannot differentiate the loss + # w.r.t. the embeddings and taking gradients w.r.t. the token_ids is not + # equivalent. For now, we raise an error if using JAX. + raise NotImplementedError("JAX backend not supported for salience.") + + def _pred_torch(self, input_ids, padding_mask, loss_mask): + target_ids = torch.roll(input_ids, shifts=-1, dims=1) + embeddings = None - batched_outputs = { - "input_ids": input_ids, - "padding_mask": padding_mask, - # Gradients are already aligned to input tokens. - FieldNames.GRAD_NORM: grad_l2, - FieldNames.GRAD_DOT_INPUT: grad_dot_input, - # Shift token loss to align with (input) tokens. - # FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1), - } + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings + embeddings = x + return x + + per_token_loss = self.model.score( + token_ids=input_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) - return batched_outputs + if embeddings is None: + raise ValueError("Embeddings are None after scoring.") + + masked_loss = per_token_loss * keras.ops.cast( + loss_mask, per_token_loss.dtype + ) + + # [batch_size, num_tokens, hdim] + grads = torch.autograd.grad( + masked_loss, embeddings, grad_outputs=torch.ones_like(masked_loss) + )[0] + embeddings = embeddings.detach() + # [batch_size, num_tokens] + grad_l2 = torch.norm(grads, dim=2) + # [batch_size, num_tokens] + grad_dot_input = torch.sum(grads * embeddings, dim=2) + return grad_l2, grad_dot_input def _postprocess(self, preds): """Post-process single-example preds. Operates on numpy arrays.""" @@ -340,7 +444,9 @@ def predict_minibatch(self, inputs): # Get the predictions. batched_outputs = self._pred(sequence_ids, padding_mask, target_masks) # Convert to numpy for post-processing. - detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + detached_outputs = { + k: keras.ops.convert_to_numpy(v) for k, v in batched_outputs.items() + } # Split up batched outputs, then post-process each example. unbatched_outputs = lit_utils.unbatch_preds(detached_outputs) return map(self._postprocess, unbatched_outputs) @@ -388,7 +494,9 @@ def predict_minibatch(self, inputs): "padding_mask": preprocessed_texts["padding_mask"], } # Convert to numpy for post-processing. - detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} + detached_outputs = { + k: keras.ops.convert_to_numpy(v) for k, v in batched_outputs.items() + } # Split up batched outputs, then post-process each example. unbatched_outputs = lit_utils.unbatch_preds(detached_outputs) return map(self._postprocess, unbatched_outputs) diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index ebd5c5bc..56d3e1a9 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -36,7 +36,7 @@ _DEFAULT_MAX_LENGTH = 1024 -_PYTORCH = "pytorch" +_PYTORCH = "torch" _TENSORFLOW = "tensorflow" # HuggingFace uses two letter abbreviations for pytorch and tensorflow. _HF_PYTORCH = "pt" @@ -400,7 +400,7 @@ def __init__( model_name_or_path: gpt2, gpt2-medium, gpt2-large, distilgpt2, meta-llama/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, etc. batch_size: the number of items to process per `predict_minibatch` call. - framework: the deep learning framework, only "tensorflow" and "pytorch" + framework: the deep learning framework, only "tensorflow" and "torch" are supported. model: an initialized transformer model. tokenizer: an initialized tokenizer. @@ -866,3 +866,19 @@ def output_spec(self) -> lit_types.Spec: return { "tokens": lit_types.Tokens(parent=""), # all tokens } + + +def initialize_model_group_for_salience( + name, *args, max_new_tokens=512, **kw +) -> dict[str, lit_model.Model]: + """Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'.""" + generation_model = HFGenerativeModel( + *args, **kw, max_new_tokens=max_new_tokens + ) + salience_model = HFSalienceModel.from_loaded(generation_model) + tokenizer_model = HFTokenizerModel.from_loaded(generation_model) + return { + name: generation_model, + f"_{name}_salience": salience_model, + f"_{name}_tokenizer": tokenizer_model, + } diff --git a/lit_nlp/examples/models/pretrained_lms_int_test.py b/lit_nlp/examples/models/pretrained_lms_int_test.py index 12cad0ce..229be85b 100644 --- a/lit_nlp/examples/models/pretrained_lms_int_test.py +++ b/lit_nlp/examples/models/pretrained_lms_int_test.py @@ -39,12 +39,12 @@ class GPT2Generation(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name="tensorflow", - framework="tensorflow", + framework=pretrained_lms.MLFramework.TF.value, model_path="https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz", ), dict( testcase_name="pytorch", - framework="pytorch", + framework=pretrained_lms.MLFramework.PT.value, model_path="https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2-pt.tar.gz", ), ) diff --git a/pyproject.toml b/pyproject.toml index fe8c4efe..218ed676 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,32 +14,32 @@ license = { file = "LICENSE" } requires-python = ">=3.10" # LINT.IfChange dependencies = [ - "absl-py==1.4.0", - "annoy==1.17.3", - "attrs==22.1.0", - "etils[epath]==1.3.0", - "filelock==3.12.2", - "google-cloud-translate==3.11.1", - "ipython==8.14.0", - "Levenshtein==0.21.1", - "matplotlib==3.6.1", - "ml-collections==0.1.1", - "numpy==1.24.1", - "pandas==1.5.3", - "Pillow==10.0.0", - "portpicker==1.5.2", - "requests==2.31.0", - "rouge-score==0.1.2", - "sacrebleu==2.3.1", - "saliency==0.1.3", - "scikit-learn==1.0.2", - "scipy==1.10.1", + "absl-py>=1.4.0", + "annoy>=1.17.3", + "attrs>=22.1.0", + "etils[epath]>=1.7.0", + "filelock>=3.12.3", + "google-cloud-translate>=3.11.1", + "ipython>=7.34.0", + "Levenshtein>=0.21.1", + "matplotlib>=3.7.1", + "ml-collections>=0.1.1", + "numpy>=1.24.1", + "pandas>=2.0.3", + "Pillow>=10.0.0", + "portpicker>=1.5.2", + "requests>=2.31.0", + "rouge-score>=0.1.2", + "sacrebleu>=2.3.1", + "saliency>=0.1.3", + "scikit-learn>=1.0.2", + "scipy>=1.10.1", "shap==0.42.0", - "six==1.16.0", - "termcolor==2.3.0", - "tqdm==4.64.0", - "umap-learn==0.5.1", - "werkzeug==2.3.6", + "six>=1.16.0", + "termcolor>=2.3.0", + "tqdm>=4.64.0", + "umap-learn>=0.5.1", + "werkzeug>=2.2.3", ] # LINT.ThenChange(./requirements_core.txt) classifiers = [ @@ -83,14 +83,14 @@ examples = [ "tensorflow==2.10.0", "tensorflow-datasets==4.8.0", "tensorflow-text==2.10.0", - "torch==2.0.1", - "transformers==4.27.1", + "torch>=2.0.0", + "transformers>=4.27.1", ] # LINT.ThenChange(./requirements_examples.txt) # LINT.IfChange test = [ "lime==0.2.0.1", - "pytest==7.4.0", + "pytest>=7.4.0,<8.0.0", ] # LINT.ThenChange(./requirements_test.txt) diff --git a/requirements_core.txt b/requirements_core.txt index b65f6b7d..2ceb6ed4 100644 --- a/requirements_core.txt +++ b/requirements_core.txt @@ -13,30 +13,30 @@ # limitations under the License. # ============================================================================== # LINT.IfChange -absl-py==1.4.0 -annoy==1.17.3 -attrs==22.1.0 -etils[epath]==1.3.0 -filelock==3.12.2 -google-cloud-translate==3.11.1 -ipython==8.14.0 -Levenshtein==0.21.1 -matplotlib==3.6.1 -ml-collections==0.1.1 -numpy==1.24.1 -pandas==1.5.3 -Pillow==10.0.0 -portpicker==1.5.2 -requests==2.31.0 -rouge-score==0.1.2 -sacrebleu==2.3.1 -saliency==0.1.3 -scikit-learn==1.0.2 -scipy==1.10.1 +absl-py>=1.4.0 +annoy>=1.17.3 +attrs>=22.1.0 +etils[epath]>=1.7.0 +filelock>=3.12.3 +google-cloud-translate>=3.11.1 +ipython>=7.34.0 +Levenshtein>=0.21.1 +matplotlib>=3.7.1 +ml-collections>=0.1.1 +numpy>=1.24.1 +pandas>=2.0.3 +Pillow>=10.0.0 +portpicker>=1.5.2 +requests>=2.31.0 +rouge-score>=0.1.2 +sacrebleu>=2.3.1 +saliency>=0.1.3 +scikit-learn>=1.0.2 +scipy>=1.10.1 shap==0.42.0 -six==1.16.0 -termcolor==2.3.0 -tqdm==4.64.0 -umap-learn==0.5.1 -werkzeug==2.3.6 +six>=1.16.0 +termcolor>=2.3.0 +tqdm>=4.64.0 +umap-learn>=0.5.1 +werkzeug>=2.2.3 # LINT.ThenChange(./pyproject.toml) diff --git a/requirements_examples.txt b/requirements_examples.txt index c19c4e1b..7778c6d7 100644 --- a/requirements_examples.txt +++ b/requirements_examples.txt @@ -18,6 +18,6 @@ sentencepiece==0.1.99 tensorflow==2.10.0 tensorflow-datasets==4.8.0 tensorflow-text==2.10.0 -torch==2.0.1 -transformers==4.27.1 +torch>=2.0.0 +transformers>=4.27.1 # LINT.ThenChange(./pyproject.toml) diff --git a/requirements_test.txt b/requirements_test.txt index 2628ea55..0f6cc418 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -14,5 +14,5 @@ # ============================================================================== # LINT.IfChange lime==0.2.0.1 -pytest==7.4.0 +pytest>=7.4.0,<8.0.0 # LINT.ThenChange(./pyproject.toml)