diff --git a/lit_nlp/examples/models/t5.py b/lit_nlp/examples/models/t5.py index b3a22cea..d21d0d05 100644 --- a/lit_nlp/examples/models/t5.py +++ b/lit_nlp/examples/models/t5.py @@ -46,6 +46,17 @@ class T5ModelConfig(object): token_top_k: int = 10 output_attention: bool = False + @classmethod + def init_spec(cls) -> lit_types.Spec: + return { + "inference_batch_size": lit_types.Integer(default=4, required=False), + "beam_size": lit_types.Integer(default=4, required=False), + "max_gen_length": lit_types.Integer(default=50, required=False), + "num_to_generate": lit_types.Integer(default=1, required=False), + "token_top_k": lit_types.Integer(default=10, required=False), + "output_attention": lit_types.Boolean(default=False, required=False), + } + def validate_t5_model(model: lit_model.Model) -> lit_model.Model: """Validate that a given model looks like a T5 model. @@ -112,6 +123,13 @@ def predict_minibatch(self, inputs): "output_text": m.decode("utf-8") } for m in model_outputs["outputs"].numpy()] + @classmethod + def init_spec(cls) -> lit_types.Spec: + return { + "saved_model_path": lit_types.String(), + **T5ModelConfig.init_spec(), + } + def input_spec(self): return { "input_text": lit_types.TextSegment(), @@ -314,6 +332,13 @@ def predict_minibatch(self, inputs): unbatched_outputs = utils.unbatch_preds(detached_outputs) return list(map(self._postprocess, unbatched_outputs)) + @classmethod + def init_spec(cls) -> lit_types.Spec: + return { + "model_name": lit_types.String(default="t5-small", required=False), + **T5ModelConfig.init_spec(), + } + def input_spec(self): return { "input_text": lit_types.TextSegment(), diff --git a/lit_nlp/examples/t5_demo.py b/lit_nlp/examples/t5_demo.py index 17c546b2..982d7306 100644 --- a/lit_nlp/examples/t5_demo.py +++ b/lit_nlp/examples/t5_demo.py @@ -21,6 +21,7 @@ from absl import flags from absl import logging +from lit_nlp import app as lit_app from lit_nlp import dev_server from lit_nlp import server_flags from lit_nlp.api import dataset as lit_dataset @@ -148,7 +149,13 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: model_name=model_name_or_path, num_to_generate=_NUM_TO_GEN.value, token_top_k=_TOKEN_TOP_K.value, - output_attention=False) + output_attention=False, + ) + + model_loaders: lit_app.ModelLoadersMap = { + "T5 Saved Model": (t5.T5SavedModel, t5.T5SavedModel.init_spec()), + "T5 HF Model": (t5.T5HFModel, t5.T5HFModel.init_spec()), + } ## # Load eval sets and model wrappers for each task. @@ -156,7 +163,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # and post-processing code. models = {} datasets = {} - + dataset_loaders: lit_app.DatasetLoadersMap = {} if "summarization" in _TASKS.value: for k, m in base_models.items(): models[k + "_summarization"] = t5.SummarizationWrapper(m) @@ -167,6 +174,11 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: datasets["CNNDM"] = summarization.CNNDMData( split="validation", max_examples=_MAX_EXAMPLES.value) + dataset_loaders["CNN DailyMail"] = ( + summarization.CNNDMData, + summarization.CNNDMData.init_spec(), + ) + if "mt" in _TASKS.value: for k, m in base_models.items(): models[k + "_translation"] = t5.TranslationWrapper(m) @@ -183,6 +195,8 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: else: datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True) + dataset_loaders["WMT 14"] = (mt.WMT14Data, mt.WMT14Data.init_spec()) + # Truncate datasets if --max_examples is set. for name in datasets: logging.info("Dataset: '%s' with %d examples", name, len(datasets[name])) @@ -207,7 +221,13 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: # Actually start the LIT server, using the models, datasets, and other # components constructed above. lit_demo = dev_server.Server( - models, datasets, generators=generators, **server_flags.get_flags()) + models, + datasets, + generators=generators, + model_loaders=model_loaders, + dataset_loaders=dataset_loaders, + **server_flags.get_flags(), + ) return lit_demo.serve()