diff --git a/lit_nlp/examples/is_eval/datasets.py b/lit_nlp/examples/is_eval/datasets.py deleted file mode 100644 index 25c3d7ed..00000000 --- a/lit_nlp/examples/is_eval/datasets.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Text classification dataset for binary, single input data.""" -from lit_nlp.api import dataset as lit_dataset -from lit_nlp.api import types as lit_types -import pandas as pd - - -class SingleInputClassificationFromTSV(lit_dataset.Dataset): - """TSV data loader for files having a single input text and a label. - - Files must be in TSV format with 2 columns in this order: - 1. Input text. - 2. Numeric label. - - Exported examples have 2 output keys: "sentence" and "label". - """ - - LABELS = ["0", "1"] - - def __init__(self, path: str, name: str = ""): - """Initializes a dataset for the Input Salience Eval demo. - - Args: - path: The path from which examples will be loaded. - name: Optionally, the name of the dataset. Used by ISEvalModel to - determine if the model is intended to be compatible with this dataset. - """ - self._examples = self.load_datapoints(path) - self.name = name - - def load_datapoints(self, path: str): - with open(path) as fd: - df = pd.read_csv(fd, sep="\t", header=None, names=["sentence", "label"]) - # pylint: disable=g-complex-comprehension - return [{ - "sentence": row["sentence"], - "label": self.LABELS[row["label"]], - } for _, row in df.iterrows()] - # pylint: enable=g-complex-comprehension - - def spec(self) -> lit_types.Spec: - return { - "sentence": lit_types.TextSegment(), - "label": lit_types.CategoryLabel(vocab=self.LABELS), - } diff --git a/lit_nlp/examples/is_eval/is_eval_demo.py b/lit_nlp/examples/is_eval/is_eval_demo.py deleted file mode 100644 index 0621adef..00000000 --- a/lit_nlp/examples/is_eval/is_eval_demo.py +++ /dev/null @@ -1,157 +0,0 @@ -r"""Example demo loading a handful of IS eval models. - -To run: - blaze run -c opt --config=cuda examples/is_eval:is_eval_demo -- \ - --port=5432 -""" -import sys - -from absl import app -from absl import flags -from absl import logging - -from lit_nlp import dev_server -from lit_nlp import server_flags -from lit_nlp.api import layout -from lit_nlp.examples.is_eval import datasets -from lit_nlp.examples.is_eval import models as is_eval_models -from lit_nlp.lib import file_cache - -# NOTE: additional flags defined in server_flags.py - -FLAGS = flags.FLAGS - -FLAGS.set_default("development_demo", True) -FLAGS.set_default("page_title", "Input Salience Evaluation Demo") - -_DOC_STRING = ( - "# Input Salience Evaluation Demo\nThis demo accompanies our " - "[paper](https://arxiv.org/abs/2211.05485) and " - "[blogpost](https://ai.googleblog.com/2022/12/will-you-find-these-shortcuts.html)" - " \"Will you find these shortcuts?\". We manually inserted one out of " - "three artificial data artifacts (shortcuts) into two datasets (SST2, " - "Toxicity). In the \"Explanations\" tab you can observe how different " - "input salience methods put different weights on the nonsense tokens " - "*zeroa*, *onea*, *synt*.") - -_MODELS = flags.DEFINE_list( - "models", - [ - "sst2_single_token:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_single_token_bert.tar.gz", - "sst2_token_in_context:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_token_in_context_bert.tar.gz", - "sst2_ordered_pair:https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_simple_order_bert.tar.gz", - "toxicity_single_token:https://storage.googleapis.com/what-if-tool-resources/lit-models/toxicity_single_token_bert.tar.gz", - "toxicity_token_in_context:https://storage.googleapis.com/what-if-tool-resources/lit-models/toxicity_token_in_context_bert.tar.gz", - "toxicity_ordered_pair:https://storage.googleapis.com/what-if-tool-resources/lit-models/toxicity_simple_order_bert.tar.gz", - ], - "List of models to load, as :. " - "Path should be the output of saving a transformer model, e.g. " - "model.save_pretrained(path) and tokenizer.save_pretrained(path). Remote " - ".tar.gz files will be downloaded and cached locally.", -) - -_MAX_EXAMPLES = flags.DEFINE_integer( - "max_examples", None, "Maximum number of examples to load into LIT. Set " - "--max_examples=200 for a quick start.") - -DATASETS = { - "sst2_single_token_dev_100_syn": "https://storage.googleapis.com/what-if-tool-resources/lit-data/sst2_single_token-dev.100syn.tsv", - "sst2_token_in_context_dev_100_syn": "https://storage.googleapis.com/what-if-tool-resources/lit-data/sst2_token_in_context-dev.100syn.tsv", - "sst2_ordered_pair_dev_100_syn": "https://storage.googleapis.com/what-if-tool-resources/lit-data/sst2_simple_order-dev.100syn.tsv", - "toxicity_single_token_dev_100_syn": "https://storage.googleapis.com/what-if-tool-resources/lit-data/toxicity_single_token-dev.100syn.tsv", - "toxicity_token_in_context_dev_100_syn": "https://storage.googleapis.com/what-if-tool-resources/lit-data/toxicity_token_in_context-dev.100syn.tsv", - "toxicity_ordered_pair_dev_100_syn": "https://storage.googleapis.com/what-if-tool-resources/lit-data/toxicity_simple_order-dev.100syn.tsv", -} - -modules = layout.LitModuleName -IS_EVAL_LAYOUT = layout.LitCanonicalLayout( - upper={ - "Main": [ - modules.DocumentationModule, - modules.EmbeddingsModule, - modules.DataTableModule, - modules.DatapointEditorModule, - ] - }, - lower={ - "Predictions": [ - modules.ClassificationModule, - modules.SalienceMapModule, - modules.ScalarModule, - ], - "Salience Clustering": [modules.SalienceClusteringModule], - "Metrics": [ - modules.MetricsModule, - modules.ConfusionMatrixModule, - modules.CurvesModule, - modules.ThresholderModule, - ], - "Counterfactuals": [ - modules.GeneratorModule, - ], - }, - description="Custom layout for evaluating input salience methods.") -CUSTOM_LAYOUTS = layout.DEFAULT_LAYOUTS | {"is_eval": IS_EVAL_LAYOUT} -# You can change this back via URL param, e.g. localhost:5432/?layout=default -FLAGS.set_default("default_layout", "is_eval") - - -def get_wsgi_app(): - """Return WSGI app for container-hosted demos.""" - FLAGS.set_default("server_type", "external") - FLAGS.set_default("demo_mode", True) - FLAGS.set_default("warm_start", 1.0) - FLAGS.set_default("max_examples", 1000) - # Parse flags without calling app.run(main), to avoid conflict with - # gunicorn command line flags. - unused = flags.FLAGS(sys.argv, known_only=True) - if unused: - logging.info("is_eval_demo:get_wsgi_app() called with unused args: %s", - unused) - return main([]) - - -def main(_): - models = {} - loaded_datasets = {} - - for model_string in _MODELS.value: - # Only split on the first two ':', because path may be a URL - # containing 'https://' - name, path = model_string.split(":", 1) - logging.info("Loading model '%s' from '%s'", name, path) - # Normally path is a directory; if it's an archive file, download and - # extract to the transformers cache. - if path.endswith(".tar.gz"): - path = file_cache.cached_path( - path, extract_compressed_file=True) - # Load the model from disk. - models[name] = is_eval_models.ISEvalModel( - name, path, output_attention=False) - - logging.info("Loading data for SST-2 task.") - for data_key, url in DATASETS.items(): - path = file_cache.cached_path(url) - loaded_datasets[data_key] = datasets.SingleInputClassificationFromTSV( - path, data_key) - - # Truncate datasets if --max_examples is set. - for name in loaded_datasets: - logging.info("Dataset: '%s' with %d examples", name, - len(loaded_datasets[name])) - loaded_datasets[name] = loaded_datasets[name].shuffle().slice[:_MAX_EXAMPLES - .value] - logging.info(" truncated to %d examples", len(loaded_datasets[name])) - - # Start the LIT server. See server_flags.py for server options. - lit_demo = dev_server.Server( - models, - loaded_datasets, - layouts=CUSTOM_LAYOUTS, - onboard_end_doc=_DOC_STRING, - **server_flags.get_flags()) - return lit_demo.serve() - - -if __name__ == "__main__": - app.run(main) diff --git a/lit_nlp/examples/is_eval/is_eval_trainer.py b/lit_nlp/examples/is_eval/is_eval_trainer.py deleted file mode 100644 index a7a53da5..00000000 --- a/lit_nlp/examples/is_eval/is_eval_trainer.py +++ /dev/null @@ -1,138 +0,0 @@ -r"""Lightweight trainer script to fine-tune a model for IS eval. - -Usage: - python -m lit_nlp.examples.tools.is_eval_trainer \ - --encoder_name=bert-base-uncased \ - --train_path=/path/to/saved/model \ - --train_data_path=/path/to/train/data \ - --dev_data_path=/path/to/dev/data \ - -This will finetune a BERT model to reproduce findings of the paper ""Will You -Find These Shortcuts?" A Protocol for Evaluating the Faithfulness of Input -Salience Methods for Text Classification" [https://arxiv.org/abs/2111.07367]. - -Please ensure that the model's vocabulary file includes all special shortcut -tokens. When using the provided datasets of the LIT demo these are: -"ZEROA", "ZEROB", "ONEA", "ONEB", "onea", "oneb", "zeroa", "zerob", "synt". - -This will train a BERT-base model [https://arxiv.org/abs/1810.04805] -which give validation accuracy in the low 90s on SST-2. - -Note: you don't have to use this trainer to use LIT; the classifier -implementation is just a wrapper around HuggingFace Transformers, using -AutoTokenizer, AutoConfig, and TFAutoModelForSequenceClassification, and can -load anything compatible with those classes. -""" - -from collections.abc import Sequence -import os - -from absl import app -from absl import flags -from absl import logging - -from lit_nlp.examples.is_eval import datasets -from lit_nlp.examples.models import glue_models -from lit_nlp.lib import serialize -import tensorflow as tf - -_ENCODER_NAME = flags.DEFINE_string( - "encoder_name", "bert-base-uncased", - "Model name or path to pretrained (base) encoder.") -_TRAIN_DATA_PATH = flags.DEFINE_string("train_data_path", None, "") -_DEV_DATA_PATH = flags.DEFINE_string("dev_data_path", None, "") -_TRAIN_PATH = flags.DEFINE_string("train_path", "/tmp/hf_demo", - "Path to save fine-tuned model.") - -_NUM_EPOCHS = flags.DEFINE_integer( - "num_epochs", 3, "Number of epochs to train for.", lower_bound=1) -_SAVE_INTERMEDIATES = flags.DEFINE_bool( - "save_intermediates", False, - "If true, save intermediate weights after each epoch.") - - -def history_to_dict(keras_history): - return { - "epochs": keras_history.epoch, - "history": keras_history.history, - "params": keras_history.params, - "optimizer_params": keras_history.model.optimizer.get_config(), - } - - -class EpochSaverCallback(tf.keras.callbacks.Callback): - """Save model at the beginning of training and after every epoch. - - Similar to tf.keras.callbacks.ModelCheckpoint, but this allows us to specify - a custom save fn to call, such as the HuggingFace model.save() which writes - .h5 files and config information. - """ - - def __init__(self, save_path_base: str, save_fn=None): - super().__init__() - self.save_path_base = save_path_base - self.save_fn = save_fn or self.model.save - - def on_train_begin(self, logs=None): - self.on_epoch_end(-1, logs=logs) # write epoch-0 - - def on_epoch_end(self, epoch, logs=None): - # Save path 1-indexed = # of completed epochs. - save_path = os.path.join(self.save_path_base, f"epoch-{epoch+1}") - self.save_fn(save_path) - - -def train_and_save(model, - train_data, - val_data, - train_path, - save_intermediates=False, - **train_kw): - """Run training and save model.""" - # Set up logging for TensorBoard. To view, run: - # tensorboard --log_dir=/tensorboard - keras_callbacks = [ - tf.keras.callbacks.TensorBoard( - log_dir=os.path.join(train_path, "tensorboard")) - ] - if save_intermediates: - keras_callbacks.append(EpochSaverCallback(train_path, save_fn=model.save)) - history = model.train( - train_data.examples, - validation_inputs=val_data.examples, - keras_callbacks=keras_callbacks, - **train_kw) - - # Save training history too, since this is human-readable and more concise - # than the TensorBoard log files. - with open(os.path.join(train_path, "train.history.json"), "w") as fd: - # Use LIT's custom JSON encoder to handle dicts containing NumPy data. - fd.write(serialize.to_json(history_to_dict(history), simple=True, indent=2)) - - model.save(train_path) - logging.info("Saved model files: \n %s", - "\n ".join(os.listdir(train_path))) - - -def main(argv: Sequence[str]) -> None: - if len(argv) > 1: - raise app.UsageError("Too many command-line arguments.") - - model = glue_models.SST2Model(_ENCODER_NAME.value) - train_data = datasets.SingleInputClassificationFromTSV(_TRAIN_DATA_PATH.value) - dev_data = datasets.SingleInputClassificationFromTSV(_DEV_DATA_PATH.value) - - train_and_save( - model, - train_data, - dev_data, - _TRAIN_PATH.value, - save_intermediates=_SAVE_INTERMEDIATES.value, - num_epochs=_NUM_EPOCHS.value, - learning_rate=1e-5, - batch_size=16, - ) - - -if __name__ == "__main__": - app.run(main) diff --git a/lit_nlp/examples/is_eval/models.py b/lit_nlp/examples/is_eval/models.py deleted file mode 100644 index d6765df7..00000000 --- a/lit_nlp/examples/is_eval/models.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Custom GLUE Model and ModelSpec for the Input Salience Evaluation demo.""" -from typing import cast -from lit_nlp.api import dataset as lit_dataset -from lit_nlp.examples.is_eval import datasets as is_eval_datasets -from lit_nlp.examples.models import glue_models - - -class ISEvalModel(glue_models.SST2Model): - """Custom GLUE model for the Input Salience Evaluation demo.""" - - def __init__(self, model_name: str, *args, **kw): - """Initializes a custom SST-2 model for the Input Salience Eval demo. - - Args: - model_name: The model's name. Used to determine dataset compatibility. - *args: Additional positional args to pass to the SST2Model base class. - **kw: Additional keyword args to pass to the SST2Model base class. - """ - super().__init__(*args, **kw) - self._model_name = model_name - - def is_compatible_with_dataset(self, dataset: lit_dataset.Dataset) -> bool: - """Returns true if the model is compatible with the dataset. - - The Input Salience Eval demo is somewhat unique in that each model and - dataset have compatible specs but the intention is to pair them for - specific tasks. - - This class determines compatibility by: - - 1. Ensuring that the value of `model_name` is contained in the `default` - value of the `dataset_name` field in the provided `dataset_spec`. - 2. Calling super().is_compatible_with_dataset() to check compatibility - using the base ModelSpec check. - - Args: - dataset: The dataset for which compatibility will be determined. - """ - if not isinstance(dataset, - is_eval_datasets.SingleInputClassificationFromTSV): - return False - - eval_dataset = cast(is_eval_datasets.SingleInputClassificationFromTSV, - dataset) - if self.model_name in eval_dataset.name: - return super().is_compatible_with_dataset(dataset) - else: - return False