Skip to content

Commit

Permalink
Adds init spec support to T5 demo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506287801
  • Loading branch information
RyanMullins authored and LIT team committed Feb 1, 2023
1 parent 9eebe57 commit bcc6c09
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
25 changes: 25 additions & 0 deletions lit_nlp/examples/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class T5ModelConfig(object):
token_top_k: int = 10
output_attention: bool = False

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"inference_batch_size": lit_types.Integer(default=4, required=False),
"beam_size": lit_types.Integer(default=4, required=False),
"max_gen_length": lit_types.Integer(default=50, required=False),
"num_to_generate": lit_types.Integer(default=1, required=False),
"token_top_k": lit_types.Integer(default=10, required=False),
"output_attention": lit_types.Boolean(default=False, required=False),
}


def validate_t5_model(model: lit_model.Model) -> lit_model.Model:
"""Validate that a given model looks like a T5 model.
Expand Down Expand Up @@ -112,6 +123,13 @@ def predict_minibatch(self, inputs):
"output_text": m.decode("utf-8")
} for m in model_outputs["outputs"].numpy()]

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"saved_model_path": lit_types.String(),
**T5ModelConfig.init_spec(),
}

def input_spec(self):
return {
"input_text": lit_types.TextSegment(),
Expand Down Expand Up @@ -314,6 +332,13 @@ def predict_minibatch(self, inputs):
unbatched_outputs = utils.unbatch_preds(detached_outputs)
return list(map(self._postprocess, unbatched_outputs))

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"model_name": lit_types.String(default="t5-small", required=False),
**T5ModelConfig.init_spec(),
}

def input_spec(self):
return {
"input_text": lit_types.TextSegment(),
Expand Down
26 changes: 23 additions & 3 deletions lit_nlp/examples/t5_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl import flags
from absl import logging

from lit_nlp import app as lit_app
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import dataset as lit_dataset
Expand Down Expand Up @@ -148,15 +149,21 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
model_name=model_name_or_path,
num_to_generate=_NUM_TO_GEN.value,
token_top_k=_TOKEN_TOP_K.value,
output_attention=False)
output_attention=False,
)

model_loaders: lit_app.ModelLoadersMap = {
"T5 Saved Model": (t5.T5SavedModel, t5.T5SavedModel.init_spec()),
"T5 HF Model": (t5.T5HFModel, t5.T5HFModel.init_spec()),
}

##
# Load eval sets and model wrappers for each task.
# Model wrappers share the same in-memory T5 model, but add task-specific pre-
# and post-processing code.
models = {}
datasets = {}

dataset_loaders: lit_app.DatasetLoadersMap = {}
if "summarization" in _TASKS.value:
for k, m in base_models.items():
models[k + "_summarization"] = t5.SummarizationWrapper(m)
Expand All @@ -167,6 +174,11 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
datasets["CNNDM"] = summarization.CNNDMData(
split="validation", max_examples=_MAX_EXAMPLES.value)

dataset_loaders["CNN DailyMail"] = (
summarization.CNNDMData,
summarization.CNNDMData.init_spec(),
)

if "mt" in _TASKS.value:
for k, m in base_models.items():
models[k + "_translation"] = t5.TranslationWrapper(m)
Expand All @@ -183,6 +195,8 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
else:
datasets["wmt14_enfr"] = mt.WMT14Data(version="fr-en", reverse=True)

dataset_loaders["WMT 14"] = (mt.WMT14Data, mt.WMT14Data.init_spec())

# Truncate datasets if --max_examples is set.
for name in datasets:
logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
Expand All @@ -207,7 +221,13 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# Actually start the LIT server, using the models, datasets, and other
# components constructed above.
lit_demo = dev_server.Server(
models, datasets, generators=generators, **server_flags.get_flags())
models,
datasets,
generators=generators,
model_loaders=model_loaders,
dataset_loaders=dataset_loaders,
**server_flags.get_flags(),
)
return lit_demo.serve()


Expand Down

0 comments on commit bcc6c09

Please sign in to comment.