diff --git a/documentation/api.md b/documentation/api.md index a6acfa7a..3dec3757 100644 --- a/documentation/api.md +++ b/documentation/api.md @@ -70,6 +70,26 @@ and [`Model`](#models) classes implement this, and provide metadata (see the For pre-built `demo.py` examples, check out /~https://github.com/PAIR-code/lit/tree/main/lit_nlp/examples +### Validating Models and Data + +Datasets and models can optionally be validated by LIT to ensure that dataset +examples match their spec and that model output values match their spec. +This can be very helpful during development of new model and dataset wrappers +to ensure correct behavior in LIT. + +At LIT server startup, the `validate` runtime flag can be used to enable +validation. +Setting the flag to `first` will validate the first example in each dataset for +correctly typed values and validate it with each model it is compatible with, to +ensure that the model outputs are also correctly typed. Setting it to `sample` +will validate against a sample of 5% of each dataset. Setting it to `all` will +validate all examples in all datasets. By default, no validation is performed, +to enable quick startup. + +Additionally, if using LIT datasets and models outside of the LIT server, +validation can be called directly through the +[`validation`](../lit_nlp/lib/validation.py) module. + ## Datasets Datasets ([`Dataset`](../lit_nlp/api/dataset.py)) are diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py index c5b11d24..9521ebd0 100644 --- a/lit_nlp/api/types.py +++ b/lit_nlp/api/types.py @@ -25,16 +25,20 @@ should be rendered. """ import abc +import math +import numbers from typing import Any, NewType, Optional, Sequence, TypedDict, Union import attr from lit_nlp.api import dtypes +import numpy as np JsonDict = dict[str, Any] Input = NewType("Input", JsonDict) ExampleId = NewType("ExampleId", str) ScoredTextCandidates = Sequence[tuple[str, Optional[float]]] TokenTopKPredsList = Sequence[ScoredTextCandidates] +NumericTypes = numbers.Number class InputMetadata(TypedDict): @@ -62,6 +66,46 @@ class LitType(metaclass=abc.ABCMeta): # TODO(lit-dev): Add defaults for all LitTypes default = None # an optional default value for a given type. + def validate_input(self, value: Any, spec: "Spec", example: Input): + """Validate a dataset example's value against its spec in an example. + + Subtypes should override to validate a provided value and raise a ValueError + if the value is not valid. + + Args: + value: The value to validate against the specific LitType. + spec: The spec of the dataset. + example: The entire example of which the value is a part of. + + Raises: + ValueError if validation fails. + """ + pass + + def validate_output(self, value: Any, output_spec: "Spec", + output_dict: JsonDict, input_spec: "Spec", + dataset_spec: "Spec", input_example: Input): + """Validate a model output value against its spec and input example. + + Subtypes should override to validate a provided value and raise a ValueError + if the value is not valid. + + Args: + value: The value to validate against the specific LitType. + output_spec: The output spec of the model. + output_dict: The entire model output for the example. + input_spec: The input spec of the model. + dataset_spec: The dataset spec. + input_example: The example from which the output value is returned. + + Raises: + ValueError if validation fails. + """ + del output_spec, output_dict, dataset_spec + # If not overwritten by a LitType, then validate it as an input to re-use + # simple validation code. + self.validate_input(value, input_spec, input_example) + def is_compatible(self, other): """Check equality, ignoring some fields.""" # We allow this class to be a subclass of the other. @@ -128,6 +172,10 @@ class StringLitType(LitType): """ default: str = "" + def validate_input(self, value, spec: Spec, example: Input): + if not isinstance(value, str): + raise ValueError(f"{value} is of type {type(value)}, expected str") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class TextSegment(StringLitType): @@ -138,7 +186,10 @@ class TextSegment(StringLitType): @attr.s(auto_attribs=True, frozen=True, kw_only=True) class ImageBytes(LitType): """An image, an encoded base64 ascii string (starts with 'data:image...').""" - pass + + def validate_input(self, value, spec: Spec, example: Input): + if not isinstance(value, str) or not value.startswith("data:image"): + raise ValueError(f"{value} is not an encoded image string.") @attr.s(auto_attribs=True, frozen=True, kw_only=True) @@ -147,6 +198,15 @@ class GeneratedText(TextSegment): # Name of a TextSegment field to evaluate against parent: Optional[str] = None + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + if not isinstance(value, str): + raise ValueError(f"{value} is of type {type(value)}, expected str") + if self.parent and not isinstance(input_spec[self.parent], TextSegment): + raise ValueError(f"parent field {self.parent} is of type " + f"{type(self.parent)}, expected TextSegment") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class ListLitType(LitType): @@ -159,6 +219,17 @@ class _StringCandidateList(ListLitType): """A list of (text, score) tuples.""" default: ScoredTextCandidates = None + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + if not isinstance(value, list): + raise ValueError(f"{value} is not a list") + + for v in value: + if not (isinstance(v, tuple) and isinstance(v[0], str) and + (v[1] is None or isinstance(v[1], NumericTypes))): + raise ValueError(f"{v} list item is not a (str, float) tuple)") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class GeneratedTextCandidates(_StringCandidateList): @@ -170,6 +241,16 @@ class GeneratedTextCandidates(_StringCandidateList): def top_text(value: ScoredTextCandidates) -> str: return value[0][0] if len(value) else "" + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + if self.parent and not isinstance(input_spec[self.parent], TextSegment): + raise ValueError(f"parent field {self.parent} is of type " + f"{type(input_spec[self.parent])}, expected TextSegment") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class ReferenceTexts(_StringCandidateList): @@ -194,6 +275,14 @@ class GeneratedURL(TextSegment): """A URL that was generated as part of a model prediction.""" align: Optional[str] = None # name of a field in the model output + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output(value, output_spec, output_dict, input_spec, + dataset_spec, input_example) + if self.align and self.align not in output_spec: + raise ValueError(f"aligned field {self.align} is not in output_spec") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class SearchQuery(TextSegment): @@ -206,6 +295,11 @@ class _StringList(ListLitType): """A list of strings.""" default: Sequence[str] = [] + def validate_input(self, value, spec: Spec, example: Input): + if not isinstance(value, list) or not all( + [isinstance(v, str) for v in value]): + raise ValueError(f"{value} is not a list of strings") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class Tokens(_StringList): @@ -229,6 +323,36 @@ class TokenTopKPreds(ListLitType): align: str = None # name of a Tokens field in the model output parent: Optional[str] = None + def _validate_scored_candidates(self, scored_candidates): + """Validates a list of scored candidates.""" + prev_val = math.inf + for scored_candidate in scored_candidates: + if not isinstance(scored_candidate, tuple): + raise ValueError(f"{scored_candidate} is not a tuple") + if not isinstance(scored_candidate[0], str): + raise ValueError(f"{scored_candidate} first element is not a str") + if scored_candidate[1] is not None: + if not isinstance(scored_candidate[1], NumericTypes): + raise ValueError(f"{scored_candidate} second element is not a num") + if prev_val < scored_candidate[1]: + raise ValueError( + "TokenTopKPreds candidates are not in descending order") + else: + prev_val = scored_candidate[1] + + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + + if not isinstance(value, list): + raise ValueError(f"{value} is not a list of scored text candidates") + for scored_candidates in value: + self._validate_scored_candidates(scored_candidates) + if self.align and not isinstance(output_spec[self.align], Tokens): + raise ValueError( + f"aligned field {self.align} is {type(output_spec[self.align])}, " + "expected Tokens") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class Scalar(LitType): @@ -238,6 +362,10 @@ class Scalar(LitType): default: float = 0 step: float = .01 + def validate_input(self, value, spec: Spec, example: Input): + if not isinstance(value, NumericTypes): + raise ValueError(f"{value} is of type {type(value)}, expected a number") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class RegressionScore(Scalar): @@ -245,6 +373,15 @@ class RegressionScore(Scalar): # name of a Scalar or RegressionScore field in input parent: Optional[str] = None + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + if not isinstance(value, NumericTypes): + raise ValueError(f"{value} is of type {type(value)}, expected a number") + if self.parent and not isinstance(dataset_spec[self.parent], Scalar): + raise ValueError(f"parent field {self.parent} is of type " + f"{type(self.parent)}, expected Scalar") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class ReferenceScores(ListLitType): @@ -254,6 +391,23 @@ class ReferenceScores(ListLitType): # name of a TextSegment or ReferenceTexts field in the input parent: Optional[str] = None + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + if isinstance(value, list): + if not all([isinstance(v, NumericTypes) for v in value]): + raise ValueError(f"{value} is of type {type(value)}, expected a list " + "of numbers") + elif not isinstance(value, np.ndarray) or not np.issubdtype( + value.dtype, np.number): + raise ValueError(f"{value} is of type {type(value)}, expected a list of " + "numbers") + if self.parent and not isinstance( + input_spec[self.parent], (TextSegment, ReferenceTexts)): + raise ValueError(f"parent field {self.parent} is of type " + f"{type(self.parent)}, expected TextSegment or " + "ReferenceTexts") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class CategoryLabel(StringLitType): @@ -262,15 +416,52 @@ class CategoryLabel(StringLitType): # If omitted, any value is accepted. vocab: Optional[Sequence[str]] = None # label names + def validate_input(self, value, spec: Spec, example: Input): + if not isinstance(value, str): + raise ValueError(f"{value} is of type {type(value)}, expected str") + if self.vocab and value not in list(self.vocab): + raise ValueError(f"{value} is not in provided vocab") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class _Tensor1D(LitType): +class _Tensor(LitType): """A tensor type.""" default: Sequence[float] = None - -@attr.s(auto_attribs=True, frozen=True, kw_only=True) -class MulticlassPreds(_Tensor1D): + def validate_input(self, value, spec: Spec, example: Input): + if isinstance(value, list): + if not all([isinstance(v, NumericTypes) for v in value]): + raise ValueError(f"{value} is not a list of numbers") + elif isinstance(value, np.ndarray): + if not np.issubdtype(value.dtype, np.number): + raise ValueError(f"{value} is not an array of numbers") + else: + raise ValueError(f"{value} is not a list or ndarray of numbers") + + def validate_ndim(self, value, ndim: Union[int, list[int]]): + """Validate the number of dimensions in a tensor. + + Args: + value: The tensor to validate. + ndim: Either a number of dimensions to validate that the value has, or + a list of dimensions any of which are valid for the value to have. + + Raises: + ValueError if validation fails. + """ + if isinstance(ndim, int): + ndim = [ndim] + if isinstance(value, np.ndarray): + if value.ndim not in ndim: + raise ValueError(f"{value} ndim is not one of {ndim}") + else: + if 1 not in ndim: + raise ValueError(f"{value} ndim is not 1. " + "Use a numpy array for multidimensional arrays") + + +@attr.s(auto_attribs=True, frozen=True, kw_only=True) +class MulticlassPreds(_Tensor): """Multiclass predicted probabilities, as [num_labels].""" # Vocabulary is required here for decoding model output. # Usually this will match the vocabulary in the corresponding label field. @@ -284,6 +475,21 @@ class MulticlassPreds(_Tensor1D): def num_labels(self): return len(self.vocab) + def validate_input(self, value, spec: Spec, example: Input): + super().validate_input(value, spec, example) + if self.null_idx is not None: + if self.null_idx < 0 or self.null_idx >= self.num_labels: + raise ValueError(f"null_idx {self.null_idx} is not in the vocab range") + + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + self.validate_input(value, output_spec, input_example) + if self.parent and not isinstance( + dataset_spec[self.parent], CategoryLabel): + raise ValueError(f"parent field {self.parent} is of type " + f"{type(self.parent)}, expected CategoryLabel") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class SequenceTags(_StringList): @@ -305,6 +511,15 @@ class SpanLabels(ListLitType): align: str # name of Tokens field parent: Optional[str] = None + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + if not isinstance(value, list) or not all( + [isinstance(v, dtypes.SpanLabel) for v in value]): + raise ValueError(f"{value} is not a list of SpanLabels") + if not isinstance(output_spec[self.align], Tokens): + raise ValueError(f"{self.align} is not a Tokens field") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class EdgeLabels(ListLitType): @@ -319,6 +534,15 @@ class EdgeLabels(ListLitType): default: Sequence[dtypes.EdgeLabel] = None align: str # name of Tokens field + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + if not isinstance(value, list) or not all( + [isinstance(v, dtypes.EdgeLabel) for v in value]): + raise ValueError(f"{value} is not a list of EdgeLabel") + if not isinstance(output_spec[self.align], Tokens): + raise ValueError(f"{self.align} is not a Tokens field") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class MultiSegmentAnnotations(ListLitType): @@ -338,19 +562,31 @@ class MultiSegmentAnnotations(ListLitType): exclusive: bool = False # if true, treat as candidate list background: bool = False # if true, don't emphasize in visualization + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + if not isinstance(value, list) or not all( + [isinstance(v, dtypes.AnnotationCluster) for v in value]): + raise ValueError(f"{value} is not a list of AnnotationCluster") ## # Model internals, for interpretation. @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class Embeddings(_Tensor1D): +class Embeddings(_Tensor): """Embeddings or model activations, as fixed-length [emb_dim].""" - pass + + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output(value, output_spec, output_dict, input_spec, + dataset_spec, input_example) + self.validate_ndim(value, 1) @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class _GradientsBase(_Tensor1D): +class _GradientsBase(_Tensor): """Shared gradient attributes.""" align: Optional[str] = None # name of a Tokens field grad_for: Optional[str] = None # name of Embeddings field @@ -358,44 +594,116 @@ class _GradientsBase(_Tensor1D): # for the gradients. grad_target_field_key: Optional[str] = None + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + if self.align is not None: + align_entry = (output_spec[self.align] if self.align in output_spec + else input_spec[self.align]) + if not isinstance(align_entry, (Tokens, ImageBytes)): + raise ValueError(f"{self.align} is not a Tokens or ImageBytes field") + if self.grad_for is not None and not isinstance( + output_spec[self.grad_for], (Embeddings, TokenEmbeddings)): + raise ValueError(f"{self.grad_for} is not a Embeddings field") + if (self.grad_target_field_key is not None and + self.grad_target_field_key not in input_spec): + raise ValueError(f"{self.grad_target_field_key} is not in input_spec") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class Gradients(_GradientsBase): """1D gradients with respect to embeddings.""" - pass + + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + self.validate_ndim(value, 1) @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class _InfluenceEncodings(_Tensor1D): +class _InfluenceEncodings(_Tensor): """A single vector of [enc_dim].""" grad_target: Optional[str] = None # class for computing gradients (string) + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + self.validate_ndim(value, 1) + @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class TokenEmbeddings(_Tensor1D): +class TokenEmbeddings(_Tensor): """Per-token embeddings, as [num_tokens, emb_dim].""" align: Optional[str] = None # name of a Tokens field + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + self.validate_ndim(value, 2) + if self.align is not None and not isinstance( + output_spec[self.align], Tokens): + raise ValueError(f"{self.align} is not a Tokens field") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class TokenGradients(_GradientsBase): """Gradients with respect to per-token inputs, as [num_tokens, emb_dim].""" - pass + + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + self.validate_ndim(value, 2) @attr.s(auto_attribs=True, frozen=True, kw_only=True) class ImageGradients(_GradientsBase): """Gradients with respect to per-pixel inputs, as a multidimensional array.""" - pass + + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + self.validate_ndim(value, [2, 3]) @attr.s(auto_attribs=True, frozen=True, kw_only=True) -class AttentionHeads(_Tensor1D): +class AttentionHeads(_Tensor): """One or more attention heads, as [num_heads, num_tokens, num_tokens].""" # input and output Tokens fields; for self-attention these can be the same align_in: str align_out: str + def validate_output(self, value, output_spec: Spec, output_dict: JsonDict, + input_spec: Spec, dataset_spec: Spec, + input_example: Input): + super().validate_output( + value, output_spec, output_dict, input_spec, dataset_spec, + input_example) + self.validate_ndim(value, 3) + if self.align_in is None or not isinstance( + output_spec[self.align_in], Tokens): + raise ValueError(f"{self.align_in} is not a Tokens field") + if self.align_out is None or not isinstance( + output_spec[self.align_out], Tokens): + raise ValueError(f"{self.align_out} is not a Tokens field") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class SubwordOffsets(ListLitType): @@ -498,6 +806,10 @@ class BooleanLitType(LitType): """Boolean value.""" default: bool = False + def validate_input(self, value, spec, example: Input): + if not isinstance(value, bool): + raise ValueError(f"{value} is not a boolean") + @attr.s(auto_attribs=True, frozen=True, kw_only=True) class CurveDataPoints(LitType): diff --git a/lit_nlp/api/types_test.py b/lit_nlp/api/types_test.py index 63995579..f387bd3d 100644 --- a/lit_nlp/api/types_test.py +++ b/lit_nlp/api/types_test.py @@ -1,7 +1,9 @@ """Tests for types.""" from absl.testing import absltest +from lit_nlp.api import dtypes from lit_nlp.api import types +import numpy as np class TypesTest(absltest.TestCase): @@ -28,6 +30,393 @@ def test_inherit_parent_custom_properties(self): self.assertTrue(hasattr(lit_type, "align")) self.assertFalse(hasattr(lit_type, "not_a_property")) + def test_tensor_ndim(self): + emb = types.Embeddings() + try: + emb.validate_ndim([1, 2, 3], 1) + emb.validate_ndim(np.array([1, 2, 3]), 1) + emb.validate_ndim(np.array([[1, 1], [2, 3]]), 2) + emb.validate_ndim(np.array([[1, 1], [2, 3]]), [2, 4]) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises(ValueError, emb.validate_ndim, [1, 2, 3], 2) + self.assertRaises(ValueError, emb.validate_ndim, np.array([[1, 1], [2, 3]]), + [1]) + + def test_type_validate_input(self): + spec = { + "score": types.Scalar(), + "text": types.TextSegment(), + } + example = {} + scalar = types.Scalar() + text = types.TextSegment() + img = types.ImageBytes() + tok = types.Tokens() + emb = types.Embeddings() + bl = types.Boolean() + try: + scalar.validate_input(3.4, spec, example) + scalar.validate_input(3, spec, example) + scalar.validate_input(np.int64(2), spec, example) + text.validate_input("hi", spec, example) + img.validate_input("data:image/blah...", spec, example) + tok.validate_input(["a", "b"], spec, example) + emb.validate_input([1, 2], spec, example) + emb.validate_input(np.array([1, 2]), spec, example) + bl.validate_input(True, spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises(ValueError, scalar.validate_input, "hi", spec, example) + self.assertRaises(ValueError, img.validate_input, "hi", spec, example) + self.assertRaises(ValueError, text.validate_input, 4, spec, example) + self.assertRaises(ValueError, tok.validate_input, [1], spec, example) + self.assertRaises(ValueError, emb.validate_input, ["a"], spec, example) + self.assertRaises(ValueError, bl.validate_input, 4, spec, example) + + def test_type_validate_gentext_output(self): + ds_spec = { + "num": types.Scalar(), + "text": types.TextSegment(), + } + out_spec = { + "gentext": types.GeneratedText(parent="text"), + "cands": types.GeneratedTextCandidates(parent="text") + } + example = {"num": 1, "text": "hi"} + output = {"gentext": "test", "cands": [("hi", 4), ("bye", None)]} + + gentext = types.GeneratedText(parent="text") + gentextcands = types.GeneratedTextCandidates(parent="text") + try: + gentext.validate_output("hi", out_spec, output, ds_spec, ds_spec, example) + gentextcands.validate_output([("hi", 4), ("bye", None)], out_spec, output, + ds_spec, ds_spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + bad_gentext = types.GeneratedText(parent="num") + self.assertRaises(ValueError, bad_gentext.validate_output, "hi", out_spec, + output, ds_spec, ds_spec, example) + + self.assertRaises(ValueError, gentextcands.validate_output, + [("hi", "wrong"), ("bye", None)], out_spec, output, + ds_spec, ds_spec, example) + bad_gentextcands = types.GeneratedTextCandidates(parent="num") + self.assertRaises(ValueError, bad_gentextcands.validate_output, + [("hi", 4), ("bye", None)], out_spec, output, ds_spec, + ds_spec, example) + + def test_type_validate_genurl(self): + ds_spec = { + "text": types.TextSegment(), + } + out_spec = { + "genurl": types.GeneratedURL(align="cands"), + "cands": types.GeneratedTextCandidates(parent="text") + } + example = {"text": "hi"} + output = {"genurl": "https://blah", "cands": [("hi", 4), ("bye", None)]} + + genurl = types.GeneratedURL(align="cands") + try: + genurl.validate_output("https://blah", out_spec, output, ds_spec, ds_spec, + example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises(ValueError, genurl.validate_output, 4, + out_spec, output, ds_spec, ds_spec, example) + bad_genurl = types.GeneratedURL(align="wrong") + self.assertRaises(ValueError, bad_genurl.validate_output, "https://blah", + out_spec, output, ds_spec, ds_spec, example) + + def test_tokentopk(self): + ds_spec = { + "text": types.TextSegment(), + } + out_spec = { + "tokens": types.Tokens(), + "preds": types.TokenTopKPreds(align="tokens") + } + example = {"text": "hi"} + output = {"tokens": ["hi"], "preds": [[("one", .9), ("two", .4)]]} + + preds = types.TokenTopKPreds(align="tokens") + try: + preds.validate_output( + [[("one", .9), ("two", .4)]], out_spec, output, ds_spec, ds_spec, + example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises( + ValueError, preds.validate_output, + [[("one", .2), ("two", .4)]], out_spec, output, ds_spec, ds_spec, + example) + self.assertRaises( + ValueError, preds.validate_output, + [["one", "two"]], out_spec, output, ds_spec, ds_spec, example) + self.assertRaises( + ValueError, preds.validate_output, ["wrong"], out_spec, output, + ds_spec, ds_spec, example) + + bad_preds = types.TokenTopKPreds(align="preds") + self.assertRaises( + ValueError, bad_preds.validate_output, + [[("one", .9), ("two", .4)]], out_spec, output, ds_spec, ds_spec, + example) + + def test_regression(self): + ds_spec = { + "val": types.Scalar(), + "text": types.TextSegment(), + } + out_spec = { + "score": types.RegressionScore(parent="val"), + } + example = {"val": 2} + output = {"score": 1} + + score = types.RegressionScore(parent="val") + try: + score.validate_output(1, out_spec, output, ds_spec, ds_spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises(ValueError, score.validate_output, "wrong", + out_spec, output, ds_spec, ds_spec, example) + bad_score = types.RegressionScore(parent="text") + self.assertRaises(ValueError, bad_score.validate_output, 1, + out_spec, output, ds_spec, ds_spec, example) + + def test_reference(self): + ds_spec = { + "text": types.TextSegment(), + "val": types.Scalar(), + } + out_spec = { + "scores": types.ReferenceScores(parent="text"), + } + example = {"text": "hi"} + output = {"scores": [1, 2]} + + score = types.ReferenceScores(parent="text") + try: + score.validate_output([1, 2], out_spec, output, ds_spec, ds_spec, example) + score.validate_output(np.array([1, 2]), out_spec, output, ds_spec, + ds_spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises(ValueError, score.validate_output, ["a"], + out_spec, output, ds_spec, ds_spec, example) + bad_score = types.ReferenceScores(parent="val") + self.assertRaises(ValueError, bad_score.validate_output, [1], + out_spec, output, ds_spec, ds_spec, example) + + def test_multiclasspreds(self): + ds_spec = { + "label": types.CategoryLabel(), + "val": types.Scalar(), + } + out_spec = { + "scores": types.MulticlassPreds( + parent="label", null_idx=0, vocab=["a", "b"]), + } + example = {"label": "hi", "val": 1} + output = {"scores": [1, 2]} + + score = types.MulticlassPreds(parent="label", null_idx=0, vocab=["a", "b"]) + try: + score.validate_output([1, 2], out_spec, output, ds_spec, ds_spec, example) + score.validate_output(np.array([1, 2]), out_spec, output, ds_spec, + ds_spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises(ValueError, score.validate_output, ["a", "b"], + out_spec, output, ds_spec, ds_spec, example) + bad_score = types.MulticlassPreds( + parent="label", null_idx=2, vocab=["a", "b"]) + self.assertRaises(ValueError, bad_score.validate_output, [1, 2], + out_spec, output, ds_spec, ds_spec, example) + bad_score = types.MulticlassPreds( + parent="val", null_idx=0, vocab=["a", "b"]) + self.assertRaises(ValueError, bad_score.validate_output, [1, 2], + out_spec, output, ds_spec, ds_spec, example) + + def test_annotations(self): + ds_spec = { + "text": types.TextSegment(), + } + out_spec = { + "tokens": types.Tokens(), + "spans": types.SpanLabels(align="tokens"), + "edges": types.EdgeLabels(align="tokens"), + "annot": types.MultiSegmentAnnotations(), + } + example = {"text": "hi"} + output = {"tokens": ["hi"], "preds": [dtypes.SpanLabel(start=0, end=1)], + "edges": [dtypes.EdgeLabel(span1=(0, 0), span2=(1, 1), label=0)], + "annot": [dtypes.AnnotationCluster(label="hi", spans=[])]} + + spans = types.SpanLabels(align="tokens") + edges = types.EdgeLabels(align="tokens") + annot = types.MultiSegmentAnnotations() + try: + spans.validate_output( + [dtypes.SpanLabel(start=0, end=1)], out_spec, output, ds_spec, + ds_spec, example) + edges.validate_output( + [dtypes.EdgeLabel(span1=(0, 0), span2=(1, 1), label=0)], out_spec, + output, ds_spec, ds_spec, example) + annot.validate_output( + [dtypes.AnnotationCluster(label="hi", spans=[])], out_spec, + output, ds_spec, ds_spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises( + ValueError, spans.validate_output, [1], out_spec, output, ds_spec, + ds_spec, example) + self.assertRaises( + ValueError, edges.validate_output, [1], out_spec, output, ds_spec, + ds_spec, example) + self.assertRaises( + ValueError, annot.validate_output, [1], out_spec, output, ds_spec, + ds_spec, example) + + bad_spans = types.SpanLabels(align="edges") + bad_edges = types.EdgeLabels(align="spans") + self.assertRaises( + ValueError, bad_spans.validate_output, + [dtypes.SpanLabel(start=0, end=1)], out_spec, output, ds_spec, ds_spec, + example) + self.assertRaises( + ValueError, bad_edges.validate_output, + [dtypes.EdgeLabel(span1=(0, 0), span2=(1, 1), label=0)], out_spec, + output, ds_spec, ds_spec, example) + + def test_gradients(self): + ds_spec = { + "text": types.TextSegment(), + "target": types.CategoryLabel() + } + out_spec = { + "tokens": types.Tokens(), + "embs": types.Embeddings(), + "grads": types.Gradients(align="tokens", grad_for="embs", + grad_target_field_key="target") + } + example = {"text": "hi", "target": "one"} + output = {"tokens": ["hi"], "embs": [.1, .2], "grads": [.1]} + + grads = types.Gradients(align="tokens", grad_for="embs", + grad_target_field_key="target") + embs = types.Embeddings() + try: + grads.validate_output([.1], out_spec, output, ds_spec, ds_spec, example) + embs.validate_output([.1, .2], out_spec, output, ds_spec, ds_spec, + example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises( + ValueError, grads.validate_output, ["bad"], out_spec, output, ds_spec, + ds_spec, example) + self.assertRaises( + ValueError, embs.validate_output, ["bad"], out_spec, output, ds_spec, + ds_spec, example) + + bad_grads = types.Gradients(align="text", grad_for="embs", + grad_target_field_key="target") + self.assertRaises( + ValueError, bad_grads.validate_output, [.1], out_spec, output, ds_spec, + ds_spec, example) + bad_grads = types.Gradients(align="tokens", grad_for="tokens", + grad_target_field_key="target") + self.assertRaises( + ValueError, bad_grads.validate_output, [.1], out_spec, output, ds_spec, + ds_spec, example) + bad_grads = types.Gradients(align="tokens", grad_for="embs", + grad_target_field_key="bad") + self.assertRaises( + ValueError, bad_grads.validate_output, [.1], out_spec, output, ds_spec, + ds_spec, example) + + def test_tokenembsgrads(self): + ds_spec = { + "text": types.TextSegment(), + "target": types.CategoryLabel() + } + out_spec = { + "tokens": types.Tokens(), + "embs": types.TokenEmbeddings(align="tokens"), + "grads": types.TokenGradients(align="tokens", grad_for="embs", + grad_target_field_key="target") + } + example = {"text": "hi", "target": "one"} + output = {"tokens": ["hi"], "embs": np.array([[.1], [.2]]), + "grads": np.array([[.1], [.2]])} + + grads = types.TokenGradients(align="tokens", grad_for="embs", + grad_target_field_key="target") + embs = types.TokenEmbeddings(align="tokens") + try: + grads.validate_output(np.array([[.1], [.2]]), out_spec, output, ds_spec, + ds_spec, example) + embs.validate_output(np.array([[.1], [.2]]), out_spec, output, ds_spec, + ds_spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises( + ValueError, grads.validate_output, np.array([.1, .2]), out_spec, output, + ds_spec, ds_spec, example) + self.assertRaises( + ValueError, embs.validate_output, np.array([.1, .2]), out_spec, output, + ds_spec, ds_spec, example) + + bad_embs = types.TokenEmbeddings(align="grads") + self.assertRaises( + ValueError, bad_embs.validate_output, np.array([[.1], [.2]]), out_spec, + output, ds_spec, ds_spec, example) + + def test_attention(self): + ds_spec = { + "text": types.TextSegment(), + } + out_spec = { + "tokens": types.Tokens(), + "val": types.RegressionScore, + "attn": types.AttentionHeads(align_in="tokens", align_out="tokens"), + } + example = {"text": "hi"} + output = {"tokens": ["hi"], "attn": np.array([[[.1]], [[.2]]])} + + attn = types.AttentionHeads(align_in="tokens", align_out="tokens") + try: + attn.validate_output(np.array([[[.1]], [[.2]]]), out_spec, output, + ds_spec, ds_spec, example) + except ValueError: + self.fail("Raised unexpected error.") + + self.assertRaises( + ValueError, attn.validate_output, np.array([.1, .2]), out_spec, output, + ds_spec, ds_spec, example) + + bad_attn = types.AttentionHeads(align_in="tokens", align_out="val") + self.assertRaises( + ValueError, bad_attn.validate_output, np.array([[[.1]], [[.2]]]), + out_spec, output, ds_spec, ds_spec, example) + bad_attn = types.AttentionHeads(align_in="val", align_out="tokens") + self.assertRaises( + ValueError, bad_attn.validate_output, np.array([[[.1]], [[.2]]]), + out_spec, output, ds_spec, ds_spec, example) + if __name__ == "__main__": absltest.main() diff --git a/lit_nlp/app.py b/lit_nlp/app.py index 57653f02..dc3d98c9 100644 --- a/lit_nlp/app.py +++ b/lit_nlp/app.py @@ -16,6 +16,7 @@ import functools import glob +import math import os import random import time @@ -30,9 +31,11 @@ from lit_nlp.api import types from lit_nlp.components import core from lit_nlp.lib import caching +from lit_nlp.lib import flag_helpers from lit_nlp.lib import serialize from lit_nlp.lib import ui_state from lit_nlp.lib import utils +from lit_nlp.lib import validation from lit_nlp.lib import wsgi_app import tqdm @@ -323,6 +326,33 @@ def _push_ui_state(self, data, dataset_name: str, **unused_kw): self._datasets[dataset_name], dataset_name, **options) + def _validate(self, validate: Optional[flag_helpers.ValidationMode], + report_all: bool): + """Validate all datasets and models loaded for proper setup.""" + if validate is None or validate == flag_helpers.ValidationMode.OFF: + return + + datasets_to_validate = {} + for dataset in self._datasets: + if validate == flag_helpers.ValidationMode.ALL: + datasets_to_validate[dataset] = self._datasets[dataset] + elif validate == flag_helpers.ValidationMode.FIRST: + datasets_to_validate[dataset] = self._datasets[dataset].slice[:1] + elif validate == flag_helpers.ValidationMode.SAMPLE: + sample_size = math.ceil(len(self._datasets[dataset]) * 0.05) + datasets_to_validate[dataset] = self._datasets[dataset].sample( + sample_size) + for dataset in datasets_to_validate: + logging.info("Validating dataset '%s'", dataset) + validation.validate_dataset( + datasets_to_validate[dataset], report_all) + for model, model_info in self._info['models'].items(): + for dataset_name in model_info['datasets']: + logging.info("Validating model '%s' on dataset '%s'", model, + dataset_name) + validation.validate_model( + self._models[model], datasets_to_validate[dataset_name], report_all) + def _warm_start(self, rate: float, progress_indicator: Optional[ProgressIndicator] = None): @@ -430,6 +460,8 @@ def __init__( onboard_start_doc: Optional[str] = None, onboard_end_doc: Optional[str] = None, sync_state: bool = False, # notebook-only; not in server_flags + validate: Optional[flag_helpers.ValidationMode] = None, + report_all: bool = False, ): if client_root is None: raise ValueError('client_root must be set on application') @@ -493,6 +525,9 @@ def __init__( # Information on models, datasets, and other components. self._info = self._build_metadata() + # Validate datasets and models if specified. + self._validate(validate, report_all) + # Optionally, run models to pre-populate cache. if warm_projections: logging.info( diff --git a/lit_nlp/client/lib/lit_types.ts b/lit_nlp/client/lib/lit_types.ts index 16654108..8365893e 100644 --- a/lit_nlp/client/lib/lit_types.ts +++ b/lit_nlp/client/lib/lit_types.ts @@ -243,7 +243,7 @@ export class CategoryLabel extends StringLitType { /** * A tensor type. */ -class _Tensor1D extends LitType { +class _Tensor extends LitType { override default: number[] = []; } @@ -252,7 +252,7 @@ class _Tensor1D extends LitType { * Multiclass predicted probabilities, as [num_labels]. */ @registered -export class MulticlassPreds extends _Tensor1D { +export class MulticlassPreds extends _Tensor { /** * Vocabulary is required here for decoding model output. * Usually this will match the vocabulary in the corresponding label field. @@ -335,13 +335,13 @@ export class MultiSegmentAnnotations extends ListLitType { * Embeddings or model activations, as fixed-length [emb_dim]. */ @registered -export class Embeddings extends _Tensor1D { +export class Embeddings extends _Tensor { } /** * Shared gradient attributes. */ -class _GradientsBase extends _Tensor1D { +class _GradientsBase extends _Tensor { /** Name of a Tokens field. */ align?: string = undefined; /** Name of Embeddings field. */ @@ -363,7 +363,7 @@ export class Gradients extends _GradientsBase { /** * A single vector of [enc_dim]. */ -class _InfluenceEncodings extends _Tensor1D { +class _InfluenceEncodings extends _Tensor { /** Class for computing gradients (string). */ grad_target?: string = undefined; } @@ -372,7 +372,7 @@ class _InfluenceEncodings extends _Tensor1D { * Per-token embeddings, as [num_tokens, emb_dim]. */ @registered -export class TokenEmbeddings extends _Tensor1D { +export class TokenEmbeddings extends _Tensor { /** Name of a Tokens field. */ align?: string = undefined; } @@ -395,7 +395,7 @@ export class ImageGradients extends _GradientsBase { * One or more attention heads, as [num_heads, num_tokens, num_tokens]. */ @registered -export class AttentionHeads extends _Tensor1D { +export class AttentionHeads extends _Tensor { // Input and output Tokens fields; for self-attention these can // be the same. align_in: string = ''; diff --git a/lit_nlp/examples/coref/model.py b/lit_nlp/examples/coref/model.py index d4830161..0e45a07f 100644 --- a/lit_nlp/examples/coref/model.py +++ b/lit_nlp/examples/coref/model.py @@ -131,5 +131,5 @@ def output_spec(self): lit_types.EdgeLabels(align='tokens'), 'pred_answer': lit_types.MulticlassPreds( - vocab=winogender.ANSWER_VOCAB, parent='answer'), + required=False, vocab=winogender.ANSWER_VOCAB, parent='answer'), } diff --git a/lit_nlp/lib/flag_helpers.py b/lit_nlp/lib/flag_helpers.py new file mode 100644 index 00000000..e278d048 --- /dev/null +++ b/lit_nlp/lib/flag_helpers.py @@ -0,0 +1,12 @@ +"""Data to support runtime flags.""" + +import enum + + +@enum.unique +class ValidationMode(enum.Enum): + """All the validation mode options.""" + OFF = 'off' # Do not validate datasets and model outputs. + FIRST = 'first' # Validate the first datapoint. + ALL = 'all' # Validate all datapoints. + SAMPLE = 'sample' # Validate a sample of 5% of datapoints. diff --git a/lit_nlp/lib/testing_utils.py b/lit_nlp/lib/testing_utils.py index 4e043b70..143666ea 100644 --- a/lit_nlp/lib/testing_utils.py +++ b/lit_nlp/lib/testing_utils.py @@ -197,3 +197,38 @@ def assert_deep_almost_equal(testcase, result, actual, places=4): testcase.fail('results and actual have different keys') for key in result: assert_deep_almost_equal(testcase, result[key], actual[key]) + + +class TestCustomOutputModel(lit_model.Model): + """Implements lit.Model interface for testing. + + This class allows user-specified outputs for testing return values. + """ + + def __init__(self, input_spec: lit_types.Spec, output_spec: lit_types.Spec, + results: List[JsonDict]): + """Set model internals. + + Args: + input_spec: An input spec. + output_spec: An output spec. + results: Results to return. + """ + self._input_spec = input_spec + self._output_spec = output_spec + self._predict_counter = 0 + self._results = results + + # LIT API implementation + def input_spec(self): + return self._input_spec + + def output_spec(self): + return self._output_spec + + def predict_minibatch(self, inputs: List[JsonDict], **kw): + def predict_single(_): + output = self._results[self._predict_counter % len(self._results)] + self._predict_counter += 1 + return output + return map(predict_single, inputs) diff --git a/lit_nlp/lib/validation.py b/lit_nlp/lib/validation.py new file mode 100644 index 00000000..d2fa679a --- /dev/null +++ b/lit_nlp/lib/validation.py @@ -0,0 +1,57 @@ +"""Validators for datasets and models.""" + + +from typing import cast +from absl import logging +from lit_nlp.api import dataset +from lit_nlp.api import model +from lit_nlp.api import types + + +def validate_dataset(ds: dataset.Dataset, report_all: bool): + """Validate dataset entries against spec.""" + last_error = None + for ex in ds.examples: + for (key, entry) in ds.spec().items(): + if key not in ex or ex[key] is None: + if entry.required: + raise ValueError( + f'Required dataset feature {key} missing from datapoint') + else: + continue + try: + entry.validate_input(ex[key], ds.spec(), cast(types.Input, ex)) + except ValueError as e: + logging.exception('Failed validating input key %s', key) + if report_all: + last_error = e + else: + raise e + if last_error: + raise last_error + + +def validate_model(mod: model.Model, ds: dataset.Dataset, report_all: bool): + """Validate model usage on dataset against specs.""" + last_error = None + outputs = list(mod.predict(ds.examples)) + for ex, output in zip(ds.examples, outputs): + for (key, entry) in mod.output_spec().items(): + if key not in output or output[key] is None: + if entry.required: + raise ValueError( + f'Required model output {key} missing from prediction result') + else: + continue + try: + entry.validate_output( + output[key], mod.output_spec(), output, mod.input_spec(), ds.spec(), + cast(types.Input, ex)) + except ValueError as e: + logging.exception('Failed validating model output key %s', key) + if report_all: + last_error = e + else: + raise e + if last_error: + raise last_error diff --git a/lit_nlp/lib/validation_test.py b/lit_nlp/lib/validation_test.py new file mode 100644 index 00000000..1228af1b --- /dev/null +++ b/lit_nlp/lib/validation_test.py @@ -0,0 +1,147 @@ +"""Tests for validation.""" + + +from absl.testing import absltest +from lit_nlp.api import dataset +from lit_nlp.api import types +from lit_nlp.lib import testing_utils +from lit_nlp.lib import validation + + +class ValidationTest(absltest.TestCase): + + def test_validate_dataset(self): + spec = { + "score": types.Scalar(), + "text": types.TextSegment(), + } + datapoints = [ + { + "score": 0, + "text": "a" + }, + { + "score": 0, + "text": "b" + }, + ] + ds = dataset.Dataset(spec, datapoints) + try: + validation.validate_dataset(ds, False) + except ValueError: + self.fail("Raised unexpected error.") + + def test_validate_dataset_fail_bad_scalar(self): + spec = { + "score": types.Scalar(), + "text": types.TextSegment(), + } + datapoints = [ + { + "score": "bad", + "text": "a" + }, + { + "score": 0, + "text": "b" + }, + ] + ds = dataset.Dataset(spec, datapoints) + self.assertRaises(ValueError, validation.validate_dataset, ds, False) + self.assertRaises(ValueError, validation.validate_dataset, ds, True) + + def test_validate_dataset_validate_all(self): + spec = { + "score": types.Scalar(), + "text": types.TextSegment(), + } + datapoints = [ + { + "score": 0, + "text": "a" + }, + { + "score": "bad", + "text": "b" + }, + ] + ds = dataset.Dataset(spec, datapoints) + self.assertRaises(ValueError, validation.validate_dataset, ds, False) + + def test_validate_model(self): + in_spec = { + "score": types.Scalar(), + "text": types.TextSegment(), + } + out_spec = { + "res": types.RegressionScore(parent="score"), + } + datapoints = [ + { + "score": 0, + "text": "a" + }, + { + "score": 1, + "text": "b" + }, + ] + results = [{"res": 1}, {"res": 1}] + ds = dataset.Dataset(in_spec, datapoints) + model = testing_utils.TestCustomOutputModel(in_spec, out_spec, results) + try: + validation.validate_model(model, ds, True) + except ValueError: + self.fail("Raised unexpected error.") + + def test_validate_model_fail(self): + in_spec = { + "score": types.Scalar(), + "text": types.TextSegment(), + } + out_spec = { + "res": types.RegressionScore(parent="score"), + } + datapoints = [ + { + "score": 0, + "text": "a" + }, + { + "score": 1, + "text": "b" + }, + ] + results = [{"res": "bad"}, {"res": 1}] + ds = dataset.Dataset(in_spec, datapoints) + model = testing_utils.TestCustomOutputModel(in_spec, out_spec, results) + self.assertRaises( + ValueError, validation.validate_model, model, ds, False) + + def test_validate_model_validate_all(self): + in_spec = { + "score": types.Scalar(), + "text": types.TextSegment(), + } + out_spec = { + "res": types.RegressionScore(parent="score"), + } + datapoints = [ + { + "score": 0, + "text": "a" + }, + { + "score": 1, + "text": "b" + }, + ] + results = [{"res": 1}, {"res": "bad"}] + ds = dataset.Dataset(in_spec, datapoints) + model = testing_utils.TestCustomOutputModel(in_spec, out_spec, results) + self.assertRaises( + ValueError, validation.validate_model, model, ds, False) + + +if __name__ == "__main__": + absltest.main() diff --git a/lit_nlp/server_config.py b/lit_nlp/server_config.py index 3243d4ed..db2d9d5e 100644 --- a/lit_nlp/server_config.py +++ b/lit_nlp/server_config.py @@ -70,6 +70,10 @@ # Whether the LIT instance is a development demo. config.development_demo = False +# Whether dataset and model validation will happen at startup. +config.validate = None +config.report_all = False + import os import pathlib config.client_root = os.path.join( diff --git a/lit_nlp/server_flags.py b/lit_nlp/server_flags.py index 74fed140..4defac0d 100644 --- a/lit_nlp/server_flags.py +++ b/lit_nlp/server_flags.py @@ -29,6 +29,7 @@ import pathlib from absl import flags +from lit_nlp.lib import flag_helpers FLAGS = flags.FLAGS @@ -80,6 +81,14 @@ flags.DEFINE_bool( 'development_demo', False, 'If true, signifies this LIT ' 'instance is a development demo.'), + flags.DEFINE_enum_class( + 'validate', None, flag_helpers.ValidationMode, + 'If not None or "off", will validate the datasets and model outputs ' + 'according to the value set. By default, validation is disabled.'), + flags.DEFINE_bool( + 'report_all', False, + 'If true, and validate is true, will report every issue in validation ' + 'as opposed to just the first.'), flags.DEFINE_string( 'client_root',