From 51985dc63771b16597a8bf3b4a4405214dba1324 Mon Sep 17 00:00:00 2001 From: Noah Broestl Date: Fri, 2 Apr 2021 18:55:14 +0000 Subject: [PATCH 001/213] Adding Stanza wrapper and demo --- lit_nlp/examples/models/stanza_models.py | 107 +++++++++++++++++++++++ lit_nlp/examples/stanza_demo.py | 83 ++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 lit_nlp/examples/models/stanza_models.py create mode 100644 lit_nlp/examples/stanza_demo.py diff --git a/lit_nlp/examples/models/stanza_models.py b/lit_nlp/examples/models/stanza_models.py new file mode 100644 index 00000000..4e491ce0 --- /dev/null +++ b/lit_nlp/examples/models/stanza_models.py @@ -0,0 +1,107 @@ +# Lint as: python3 +"""Wrapper for Stanza model""" + +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.api import dtypes + +SpanLabel = dtypes.SpanLabel +EdgeLabel = dtypes.EdgeLabel + + +class StanzaTagger(lit_model.Model): + def __init__(self, model, tasks): + self.model = model + self.sequence_tasks = tasks["sequence"] + self.span_tasks = tasks["span"] + self.edge_tasks = tasks["edge"] + + self._input_spec = { + "sentence": lit_types.TextSegment(), + } + + self._output_spec = { + "tokens": lit_types.Tokens(), + } + + # Output spec based on specified tasks + for task in self.sequence_tasks: + self._output_spec[task] = lit_types.SequenceTags(align="tokens") + for task in self.span_tasks: + self._output_spec[task] = lit_types.SpanLabels(align="tokens") + for task in self.edge_tasks: + self._output_spec[task] = lit_types.EdgeLabels(align="tokens") + + def _predict(self, ex): + """ + Predicts all specified tasks for an individual example + :param ex (dict): + This should be a dict with a single entry with: + key = "sentence" + value (str) = a single string for prediction + :return (list): + This list contains dicts for each prediction tasks with: + key = task name + value (list) = predictions + """ + doc = self.model(ex["sentence"]) + prediction = {} + for sentence in doc.sentences: + prediction["tokens"] = [word.text for word in sentence.words] + + # Process each sequence task + for task in self.sequence_tasks: + prediction[task] = [word.to_dict()[task] for word in sentence.words] + + # Process each span task + for task in self.span_tasks: + # Mention is currently the only span task + if task == "mention": + prediction[task] = [] + for entity in sentence.entities: + # Stanza indexes start/end of entities on char. LIT needs them as token indexes + start, end = entity_char_to_token(entity, sentence) + span_label = SpanLabel(start=start, end=end, label=entity.type) + prediction[task].append(span_label) + + # Process each edge task + for task in self.edge_tasks: + # Deps is currently the only edge task + if task == "deps": + prediction[task] = [] + for relation in sentence.dependencies: + label = relation[1] + span1 = relation[2].id + span2 = relation[2].id if label == "root" else relation[0].id + edge_label = EdgeLabel( + (span1 - 1, span1), (span2 - 1, span2), label + ) + prediction[task].append(edge_label) + + return prediction + + def predict_minibatch(self, inputs, config=None): + return [self._predict(ex) for ex in inputs] + + def input_spec(self): + return self._input_spec + + def output_spec(self): + return self._output_spec + + +def entity_char_to_token(entity, sentence): + """ + Takes Stanza entity and sentence objects and returns the start and end tokens for the entity + :param entity: Stanza entity + :param sentence: Stanza sentence + :return (int, int): Returns the start and end locations indexed by tokens + """ + start_token, end_token = None, None + for i, v in enumerate(sentence.words): + x = v.misc.split("|") + if "start_char=" + str(entity.start_char) in x: + start_token = i + if "end_char=" + str(entity.end_char) in x: + end_token = i + 1 + return start_token, end_token diff --git a/lit_nlp/examples/stanza_demo.py b/lit_nlp/examples/stanza_demo.py new file mode 100644 index 00000000..2dfa4708 --- /dev/null +++ b/lit_nlp/examples/stanza_demo.py @@ -0,0 +1,83 @@ +# Lint at: python3 +"""Example demo loading Stanza models. +To run with the demo: + python -m lit_nlp.examples.stanza_demo --port=5432 +Then navigate to localhost:5432 to access the demo UI. +""" +from absl import app +from absl import flags + +import lit_nlp.api.dataset as lit_dataset +import lit_nlp.api.types as lit_types +from lit_nlp.examples.datasets import glue +from lit_nlp.examples.models import stanza_models +from lit_nlp import dev_server +from lit_nlp import server_flags +from lit_nlp.components import scrambler +from lit_nlp.components import word_replacer + +import stanza + +FLAGS = flags.FLAGS + +flags.DEFINE_list( + "sequence_tasks", + ["upos", "xpos", "lemma"], + "Sequence tasks to load and use for prediction. Defaults to all sequence tasks", +) + +flags.DEFINE_list( + "span_tasks", + ["mention"], + "Span tasks to load and use for prediction. Only mentions are included in this demo", +) + +flags.DEFINE_list( + "edge_tasks", + ["deps"], + "Span tasks to load and use for prediction. Only deps are included in this demo", +) + +flags.DEFINE_string("language", "en", "Language to load for Stanza model.") + +flags.DEFINE_integer( + "max_examples", None, "Maximum number of examples to load into LIT." +) + + +def main(_): + # Set Tasks + tasks = { + "sequence": FLAGS.sequence_tasks, + "span": FLAGS.span_tasks, + "edge": FLAGS.edge_tasks, + } + + # Get the correct model for the language + stanza.download(FLAGS.language) + pretrained_model = stanza.Pipeline(FLAGS.language) + models = { + "stanza": stanza_models.StanzaTagger(pretrained_model, tasks), + } + + # Datasets for LIT demo + datasets = { + "SST2": glue.SST2Data(split="validation").slice[: FLAGS.max_examples], + "blank": lit_dataset.Dataset({"text": lit_types.TextSegment()}, []), + } + + # Add generators + generators = { + "scrambler": scrambler.Scrambler(), + "word_replacer": word_replacer.WordReplacer(), + } + + # Start the LIT server. See server_flags.py for server options. + lit_demo = dev_server.Server( + models, datasets, generators, **server_flags.get_flags() + ) + lit_demo.serve() + + +if __name__ == "__main__": + app.run(main) From 106b30d362f9b102213559e392e77c13f36a039b Mon Sep 17 00:00:00 2001 From: Ellen Jiang Date: Mon, 5 Apr 2021 15:01:48 -0700 Subject: [PATCH 002/213] Truncates the menu item text if longer than 25 characters. PiperOrigin-RevId: 366881472 --- lit_nlp/client/elements/menu.ts | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lit_nlp/client/elements/menu.ts b/lit_nlp/client/elements/menu.ts index 19ab2af7..c22dadb2 100644 --- a/lit_nlp/client/elements/menu.ts +++ b/lit_nlp/client/elements/menu.ts @@ -29,6 +29,8 @@ import {styles} from './menu.css'; type ClickCallback = () => void; +const MAX_TEXT_LENGTH = 25; + /** Holds the properties for an item in the menu. */ export interface MenuItem { itemText: string; // menu item text @@ -161,6 +163,9 @@ export class LitMenu extends LitElement { const itemTextClass = classMap({'item-text': true, 'text-disabled': item.disabled}); + // TODO(b/184549342): Consider rewriting component without Material menu + // due to styling issues (e.g. with setting max width with + // text-overflow:ellipses in CSS). // clang-format off return html`
${hasSubmenu ? 'arrow_right' : 'check'} - ${item.itemText} + + ${item.itemText.slice(0, MAX_TEXT_LENGTH) + + ((item.itemText.length > MAX_TEXT_LENGTH) ? '...': '')} +
`; // clang-format on From 9a0f6778da439455f7cda7ffde50397df73443b6 Mon Sep 17 00:00:00 2001 From: Ellen Jiang Date: Thu, 8 Apr 2021 08:55:19 -0700 Subject: [PATCH 003/213] Changes the table filtering to use regex search. PiperOrigin-RevId: 367439391 --- lit_nlp/client/elements/table.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lit_nlp/client/elements/table.ts b/lit_nlp/client/elements/table.ts index 9106dfca..d475a782 100644 --- a/lit_nlp/client/elements/table.ts +++ b/lit_nlp/client/elements/table.ts @@ -178,8 +178,8 @@ export class DataTable extends ReactiveElement { const col = item[index]; if (typeof col === 'string') { - // TODO(b/158299036) Change this to regexp search. - isShownByTextFilter = isShownByTextFilter && col.includes(value); + isShownByTextFilter = + isShownByTextFilter && col.search(new RegExp(value)) !== -1; } else if (typeof col === 'number') { // TODO(b/158299036) Support syntax like 1-3,6 for numbers. isShownByTextFilter = isShownByTextFilter && value === '' ? From b1060e83309b6bec8dbfccfbc554b68c57ebad18 Mon Sep 17 00:00:00 2001 From: Noah Broestl Date: Fri, 9 Apr 2021 18:55:58 +0000 Subject: [PATCH 004/213] Minor updates to comments/docstrings. Handle multiple sentences. --- lit_nlp/examples/models/stanza_models.py | 82 +++++++++++++++++------- lit_nlp/examples/stanza_demo.py | 17 ++++- 2 files changed, 75 insertions(+), 24 deletions(-) diff --git a/lit_nlp/examples/models/stanza_models.py b/lit_nlp/examples/models/stanza_models.py index 4e491ce0..881d3ee1 100644 --- a/lit_nlp/examples/models/stanza_models.py +++ b/lit_nlp/examples/models/stanza_models.py @@ -1,3 +1,17 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== # Lint as: python3 """Wrapper for Stanza model""" @@ -10,8 +24,19 @@ class StanzaTagger(lit_model.Model): + """Stanza Model wrapper""" + def __init__(self, model, tasks): + """Initialize with Stanza model and a dictionary of tasks. + + Args: + model: A Stanza model + tasks: A dictionary of tasks, grouped by task type. + Keys are the grouping, which should be one of ('sequence', 'span', 'edge'). + Values are a list of stanza task names as strings. + """ self.model = model + # Store lists of task name strings by grouping self.sequence_tasks = tasks["sequence"] self.span_tasks = tasks["span"] self.edge_tasks = tasks["edge"] @@ -33,50 +58,58 @@ def __init__(self, model, tasks): self._output_spec[task] = lit_types.EdgeLabels(align="tokens") def _predict(self, ex): - """ - Predicts all specified tasks for an individual example - :param ex (dict): - This should be a dict with a single entry with: - key = "sentence" - value (str) = a single string for prediction - :return (list): - This list contains dicts for each prediction tasks with: - key = task name - value (list) = predictions + """Predicts all specified tasks for an individual example. + + Args: + ex (dict): This should be a dict with a single entry. + key = "sentence" + value (str) = a single string for prediction + Returns: + A list containing dicts for each prediction tasks with: + key = task name + value (list) = predictions + Raises: + ValueError: Invalid task name. """ doc = self.model(ex["sentence"]) - prediction = {} + prediction = {task: [] for task in self._output_spec} for sentence in doc.sentences: - prediction["tokens"] = [word.text for word in sentence.words] + # Get offset value to align task to tokens for multiple sentences + offset = len(prediction['tokens']) + prediction["tokens"].extend([word.text for word in sentence.words]) # Process each sequence task for task in self.sequence_tasks: - prediction[task] = [word.to_dict()[task] for word in sentence.words] + prediction[task].extend([word.to_dict()[task] for word in sentence.words]) # Process each span task + print(sentence.entities) for task in self.span_tasks: # Mention is currently the only span task if task == "mention": - prediction[task] = [] for entity in sentence.entities: # Stanza indexes start/end of entities on char. LIT needs them as token indexes start, end = entity_char_to_token(entity, sentence) - span_label = SpanLabel(start=start, end=end, label=entity.type) + span_label = SpanLabel(start=start+offset, end=end+offset, label=entity.type) prediction[task].append(span_label) + else: + raise ValueError(f"Invalid span task: '{task}'") # Process each edge task for task in self.edge_tasks: # Deps is currently the only edge task if task == "deps": - prediction[task] = [] for relation in sentence.dependencies: label = relation[1] - span1 = relation[2].id - span2 = relation[2].id if label == "root" else relation[0].id + span1 = relation[2].id + offset + span2 = relation[2].id + offset if label == "root" else relation[0].id + offset + # Relation lists have a root value at index 0, so subtract 1 to align them to tokens edge_label = EdgeLabel( (span1 - 1, span1), (span2 - 1, span2), label ) prediction[task].append(edge_label) + else: + raise ValueError(f"Invalid edge task: '{task}'") return prediction @@ -91,14 +124,17 @@ def output_spec(self): def entity_char_to_token(entity, sentence): - """ - Takes Stanza entity and sentence objects and returns the start and end tokens for the entity - :param entity: Stanza entity - :param sentence: Stanza sentence - :return (int, int): Returns the start and end locations indexed by tokens + """Takes Stanza entity and sentence objects and returns the start and end tokens for the entity + + Args: + entity: Stanza entity object + sentence: Stanza sentence object + Returns: + Returns the token index of start and end locations for the entity """ start_token, end_token = None, None for i, v in enumerate(sentence.words): + # Misc is a string of values, separated by |, that contains start and end chars x = v.misc.split("|") if "start_char=" + str(entity.start_char) in x: start_token = i diff --git a/lit_nlp/examples/stanza_demo.py b/lit_nlp/examples/stanza_demo.py index 2dfa4708..1eec82a9 100644 --- a/lit_nlp/examples/stanza_demo.py +++ b/lit_nlp/examples/stanza_demo.py @@ -1,3 +1,17 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== # Lint at: python3 """Example demo loading Stanza models. To run with the demo: @@ -46,7 +60,8 @@ def main(_): - # Set Tasks + # Set Tasks as a dictionary with task groups as + # keys and values as lists of strings of Stanza task names tasks = { "sequence": FLAGS.sequence_tasks, "span": FLAGS.span_tasks, From 9978971f2bc91b97f2970a197dbdefd5cb87567f Mon Sep 17 00:00:00 2001 From: James Wexler Date: Fri, 9 Apr 2021 11:56:25 -0700 Subject: [PATCH 005/213] Fix scalars xaxis when all values are identical PiperOrigin-RevId: 367675908 --- lit_nlp/client/modules/scalar_module.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lit_nlp/client/modules/scalar_module.ts b/lit_nlp/client/modules/scalar_module.ts index b1ca66ff..252cada0 100644 --- a/lit_nlp/client/modules/scalar_module.ts +++ b/lit_nlp/client/modules/scalar_module.ts @@ -374,6 +374,12 @@ export class ScalarModule extends LitModule { if (outputSpec != null && isLitSubtype(outputSpec[key], 'Scalar')) { const scalarValues = this.preds.map((pred) => pred[key]); scoreRange = [Math.min(...scalarValues), Math.max(...scalarValues)]; + // If the range is 0 (all values are identical, then artificially increase + // the range so that an X-axis is properly displayed. + if (scoreRange[0] === scoreRange[1]) { + scoreRange[0] = scoreRange[0] - .1; + scoreRange[1] = scoreRange[1] + .1; + } } return d3.scaleLinear().domain(scoreRange).range([ From 3575f2ba75ad91a8efae5b0367daad0f94599982 Mon Sep 17 00:00:00 2001 From: Ankur Taly Date: Fri, 9 Apr 2021 14:40:41 -0700 Subject: [PATCH 006/213] Update hotflip algorithm to only report minimal hotflips A hotflip is consider minimal if no strict subset of the applied token flips succeeds in flipping the prediction. PiperOrigin-RevId: 367706701 --- lit_nlp/components/hotflip.py | 134 +++++++++++++++++++++------------- 1 file changed, 82 insertions(+), 52 deletions(-) diff --git a/lit_nlp/components/hotflip.py b/lit_nlp/components/hotflip.py index f0b33c74..2d77652b 100644 --- a/lit_nlp/components/hotflip.py +++ b/lit_nlp/components/hotflip.py @@ -26,6 +26,7 @@ """ import copy +import itertools from typing import List, Text, Optional, Type, cast from absl import logging @@ -39,6 +40,11 @@ JsonDict = types.JsonDict Spec = types.Spec +NUM_EXAMPLES_KEY = "Number of examples" +NUM_EXAMPLES_DEFAULT = 5 +MAX_FLIPS_KEY = "Maximum number of token flips" +MAX_FLIPS_DEFAULT = 3 + class HotFlip(lit_components.Generator): """HotFlip generator. @@ -47,6 +53,9 @@ class HotFlip(lit_components.Generator): tokens in the input sentence in order to to obtain a different prediction from the input sentence. + A hotflip is considered minimal if no strict subset of the applied token flips + succeeds in flipping the prediction. + This generator is currently only supported on classification models. """ @@ -66,17 +75,37 @@ def find_fields( assert isinstance(output_spec[align_field], align_typ) return fields + def config_spec(self) -> types.Spec: + return { + NUM_EXAMPLES_KEY: types.TextSegment(default=str(NUM_EXAMPLES_DEFAULT)), + MAX_FLIPS_KEY: types.TextSegment(default=str(MAX_FLIPS_DEFAULT)), + } + + def _subset_exists(self, cand_set, sets): + """Checks whether a subset of 'cand_set' exists in 'sets'.""" + for s in sets: + if s.issubset(cand_set): + return True + return False + + def _gen_tokens_to_flip(self, ntokens, max_flips): + for i in range(min(ntokens, max_flips)): + for s in itertools.combinations(range(ntokens), i+1): + yield s + def generate(self, example: JsonDict, model: lit_model.Model, dataset: lit_dataset.Dataset, - config: Optional[JsonDict] = None, - num_examples: int = 1) -> List[JsonDict]: - """Use gradient to find/substitute the token with largest impact on loss.""" - # TODO(lit-team): This function is quite long. Consider breaking it - # into small functions. + config: Optional[JsonDict] = None) -> List[JsonDict]: + """Identify minimal sets of token flips that alter the prediction.""" del dataset # Unused. + num_examples = int( + config[NUM_EXAMPLES_KEY]) if config else NUM_EXAMPLES_DEFAULT + max_flips = int( + config[MAX_FLIPS_KEY]) if config else MAX_FLIPS_DEFAULT + assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) @@ -122,12 +151,9 @@ def generate(self, orig_probabilities = orig_output[pred_key] orig_prediction = np.argmax(orig_probabilities) - # Perform a flip in each sequence for which we have gradients (separately). - # Each sequence may give rise to multiple new examples, depending on how - # many words we flip. - # TODO(lit-team): make configurable how many new examples are desired. # TODO(lit-team): use only 1 sequence as input (configurable in UI). - new_examples = [] + successful_counterfactuals = [] + successful_positions = [] for grad_field in grad_fields: # Get the tokens and their gradient vectors. token_field = output_spec[grad_field].align # pytype: disable=attribute-error @@ -137,67 +163,71 @@ def generate(self, types.Tokens) assert len(token_emb_fields) == 1, "Found multiple token embeddings" token_embs = orig_output[token_emb_fields[0]] - - # Identify the token with the largest gradient attribution, - # defined as the dot product between the token embedding and gradient - # of the output wrt the embedding. assert token_embs.shape[0] == grads.shape[0] - token_grad_attrs = np.sum(token_embs * grads, axis=-1) - # Get a list of indices of input tokens, sorted by gradient attribution, - # highest first. We will flip tokens in this order. - sorted_by_grad_attrs = np.argsort(token_grad_attrs)[::-1] - - for i in range(min(num_examples, len(tokens))): - token_id = sorted_by_grad_attrs[i] - logging.info("Selected token: %s (pos=%d) with gradient attribution %f", - tokens[token_id], token_id, token_grad_attrs[token_id]) - token_grad = grads[token_id] - - # Take dot product with all word embeddings. Get smallest value. - # (We are look for a replacement token that will lower the score - # the current class, thereby increasing the chances of a label - # flip.) - # TODO(lit-team): Can add criteria to the winner e.g. cosine distance. - scores = np.dot(embed, token_grad) - winner = np.argmin(scores) - logging.info("Replacing [%s] (pos=%d) with option %d: [%s] (id=%d)", - tokens[token_id], token_id, i, inv_vocab[winner], winner) + + # We take a dot product of each input token gradient (grads) with the + # embedding table (embed) + # TODO(ataly): Only consider tokens that have the same part-of-speech + # tag as the original token (and a certain cosine similarity with the + # original token) + replacement_token_ids = np.argmin( + (np.expand_dims(embed, 1) @ grads.T).squeeze(1), axis=0) + + replacement_tokens = [inv_vocab[id] for id in replacement_token_ids] + logging.info("Replacement tokens: %s", replacement_tokens) + + # Consider all combinations of tokens upto length max_flips. + # We will iterate through this list (in toplogically sorted order) + # and at each iteration, replace the selected tokens with corresponding + # replacement tokens and checks if the prediction flips. + # TODO(ataly): Sort token sets of the same cardinality in decreasing + # order of gradient (i.e., we wish to prioritize flipping tokens that + # have the largest impact on the prediction.) + for token_positions in self._gen_tokens_to_flip(len(tokens), max_flips): + if len(successful_counterfactuals) >= num_examples: + return successful_counterfactuals + # If a subset of the set of tokens have already been successful in + # obtaining a flip, we continue. This ensure that we only consider + # sets of token flips that are minimal. + if self._subset_exists(set(token_positions), successful_positions): + continue + + logging.info("Selected tokens to flip: %s (positions=%s) with: %s", + [tokens[i] for i in token_positions], token_positions, + [replacement_tokens[i] for i in token_positions]) # Create a new input to the model. # TODO(iftenney, bastings): enforce somewhere that this field has the # same name in the input and output specs. input_token_field = token_field input_text_field = input_spec[input_token_field].parent # pytype: disable=attribute-error - new_example = copy.deepcopy(example) + counterfactual = copy.deepcopy(example) modified_tokens = copy.copy(tokens) - modified_tokens[token_id] = inv_vocab[winner] - new_example[input_token_field] = modified_tokens + for j in token_positions: + modified_tokens[j] = replacement_tokens[j] + counterfactual[input_token_field] = modified_tokens # TODO(iftenney, bastings): call a model-provided detokenizer here? # Though in general tokenization isn't invertible and it's possible for # HotFlip to produce wordpiece sequences that don't correspond to any # input string. - new_example[input_text_field] = " ".join(modified_tokens) + counterfactual[input_text_field] = " ".join(modified_tokens) # Predict a new label for this example. - new_output = list(model.predict([new_example]))[0] + counterfactual_output = list(model.predict([counterfactual]))[0] # Update label if multi-class prediction. # TODO(lit-dev): provide a general system for handling labels on # generated examples. - probabilities = new_output[pred_key] - new_prediction = np.argmax(probabilities) + probabilities = counterfactual_output[pred_key] + counterfactual_prediction = np.argmax(probabilities) label_key = cast(types.MulticlassPreds, output_spec[pred_key]).parent label_names = cast(types.MulticlassPreds, output_spec[pred_key]).vocab - new_label = label_names[new_prediction] - new_example[label_key] = new_label - logging.info("Updated example with new label: %s", new_label) + counterfactual_label = label_names[counterfactual_prediction] + counterfactual[label_key] = counterfactual_label + logging.info("Updated example with new label: %s", counterfactual_label) - if new_prediction != orig_prediction: + if counterfactual_prediction != orig_prediction: # Hotflip found - new_examples.append(new_example) - else: - # We make new_example as our base example and continue with more - # token flips. - example = new_example - tokens = modified_tokens - return new_examples + successful_counterfactuals.append(counterfactual) + successful_positions.append(set(token_positions)) + return successful_counterfactuals From be3eeaf854076ffa9a8158ebd107a5fb063c4c5c Mon Sep 17 00:00:00 2001 From: Noah Broestl Date: Mon, 12 Apr 2021 13:07:45 +0000 Subject: [PATCH 007/213] Remove a print statement --- lit_nlp/examples/models/stanza_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lit_nlp/examples/models/stanza_models.py b/lit_nlp/examples/models/stanza_models.py index 881d3ee1..33ea832e 100644 --- a/lit_nlp/examples/models/stanza_models.py +++ b/lit_nlp/examples/models/stanza_models.py @@ -83,7 +83,6 @@ def _predict(self, ex): prediction[task].extend([word.to_dict()[task] for word in sentence.words]) # Process each span task - print(sentence.entities) for task in self.span_tasks: # Mention is currently the only span task if task == "mention": From dd7c2247c248ed7438fea902d1feef12e73ffc8a Mon Sep 17 00:00:00 2001 From: James Wexler Date: Mon, 12 Apr 2021 15:14:50 -0700 Subject: [PATCH 008/213] Allow specifying tokens to not flip in hotflip generator. PiperOrigin-RevId: 368094807 --- .../client/elements/interpreter_controls.ts | 32 +++++++++++++++++-- lit_nlp/components/hotflip.py | 27 ++++++++++------ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/lit_nlp/client/elements/interpreter_controls.ts b/lit_nlp/client/elements/interpreter_controls.ts index 9c67911e..99f827f3 100644 --- a/lit_nlp/client/elements/interpreter_controls.ts +++ b/lit_nlp/client/elements/interpreter_controls.ts @@ -37,7 +37,7 @@ export class InterpreterControls extends ReactiveElement { @observable @property({type: String}) name = ''; @observable @property({type: String}) description = ''; @observable @property({type: Boolean}) bordered = false; - @observable settings: {[name: string]: string|string[]} = {}; + @observable settings: {[name: string]: string|number|boolean|string[]} = {}; @property({type: Boolean, reflect: true}) opened = false; static get styles() { @@ -139,7 +139,8 @@ export class InterpreterControls extends ReactiveElement { item => item !== option); } }; - const isSelected = this.settings[name].indexOf(option) !== -1; + const isSelected = (this.settings[name] as string[]).indexOf( + option) !== -1; return html` @@ -193,9 +194,34 @@ export class InterpreterControls extends ReactiveElement {
${this.settings[name]}
`; } + else if (isLitSubtype(controlType, ['Boolean'])) { + // Render a checkbox. + const toggleVal = () => { + const val = !!this.settings[name]; + this.settings[name] = !val; + }; + // clang-format off + return html` + + + `; + // clang-format on + } + else if (isLitSubtype(controlType, ['Tokens'])) { + // Render a text input box and split on commas. + const value = this.settings[name] as string || ''; + const updateText = (e: Event) => { + const input = e.target! as HTMLInputElement; + this.settings[name] = input.value.split(',').map(val => val.trim()); + }; + return html``; + } else { // Render a text input box. - const value = this.settings[name] || ''; + const value = this.settings[name] as string || ''; const updateText = (e: Event) => { const input = e.target! as HTMLInputElement; this.settings[name] = input.value; diff --git a/lit_nlp/components/hotflip.py b/lit_nlp/components/hotflip.py index 2d77652b..117c0c42 100644 --- a/lit_nlp/components/hotflip.py +++ b/lit_nlp/components/hotflip.py @@ -44,6 +44,8 @@ NUM_EXAMPLES_DEFAULT = 5 MAX_FLIPS_KEY = "Maximum number of token flips" MAX_FLIPS_DEFAULT = 3 +TOKENS_TO_IGNORE_KEY = "Tokens to freeze" +TOKENS_TO_IGNORE_DEFAULT = [] class HotFlip(lit_components.Generator): @@ -79,6 +81,7 @@ def config_spec(self) -> types.Spec: return { NUM_EXAMPLES_KEY: types.TextSegment(default=str(NUM_EXAMPLES_DEFAULT)), MAX_FLIPS_KEY: types.TextSegment(default=str(MAX_FLIPS_DEFAULT)), + TOKENS_TO_IGNORE_KEY: types.Tokens(default=TOKENS_TO_IGNORE_DEFAULT) } def _subset_exists(self, cand_set, sets): @@ -88,9 +91,9 @@ def _subset_exists(self, cand_set, sets): return True return False - def _gen_tokens_to_flip(self, ntokens, max_flips): - for i in range(min(ntokens, max_flips)): - for s in itertools.combinations(range(ntokens), i+1): + def _gen_tokens_to_flip(self, token_idxs, max_flips): + for i in range(min(len(token_idxs), max_flips)): + for s in itertools.combinations(token_idxs, i+1): yield s def generate(self, @@ -105,6 +108,8 @@ def generate(self, config[NUM_EXAMPLES_KEY]) if config else NUM_EXAMPLES_DEFAULT max_flips = int( config[MAX_FLIPS_KEY]) if config else MAX_FLIPS_DEFAULT + tokens_to_ignore = (config[TOKENS_TO_IGNORE_KEY] if config + else TOKENS_TO_IGNORE_DEFAULT) assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") @@ -183,18 +188,22 @@ def generate(self, # TODO(ataly): Sort token sets of the same cardinality in decreasing # order of gradient (i.e., we wish to prioritize flipping tokens that # have the largest impact on the prediction.) - for token_positions in self._gen_tokens_to_flip(len(tokens), max_flips): + token_idxs_to_flip = [ + idx for idx in range(len(tokens)) + if tokens[idx] not in tokens_to_ignore] + for token_idxs in self._gen_tokens_to_flip( + token_idxs_to_flip, max_flips): if len(successful_counterfactuals) >= num_examples: return successful_counterfactuals # If a subset of the set of tokens have already been successful in # obtaining a flip, we continue. This ensure that we only consider # sets of token flips that are minimal. - if self._subset_exists(set(token_positions), successful_positions): + if self._subset_exists(set(token_idxs), successful_positions): continue logging.info("Selected tokens to flip: %s (positions=%s) with: %s", - [tokens[i] for i in token_positions], token_positions, - [replacement_tokens[i] for i in token_positions]) + [tokens[i] for i in token_idxs], token_idxs, + [replacement_tokens[i] for i in token_idxs]) # Create a new input to the model. # TODO(iftenney, bastings): enforce somewhere that this field has the @@ -203,7 +212,7 @@ def generate(self, input_text_field = input_spec[input_token_field].parent # pytype: disable=attribute-error counterfactual = copy.deepcopy(example) modified_tokens = copy.copy(tokens) - for j in token_positions: + for j in token_idxs: modified_tokens[j] = replacement_tokens[j] counterfactual[input_token_field] = modified_tokens # TODO(iftenney, bastings): call a model-provided detokenizer here? @@ -229,5 +238,5 @@ def generate(self, if counterfactual_prediction != orig_prediction: # Hotflip found successful_counterfactuals.append(counterfactual) - successful_positions.append(set(token_positions)) + successful_positions.append(set(token_idxs)) return successful_counterfactuals From 621eaaeb01ccc78349526aea598a9c04d7f7fb4e Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 12 Apr 2021 20:54:19 -0700 Subject: [PATCH 009/213] Silence type errors generated by new pytype features. PiperOrigin-RevId: 368140171 --- lit_nlp/examples/models/glue_models.py | 2 +- lit_nlp/examples/models/pretrained_lms.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lit_nlp/examples/models/glue_models.py b/lit_nlp/examples/models/glue_models.py index 378a851a..1103c2b3 100644 --- a/lit_nlp/examples/models/glue_models.py +++ b/lit_nlp/examples/models/glue_models.py @@ -213,7 +213,7 @@ def _postprocess(self, output: Dict[str, np.ndarray]): output["token_grad_" + self.config.text_b_name] = output["input_emb_grad"][slicer_b] if self.is_regression: - output["grad_class"] = None + output["grad_class"] = None # pytype: disable=container-type-mismatch else: # Return the label corresponding to the class index used for gradients. output["grad_class"] = self.config.labels[output["grad_class"]] diff --git a/lit_nlp/examples/models/pretrained_lms.py b/lit_nlp/examples/models/pretrained_lms.py index 784267ab..76b10396 100644 --- a/lit_nlp/examples/models/pretrained_lms.py +++ b/lit_nlp/examples/models/pretrained_lms.py @@ -111,7 +111,7 @@ def _postprocess(self, output: Dict[str, np.ndarray]): probas = output.pop("probas") # Predictions at every position, regardless of masking. - output["pred_tokens"] = self._get_topk_tokens(probas[slicer]) + output["pred_tokens"] = self._get_topk_tokens(probas[slicer]) # pytype: disable=container-type-mismatch return output From 03b272d218cfa8608d9a5b934296dd861211b599 Mon Sep 17 00:00:00 2001 From: James Wexler Date: Tue, 13 Apr 2021 09:54:46 -0700 Subject: [PATCH 010/213] Make Integrated Gradients options configurable in the UI. By default, explains argmax class but can be set to any class index. Also can set normalization and steps. Also updated LIME to by default explain the argmax class (instead of class 1). PiperOrigin-RevId: 368236153 --- lit_nlp/api/types.py | 2 +- lit_nlp/components/gradient_maps.py | 52 ++++++++++++++++++++-------- lit_nlp/components/lime_explainer.py | 43 +++++++++++++---------- 3 files changed, 64 insertions(+), 33 deletions(-) diff --git a/lit_nlp/api/types.py b/lit_nlp/api/types.py index 64315ec9..15a23d3f 100644 --- a/lit_nlp/api/types.py +++ b/lit_nlp/api/types.py @@ -350,4 +350,4 @@ class SalienceMap(LitType): @attr.s(auto_attribs=True, frozen=True, kw_only=True) class Boolean(LitType): """Boolean value.""" - pass + default: bool = False diff --git a/lit_nlp/components/gradient_maps.py b/lit_nlp/components/gradient_maps.py index c6df88d2..471f2581 100644 --- a/lit_nlp/components/gradient_maps.py +++ b/lit_nlp/components/gradient_maps.py @@ -31,6 +31,10 @@ JsonDict = types.JsonDict Spec = types.Spec +CLASS_KEY = 'Class index to explain' +NORMALIZATION_KEY = 'Normalize' +INTERPOLATION_KEY = 'Interpolation steps' + class GradientNorm(lit_components.Interpreter): """Salience map from gradient L2 norm.""" @@ -201,10 +205,6 @@ class IntegratedGradients(lit_components.Interpreter): label for all integral steps, since the argmax prediction may change. """ - def __init__(self, interpolation_steps=30): - # TODO(b/168042999): Make this parameter configurable in the UI. - self.interpolation_steps = interpolation_steps - def find_fields(self, input_spec: Spec, output_spec: Spec) -> List[Text]: # Find TokenGradients fields grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients) @@ -275,7 +275,9 @@ def get_baseline(self, embeddings: np.ndarray) -> np.ndarray: return baseline def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, - model_output: JsonDict, grad_fields: List[Text]): + interpolation_steps: int, normalize: bool, + class_to_explain: int, model_output: JsonDict, + grad_fields: List[Text]): result = {} output_spec = model.output_spec() @@ -287,11 +289,15 @@ def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, # The gradient class input is used to specify the target class of the # gradient calculation (if unspecified, this option defaults to the argmax, - # which could flip between interpolated inputs). + # which could flip between interpolated inputs). If class_to_explain is -1, + # then explain the argmax class. grad_class_key = cast(types.TokenGradients, output_spec[grad_fields[0]]).grad_target - # TODO(b/168042999): Add option to specify the class to explain in the UI. - grad_class = model_output[grad_class_key] + if class_to_explain == -1: + grad_class = model_output[grad_class_key] + else: + grad_class = cast(types.CategoryLabel, + output_spec[grad_class_key]).vocab[class_to_explain] interpolated_inputs = {} all_embeddings = [] @@ -308,11 +314,11 @@ def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, # Get interpolated inputs from baseline to original embedding. # [interpolation_steps, num_tokens, emb_size] interpolated_inputs[embed_field] = self.get_interpolated_inputs( - baseline, embeddings, self.interpolation_steps) + baseline, embeddings, interpolation_steps) # Create model inputs and populate embedding field(s). inputs_with_embeds = [] - for i in range(self.interpolation_steps): + for i in range(interpolation_steps): input_copy = model_input.copy() # Interpolates embeddings for all inputs simultaneously. for embed_field in embeddings_fields: @@ -349,9 +355,9 @@ def get_salience_result(self, model_input: JsonDict, model: lit_model.Model, # [total_num_tokens] attributions = np.sum(integrated_gradients, axis=-1) - # TODO(b/168042999): Make normalization customizable in the UI. # [total_num_tokens] - scores = citrus_utils.normalize_scores(attributions) + scores = citrus_utils.normalize_scores( + attributions) if normalize else attributions for grad_field in grad_fields: # Format as salience map result. @@ -375,6 +381,13 @@ def run(self, model_outputs: Optional[List[JsonDict]] = None, config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" + class_to_explain = int(config[CLASS_KEY] if config else + self.config_spec()[CLASS_KEY].default) + interpolation_steps = int(config[INTERPOLATION_KEY] if config else + self.config_spec()[INTERPOLATION_KEY].default) + normalization = (config[NORMALIZATION_KEY] if config + else self.config_spec()[NORMALIZATION_KEY].default) + # Find gradient fields to interpret input_spec = model.input_spec() output_spec = model.output_spec() @@ -389,8 +402,9 @@ def run(self, all_results = [] for model_output, model_input in zip(model_outputs, inputs): - result = self.get_salience_result(model_input, model, model_output, - grad_fields) + result = self.get_salience_result(model_input, model, interpolation_steps, + normalization, class_to_explain, + model_output, grad_fields) all_results.append(result) return all_results @@ -399,5 +413,15 @@ def is_compatible(self, model: lit_model.Model): model.input_spec(), model.output_spec()) return len(compatible_fields) + def config_spec(self) -> types.Spec: + return { + # TODO(lit-dev): Consider making class to predict strings using + # dropdowns on the front-end as opposed to class indicies. + CLASS_KEY: types.TextSegment(default='-1'), + NORMALIZATION_KEY: types.Boolean(default=True), + INTERPOLATION_KEY: types.Scalar( + min_val=5, max_val=100, default=30, step=1) + } + def meta_spec(self) -> types.Spec: return {'saliency': types.SalienceMap(autorun=False, signed=True)} diff --git a/lit_nlp/components/lime_explainer.py b/lit_nlp/components/lime_explainer.py index df636ada..a79cda55 100644 --- a/lit_nlp/components/lime_explainer.py +++ b/lit_nlp/components/lime_explainer.py @@ -39,11 +39,6 @@ MASK_KEY = 'Mask' NUM_SAMPLES_KEY = 'Number of samples' SEED_KEY = 'Seed' -CLASS_DEFAULT = 1 -KERNEL_WIDTH_DEFAULT = 256 -MASK_DEFAULT = '[MASK]' -NUM_SAMPLES_DEFAULT = 256 -SEED_DEFAULT = None def new_example(original_example: JsonDict, field: str, new_value: Any): @@ -85,13 +80,17 @@ def run( ) -> Optional[List[JsonDict]]: """Run this component, given a model and input(s).""" - class_to_explain = int(config[CLASS_KEY]) if config else CLASS_DEFAULT - kernel_width = int( - config[KERNEL_WIDTH_KEY]) if config else KERNEL_WIDTH_DEFAULT - mask_string = config[MASK_KEY] if config else MASK_DEFAULT - num_samples = int( - config[NUM_SAMPLES_KEY]) if config else NUM_SAMPLES_DEFAULT - seed = config[SEED_KEY] if config else SEED_DEFAULT + class_to_explain = int(config[CLASS_KEY] if config else + self.config_spec()[CLASS_KEY].default) + kernel_width = int(config[KERNEL_WIDTH_KEY] if config else + self.config_spec()[KERNEL_WIDTH_KEY].default) + num_samples = int(config[NUM_SAMPLES_KEY] if config else + self.config_spec()[NUM_SAMPLES_KEY].default) + mask_string = (config[MASK_KEY] if config + else self.config_spec()[MASK_KEY].default) + seed_str = (config[SEED_KEY] if config + else self.config_spec()[SEED_KEY].default) + seed = int(seed_str) if seed_str else None # Find keys of input (text) segments to explain. # Search in the input spec, since it's only useful to look at ones that are @@ -119,6 +118,14 @@ def run( predict_fn = functools.partial( _predict_fn, model=model, original_example=input_, pred_key=pred_key) + # If class_to_explain is -1, then explain the argmax class + if (isinstance(model.output_spec()[pred_key], types.MulticlassPreds) and + class_to_explain == -1): + pred = list(model.predict([input_]))[0] + class_to_explain_for_input = np.argmax(pred[pred_key]) + else: + class_to_explain_for_input = class_to_explain + # Explain each text segment in the input, keeping the others constant. for text_key in text_keys: input_string = input_[text_key] @@ -132,7 +139,7 @@ def run( sentence=input_string, predict_fn=functools.partial(predict_fn, text_key=text_key), # `class_to_explain` is ignored when predict_fn output is a scalar. - class_to_explain=class_to_explain, # Index of the class to explain. + class_to_explain=class_to_explain_for_input, num_samples=num_samples, tokenizer=str.split, mask_token=mask_string, @@ -152,11 +159,11 @@ def run( def config_spec(self) -> types.Spec: return { - CLASS_KEY: types.TextSegment(default=str(CLASS_DEFAULT)), - KERNEL_WIDTH_KEY: types.TextSegment(default=str(KERNEL_WIDTH_DEFAULT)), - MASK_KEY: types.TextSegment(default=MASK_DEFAULT), - NUM_SAMPLES_KEY: types.TextSegment(default=str(NUM_SAMPLES_DEFAULT)), - SEED_KEY: types.TextSegment(default=SEED_DEFAULT), + CLASS_KEY: types.TextSegment(default='-1'), + KERNEL_WIDTH_KEY: types.TextSegment(default='256'), + MASK_KEY: types.TextSegment(default='[MASK]'), + NUM_SAMPLES_KEY: types.TextSegment(default='256'), + SEED_KEY: types.TextSegment(default=''), } def is_compatible(self, model: lit_model.Model): From 096ff982eea0283f63bc69601ab388272e89335a Mon Sep 17 00:00:00 2001 From: Noah Broestl Date: Tue, 13 Apr 2021 22:07:44 +0000 Subject: [PATCH 011/213] More documentation and a TODO for the UD dataset --- lit_nlp/examples/models/stanza_models.py | 38 +++++++++++++++++++----- lit_nlp/examples/stanza_demo.py | 1 + 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/lit_nlp/examples/models/stanza_models.py b/lit_nlp/examples/models/stanza_models.py index 33ea832e..a776e480 100644 --- a/lit_nlp/examples/models/stanza_models.py +++ b/lit_nlp/examples/models/stanza_models.py @@ -74,8 +74,8 @@ def _predict(self, ex): doc = self.model(ex["sentence"]) prediction = {task: [] for task in self._output_spec} for sentence in doc.sentences: - # Get offset value to align task to tokens for multiple sentences - offset = len(prediction['tokens']) + # Get starting token of the offset to align task to tokens for multiple sentences + start_token = len(prediction['tokens']) prediction["tokens"].extend([word.text for word in sentence.words]) # Process each sequence task @@ -89,7 +89,7 @@ def _predict(self, ex): for entity in sentence.entities: # Stanza indexes start/end of entities on char. LIT needs them as token indexes start, end = entity_char_to_token(entity, sentence) - span_label = SpanLabel(start=start+offset, end=end+offset, label=entity.type) + span_label = SpanLabel(start=start+start_token, end=end+start_token, label=entity.type) prediction[task].append(span_label) else: raise ValueError(f"Invalid span task: '{task}'") @@ -100,8 +100,8 @@ def _predict(self, ex): if task == "deps": for relation in sentence.dependencies: label = relation[1] - span1 = relation[2].id + offset - span2 = relation[2].id + offset if label == "root" else relation[0].id + offset + span1 = relation[2].id + start_token + span2 = relation[2].id + start_token if label == "root" else relation[0].id + start_token # Relation lists have a root value at index 0, so subtract 1 to align them to tokens edge_label = EdgeLabel( (span1 - 1, span1), (span2 - 1, span2), label @@ -125,15 +125,37 @@ def output_spec(self): def entity_char_to_token(entity, sentence): """Takes Stanza entity and sentence objects and returns the start and end tokens for the entity + The misc value in a stanza sentence object contains a string with additional + information, separated by a pipe character. This string contains the + start_char and end_char for each token, along with other information. This is + extracted and used to match the start_char and end char values in a span + object to return the start and end tokens for the entity. + + Example entity: + {'text': 'Barrack Obama', + 'type': 'PERSON', + 'start_char': 0, + 'end_char': 13} + Example sentence: + [ + {'id': 1, + 'text': 'Barrack', + ..., + 'misc': 'start_char=0|end_char=7'}, + {'id': 2, + 'text': 'Obama', + ..., + 'misc': 'start_char=8|end_char=13'} + ] + Args: - entity: Stanza entity object - sentence: Stanza sentence object + entity: Stanza Span object + sentence: Stanza Sentence object Returns: Returns the token index of start and end locations for the entity """ start_token, end_token = None, None for i, v in enumerate(sentence.words): - # Misc is a string of values, separated by |, that contains start and end chars x = v.misc.split("|") if "start_char=" + str(entity.start_char) in x: start_token = i diff --git a/lit_nlp/examples/stanza_demo.py b/lit_nlp/examples/stanza_demo.py index 1eec82a9..c7dfc228 100644 --- a/lit_nlp/examples/stanza_demo.py +++ b/lit_nlp/examples/stanza_demo.py @@ -76,6 +76,7 @@ def main(_): } # Datasets for LIT demo + # TODO: Use the UD dataset (https://huggingface.co/datasets/universal_dependencies) datasets = { "SST2": glue.SST2Data(split="validation").slice[: FLAGS.max_examples], "blank": lit_dataset.Dataset({"text": lit_types.TextSegment()}, []), From 1253494250fe9b46c72dfccef7231215a81e9846 Mon Sep 17 00:00:00 2001 From: Tolga Bolukbasi Date: Wed, 14 Apr 2021 11:53:59 -0700 Subject: [PATCH 012/213] Allow sharing multiple examples from URL when launching LIT. The sync is one way. LIT still only syncs primary selection back to URL. PiperOrigin-RevId: 368476811 --- lit_nlp/client/services/url_service.ts | 62 +++++++++++++++++--------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/lit_nlp/client/services/url_service.ts b/lit_nlp/client/services/url_service.ts index c9d27d71..4a3ec946 100644 --- a/lit_nlp/client/services/url_service.ts +++ b/lit_nlp/client/services/url_service.ts @@ -33,8 +33,10 @@ export class UrlConfiguration { /** * For datapoints that are not in the original dataset, the fields * and their values are added directly into the url. + * LIT can load multiple examples from the url, but can only share + * primary selected example. */ - dataFields: Input = {}; + dataFields: {[key: number]: Input} = {}; selectedDataset?: string; hiddenModules: string[] = []; compareExamplesEnabled?: boolean; @@ -94,8 +96,13 @@ const NEW_DATASET_PATH = 'new_dataset_path'; const MAX_IDS_IN_URL_SELECTION = 100; const makeDataFieldKey = (key: string) => `${DATA_FIELDS_KEY_SUBSTRING}_${key}`; -const parseDataFieldKey = (key: string) => - key.replace(`${DATA_FIELDS_KEY_SUBSTRING}_`, ''); +const parseDataFieldKey = (key: string) => { + // Split string into two from the first underscore, + // data{index}_{feature}={val} -> [data{index}, {feature}={val}] + const pieces = key.split(/_([^]*)/, 2); + const indexStr = pieces[0].replace(DATA_FIELDS_KEY_SUBSTRING, ''); + return {fieldKey: pieces[1], dataIndex: +(indexStr || '0')}; +}; /** * Singleton service responsible for deserializing / serializing state to / from @@ -104,6 +111,9 @@ const parseDataFieldKey = (key: string) => export class UrlService extends LitService { /** Parse arrays in a url param, filtering out empty strings */ private urlParseArray(encoded: string) { + if (encoded == null) { + return []; + } const array = encoded.split(','); return array.filter(str => str !== ''); } @@ -134,7 +144,7 @@ export class UrlService extends LitService { const urlConfiguration = new UrlConfiguration(); const urlSearchParams = new URLSearchParams(window.location.search); - urlSearchParams.forEach((value: string, key: string) => { + for (const [key, value] of urlSearchParams) { if (key === SELECTED_MODELS_KEY) { urlConfiguration.selectedModels = this.urlParseArray(value); } else if (key === SELECTED_DATA_KEY) { @@ -153,16 +163,25 @@ export class UrlService extends LitService { urlConfiguration.layoutName = this.urlParseString(value); } else if (key === NEW_DATASET_PATH) { urlConfiguration.newDatasetPath = this.urlParseString(value); - } else if (key.includes(DATA_FIELDS_KEY_SUBSTRING)) { - const fieldKey = parseDataFieldKey(key); + } else if (key.startsWith(DATA_FIELDS_KEY_SUBSTRING)) { + const {fieldKey, dataIndex}: {fieldKey: string, dataIndex: number} = + parseDataFieldKey(key); // TODO(b/179788207) Defer parsing of data keys here as we do not have // access to the input spec of the dataset at the time // this is called. We convert array fields to their proper forms in // syncSelectedDatapointToUrl. - urlConfiguration.dataFields[fieldKey] = value; + if (!(dataIndex in urlConfiguration.dataFields)) { + urlConfiguration.dataFields[dataIndex] = {}; + } + // Warn if an example is overwritten, this only happens if url is + // malformed to contain the same index more than once. + if (fieldKey in urlConfiguration.dataFields[dataIndex]) { + console.log( + `Warning, data index ${dataIndex} is set more than once.`); + } + urlConfiguration.dataFields[dataIndex][fieldKey] = value; } - }); - + } return urlConfiguration; } @@ -263,14 +282,12 @@ export class UrlService extends LitService { selectionService: SelectionObservedByUrlService, ) { const urlConfiguration = appState.getUrlConfiguration(); - const fields = urlConfiguration.dataFields; - // Create a new dict and do not modify the urlConfiguration. This makes sure - // that this call works even if initialize app is called multiple times. - const outputFields: Input = {}; - - // If there are data fields set in the url, make a new datapoint - // from them. - if (Object.keys(fields).length) { + const dataFields = urlConfiguration.dataFields; + const dataToAdd = Object.values(dataFields).map((fields: Input) => { + // Create a new dict and do not modify the urlConfiguration. This makes + // sure that this call works even if initialize app is called multiple + // times. + const outputFields: Input = {}; const spec = appState.currentDatasetSpec; Object.keys(spec).forEach(key => { outputFields[key] = this.parseDataFieldValue(key, fields[key], spec); @@ -280,11 +297,16 @@ export class UrlService extends LitService { id: '', // will be overwritten meta: {source: 'url', added: true}, }; - const data = await appState.indexDatapoints([datum]); + return datum; + }); + // If there are data fields set in the url, make new datapoints + // from them and select all passed data points. + // TODO(b/185155960) Allow specifying selection for passed examples in url. + if (dataToAdd.length > 0) { + const data = await appState.indexDatapoints(dataToAdd); appState.commitNewDatapoints(data); - selectionService.setPrimarySelection(data[0].id); + selectionService.selectIds(data.map((d) => d.id)); } - // Otherwise, use the primary selected datapoint url param directly. else { const id = urlConfiguration.primarySelectedData; From e7fed06c25f9bd3eda588fb753e66df7985c04c4 Mon Sep 17 00:00:00 2001 From: Ankur Taly Date: Thu, 15 Apr 2021 10:15:30 -0700 Subject: [PATCH 013/213] Hotflip: Sort tokens by gradient before generating subsets PiperOrigin-RevId: 368664474 --- lit_nlp/components/hotflip.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/lit_nlp/components/hotflip.py b/lit_nlp/components/hotflip.py index 117c0c42..fe1bb964 100644 --- a/lit_nlp/components/hotflip.py +++ b/lit_nlp/components/hotflip.py @@ -46,6 +46,7 @@ MAX_FLIPS_DEFAULT = 3 TOKENS_TO_IGNORE_KEY = "Tokens to freeze" TOKENS_TO_IGNORE_DEFAULT = [] +MAX_FLIPPABLE_TOKENS = 10 class HotFlip(lit_components.Generator): @@ -104,12 +105,11 @@ def generate(self, """Identify minimal sets of token flips that alter the prediction.""" del dataset # Unused. - num_examples = int( - config[NUM_EXAMPLES_KEY]) if config else NUM_EXAMPLES_DEFAULT - max_flips = int( - config[MAX_FLIPS_KEY]) if config else MAX_FLIPS_DEFAULT - tokens_to_ignore = (config[TOKENS_TO_IGNORE_KEY] if config - else TOKENS_TO_IGNORE_DEFAULT) + config = config or {} + num_examples = int(config.get(NUM_EXAMPLES_KEY, NUM_EXAMPLES_DEFAULT)) + max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) + tokens_to_ignore = config.get(TOKENS_TO_IGNORE_KEY, + TOKENS_TO_IGNORE_DEFAULT) assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") @@ -185,12 +185,20 @@ def generate(self, # We will iterate through this list (in toplogically sorted order) # and at each iteration, replace the selected tokens with corresponding # replacement tokens and checks if the prediction flips. - # TODO(ataly): Sort token sets of the same cardinality in decreasing - # order of gradient (i.e., we wish to prioritize flipping tokens that - # have the largest impact on the prediction.) - token_idxs_to_flip = [ - idx for idx in range(len(tokens)) - if tokens[idx] not in tokens_to_ignore] + # At each level of the topological sort, we will consider combinations + # by ordering tokens by gradient L2 (i.e., we wish to prioritize flipping + # tokens that may have the largest impact on the prediction.) + token_grads_l2 = np.sum(grads * grads, axis=-1) + # TODO(ataly, bastings): Consider sorting by attributions (either + # Integrated Gradients or Shapley values). + token_idxs_sorted_by_grads = np.argsort(token_grads_l2)[::-1] + token_idxs_to_flip = [idx for idx in token_idxs_sorted_by_grads + if tokens[idx] not in tokens_to_ignore] + + # If the number of tokens considered for flipping is larger than + # MAX_FLIPPABLE_TOKENS we only consider the top tokens. + token_idxs_to_flip = token_idxs_to_flip[:MAX_FLIPPABLE_TOKENS] + for token_idxs in self._gen_tokens_to_flip( token_idxs_to_flip, max_flips): if len(successful_counterfactuals) >= num_examples: From 6dfb8c062e467cbeee62551723c6dc180ab4e4f2 Mon Sep 17 00:00:00 2001 From: James Wexler Date: Fri, 16 Apr 2021 08:11:01 -0700 Subject: [PATCH 014/213] Rewrite of widget group layout/width logic. No longer uses flexbox to avoid a variety of issues with resizing. Modules.ts now calculates sets the width of each widget_group explicitly, as opposed to the groups calculating their own widths. This allows for smoother/faster width changes and more predictable behavior for the user when resizing widget groups. Removed transition animations of module groups to improve performance and visual smoothless as well. Also shrank width of classification module bars since that module took up unnecessary width. PiperOrigin-RevId: 368848905 --- lit_nlp/client/core/modules.ts | 130 ++++++++++++++++-- lit_nlp/client/core/widget_group.css | 6 - lit_nlp/client/core/widget_group.ts | 78 ++++------- .../client/modules/classification_module.ts | 2 +- 4 files changed, 147 insertions(+), 69 deletions(-) diff --git a/lit_nlp/client/core/modules.ts b/lit_nlp/client/core/modules.ts index 9172942b..dfc2acc5 100644 --- a/lit_nlp/client/core/modules.ts +++ b/lit_nlp/client/core/modules.ts @@ -18,20 +18,33 @@ /** * Client-side (UI) code for the LIT tool. */ - -import {customElement, html, LitElement, property} from 'lit-element'; +// tslint:disable:no-new-decorators +import {customElement, html, property} from 'lit-element'; +import {observable} from 'mobx'; import {classMap} from 'lit-html/directives/class-map'; import {styleMap} from 'lit-html/directives/style-map'; import '@material/mwc-icon'; +import {ReactiveElement} from '../lib/elements'; import {LitRenderConfig, RenderConfig} from '../services/modules_service'; import {ModulesService} from '../services/services'; import {app} from './lit_app'; import {LitModule} from './lit_module'; import {styles} from './modules.css'; -import {LitWidget} from './widget_group'; +import {LitWidget, MIN_GROUP_WIDTH_PX} from './widget_group'; + +// Number of columns in the full width of the layout. +const NUM_COLS = 12; +// Width of a minimized widget group. From widget_group.css. +const MINIMIZED_WIDTH_PX = 36; + +// Contains for each section (main section, or a tab), a mapping of widget +// groups to their calculated widths. +interface LayoutWidths { + [layoutSection: string]: number[]; +} /** * The component responsible for rendering the selected and available lit @@ -39,10 +52,12 @@ import {LitWidget} from './widget_group'; * to explicitly control when it rerenders (via the setRenderModulesCallback). */ @customElement('lit-modules') -export class LitModules extends LitElement { +export class LitModules extends ReactiveElement { private readonly modulesService = app.getService(ModulesService); @property({type: Number}) mainSectionHeight = this.modulesService.getSetting('mainHeight') || 45; + @observable layoutWidths: LayoutWidths = {}; + private resizeObserver!: ResizeObserver; static get styles() { return styles; @@ -56,6 +71,62 @@ export class LitModules extends LitElement { this.modulesService.setRenderModulesCallback(() => { this.requestUpdate(); }); + + this.resizeObserver = new ResizeObserver(() => { + this.calculateWidths(this.modulesService.getRenderLayout()); + }); + this.resizeObserver.observe(this); + + this.reactImmediately( + () => this.modulesService.getRenderLayout(), renderLayout => { + this.calculateWidths(renderLayout); + }); + } + + // Calculate widths of all module groups in all panels. + calculateWidths(renderLayout: LitRenderConfig) { + const panelNames = Object.keys(renderLayout); + for (const panelName of panelNames) { + this.layoutWidths[panelName] = []; + this.calculatePanelWidths(panelName, renderLayout[panelName]); + } + } + + // Calculate widths of all module groups in a single panel. + calculatePanelWidths(panelName: string, panelConfig: RenderConfig[][]) { + // Get the number of minimized widget groups to calculate the total width + // available for non-minimized widgets. + let numMinimized = 0; + for (const configGroup of panelConfig) { + if (this.modulesService.isModuleGroupHidden(configGroup[0])) { + numMinimized +=1; + } + } + const widthAvailable = + window.innerWidth - MINIMIZED_WIDTH_PX * numMinimized; + + // Get the total number of columns requested for the non-minimized widget + // groups. + let totalCols = 0; + for (const configGroup of panelConfig) { + if (this.modulesService.isModuleGroupHidden(configGroup[0])) { + continue; + } + const numColsList = configGroup.map(config => config.moduleType.numCols); + totalCols += Math.max(...numColsList); + } + // Ensure that when a panel requests less than the full width of columns + // that the widget groups still use up the entire width available. + const totalColsToUse = Math.min(totalCols, NUM_COLS); + + // Set the width for each widget group based on the maximum number of + // columns it's widgets have specified and the width available. + for (let i = 0; i < panelConfig.length; i++) { + const configGroup = panelConfig[i]; + const numColsList = configGroup.map(config => config.moduleType.numCols); + const width = Math.max(...numColsList) / totalColsToUse * widthAvailable; + this.layoutWidths[panelName][i] = width; + } } disconnectedCallback() { @@ -109,7 +180,7 @@ export class LitModules extends LitElement { // clang-format off return html`
- ${this.renderMainPanel(mainPanelConfig)} + ${this.renderWidgetGroups(mainPanelConfig, 'Main')}
@@ -146,15 +217,11 @@ export class LitModules extends LitElement { tabToSelect: string) { return compGroupNames.map((compGroupName) => { const configs = layout[compGroupName]; - const componentsHTML = - configs.map(configGroup => - html` - `); const selected = tabToSelect === compGroupName; const classes = classMap({selected, 'components-group-holder': true}); return html`
- ${componentsHTML} + ${this.renderWidgetGroups(configs, compGroupName)}
`; }); } @@ -182,8 +249,47 @@ export class LitModules extends LitElement { }); } - renderMainPanel(configs: RenderConfig[][]) { - return configs.map(configGroup =>html``); + renderWidgetGroups(configs: RenderConfig[][], section: string) { + // Calllback for widget isMinimized state changes. + const onMin = (event: Event) => { + // Recalculate the widget group widths in this section. + this.calculatePanelWidths(section, configs); + }; + + return configs.map((configGroup, i) => { + + // Callback from widget width drag events. + const onDrag = (event: Event) => { + // tslint:disable-next-line:no-any + const dragWidth = (event as any).detail.dragWidth; + + // If the dragged group isn't the right-most group, then balance the + // delta in width with the widget directly to it's left (so if a widget + // is expanded, then its adjacent widget is shrunk by the same amount). + if (i < configs.length - 1) { + const adjacentConfig = configs[i + 1]; + if (!this.modulesService.isModuleGroupHidden(adjacentConfig[0])) { + const widthChange = dragWidth - + this.layoutWidths[section][i]; + const oldAdjacentWidth = + this.layoutWidths[section][i + 1]; + this.layoutWidths[section][i + 1] = + Math.max(MIN_GROUP_WIDTH_PX, oldAdjacentWidth - widthChange); + } + } + + // Set the width of the dragged widget group. + this.layoutWidths[section][i] = dragWidth; + + this.requestUpdate(); + }; + + const width = this.layoutWidths[section] ? + this.layoutWidths[section][i] : 0; + return html``; + }); } } diff --git a/lit_nlp/client/core/widget_group.css b/lit_nlp/client/core/widget_group.css index 926f85b7..a182682f 100644 --- a/lit_nlp/client/core/widget_group.css +++ b/lit_nlp/client/core/widget_group.css @@ -2,7 +2,6 @@ flex: var(--flex); min-width: var(--width); width: var(--width); - transition: all .25s; } @@ -13,7 +12,6 @@ } :host([maximized]) { - transition: none; margin: 15px 45px; padding: 0; width: calc(100vw - 90px) !important; @@ -25,10 +23,6 @@ box-shadow: rgba(0, 0, 0, 0.14) 0px 2px 2px 0px, rgba(0, 0, 0, 0.2) 0px 3px 1px -2px, rgba(0, 0, 0, 0.12) 0px 1px 5px 0px; } -:host([dragging]) { - transition: none; -} - .wrapper { padding: 4pt; height: 100%; diff --git a/lit_nlp/client/core/widget_group.ts b/lit_nlp/client/core/widget_group.ts index b2d7e0ed..73ebff5e 100644 --- a/lit_nlp/client/core/widget_group.ts +++ b/lit_nlp/client/core/widget_group.ts @@ -34,7 +34,13 @@ import {LitModule} from './lit_module'; import {styles as widgetStyles} from './widget.css'; import {styles as widgetGroupStyles} from './widget_group.css'; -const NUM_COLS = 12; +/** Minimum width for a widget group. */ +export const MIN_GROUP_WIDTH_PX = 100; + +// Width changes below this delta aren't bubbled up, to avoid unnecssary width +// recalculations. +const MIN_GROUP_WIDTH_DELTA_PX = 10; + /** * Renders a group of widgets (one per model, and one per datapoint if * compareDatapoints is enabled) for a single component. @@ -46,7 +52,7 @@ export class WidgetGroup extends LitElement { @property({ type: Boolean, reflect: true }) minimized = false; @property({ type: Boolean, reflect: true }) maximized = false; @property({ type: Boolean, reflect: true }) dragging = false; - @property({ type: Number }) userSetNumCols = 0; + @property({ type: Number}) width = 0; private widgetScrollTop = 0; private widgetScrollLeft = 0; // Not set as @observable since re-renders were not occuring when changed. @@ -134,7 +140,11 @@ export class WidgetGroup extends LitElement { const modulesInGroup = configGroup.length > 1; const duplicateAsRow = configGroup[0].moduleType.duplicateAsRow; - this.setWidthValues(configGroup, duplicateAsRow); + // Set width properties based on provided width. + const host = this.shadowRoot!.host as HTMLElement; + const width = `${this.width}px`; + host.style.setProperty('--width', width); + host.style.setProperty('--min-width', width); const wrapperClasses = classMap({ 'wrapper': true, @@ -213,36 +223,22 @@ export class WidgetGroup extends LitElement { renderExpander() { const dragged = (e: DragEvent) => { - // The sizes of the divs is a bit complicated because we are both - // setting hardcoded widths, but also allowing flexbox to expand - // the modules to fill remaining space. So, to have drag-to-resize - // be consistent, we know what width the div should be (based on - // the user's mouse), then back-calculate what the actual set vw - // width set should be so that flexbox will expand the module to - // the desired width. const holder = this.shadowRoot!.querySelector('.holder')!; - - // Actual div width and left positions (set by flexbox rendering). const left = holder.getBoundingClientRect().left; - const width = holder.getBoundingClientRect().width; - const fullWidth = window.innerWidth; - - // Ratio of flex-set width to our calculated width. - const flexRatio = width/(this.userSetNumCols/NUM_COLS * fullWidth); - - // Updated number of columns from the drag. const dragWidth = e.clientX - left; - if (dragWidth > 0) { - const numCols = dragWidth / fullWidth * NUM_COLS; - // For perf reasons, only update in incriments of .1 columns. - this.userSetNumCols = +Math.max(numCols/flexRatio, 1).toFixed(1); + + if (dragWidth > MIN_GROUP_WIDTH_PX && + Math.abs(dragWidth - this.width) > MIN_GROUP_WIDTH_DELTA_PX) { + const event = new CustomEvent('widget-group-drag', { + detail: { + dragWidth, + } + }); + this.dispatchEvent(event); } }; const dragStarted = () => { - if (!this.userSetNumCols) { - this.userSetNumCols = this.configGroup[0].moduleType.numCols; - } this.dragging = true; }; @@ -259,30 +255,6 @@ export class WidgetGroup extends LitElement { } else { return html``; } - - } - - /** Returns styling with flex set based off of max columns of all configs. */ - setWidthValues(configs: RenderConfig[], duplicateAsRow: boolean) { - const numColsList = configs.map(config => config.moduleType.numCols); - // In row duplication, the flex should be the sum of the child flexes, and - // in column duplication, it should be the maximum of the child flexes. - let maxFlex = duplicateAsRow ? numColsList.reduce((a, b) => a + b, 0) : - Math.max(...numColsList); - - // If the user manually set the number of columns, just use that instead. - if (this.userSetNumCols) { - maxFlex = this.userSetNumCols; - } - const width = this.flexGrowToWidth(maxFlex); - const host = this.shadowRoot!.host as HTMLElement; - host.style.setProperty('--flex', maxFlex.toString()); - host.style.setProperty('--width', width); - host.style.setProperty('--min-width', width); - } - - private flexGrowToWidth(flexGrow: number) { - return (flexGrow / NUM_COLS * 100).toFixed(3).toString() + '%'; } private initMinimized() { @@ -294,6 +266,12 @@ export class WidgetGroup extends LitElement { const config = this.configGroup[0]; this.modulesService.toggleHiddenModule(config, isMinimized); this.minimized = isMinimized; + const event = new CustomEvent('widget-group-minimized-changed', { + detail: { + isMinimized + } + }); + this.dispatchEvent(event); } } diff --git a/lit_nlp/client/modules/classification_module.ts b/lit_nlp/client/modules/classification_module.ts index 52041140..4e6c3e4b 100644 --- a/lit_nlp/client/modules/classification_module.ts +++ b/lit_nlp/client/modules/classification_module.ts @@ -150,7 +150,7 @@ export class ClassificationModule extends LitModule { barStyle['margin-left'] = `${margin}%`; barStyle['margin-right'] = `${margin}%`; const holderStyle: {[name: string]: string} = {}; - holderStyle['width'] = '200px'; + holderStyle['width'] = '100px'; holderStyle['height'] = '20px'; holderStyle['display'] = 'flex'; holderStyle['position'] = 'relative'; From 489330fd526daed0d04988cbb3c5a5c524212cf0 Mon Sep 17 00:00:00 2001 From: James Wexler Date: Fri, 16 Apr 2021 09:08:00 -0700 Subject: [PATCH 015/213] Add plots for scalar input features. Also adds display of value of primary selected example, if there is one. PiperOrigin-RevId: 368858422 --- lit_nlp/client/modules/scalar_module.css | 10 +++++ lit_nlp/client/modules/scalar_module.ts | 56 ++++++++++++++++++------ lit_nlp/client/services/services.ts | 1 + lit_nlp/examples/coref/model.py | 8 ---- 4 files changed, 54 insertions(+), 21 deletions(-) diff --git a/lit_nlp/client/modules/scalar_module.css b/lit_nlp/client/modules/scalar_module.css index 761beeec..c644fd7c 100644 --- a/lit_nlp/client/modules/scalar_module.css +++ b/lit_nlp/client/modules/scalar_module.css @@ -67,3 +67,13 @@ padding-top: 5px; } +.axis-title { + display: flex; + justify-content: space-between; + align-items: center; + flex: 1; +} + +.selected-value { + padding-right: 8px; +} diff --git a/lit_nlp/client/modules/scalar_module.ts b/lit_nlp/client/modules/scalar_module.ts index 252cada0..4d810cb0 100644 --- a/lit_nlp/client/modules/scalar_module.ts +++ b/lit_nlp/client/modules/scalar_module.ts @@ -26,10 +26,10 @@ const seedrandom = require('seedrandom'); // from //third_party/javascript/typi import {app} from '../core/lit_app'; import {LitModule} from '../core/lit_module'; -import {D3Selection, IndexedInput, ModelInfoMap, NumericSetting, Preds, Spec} from '../lib/types'; +import {D3Selection, formatForDisplay, IndexedInput, ModelInfoMap, ModelSpec, NumericSetting, Preds, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys, getThresholdFromMargin, isLitSubtype} from '../lib/utils'; import {FocusData} from '../services/focus_service'; -import {ClassificationService, ColorService, FocusService, RegressionService} from '../services/services'; +import {ClassificationService, ColorService, GroupService, FocusService, RegressionService} from '../services/services'; import {styles} from './scalar_module.css'; import {styles as sharedStyles} from './shared_styles.css'; @@ -82,6 +82,7 @@ export class ScalarModule extends LitModule { private readonly colorService = app.getService(ColorService); private readonly classificationService = app.getService(ClassificationService); + private readonly groupService = app.getService(GroupService); private readonly regressionService = app.getService(RegressionService); private readonly focusService = app.getService(FocusService); @@ -98,6 +99,11 @@ export class ScalarModule extends LitModule { @observable private plotWidth = ScalarModule.maxPlotWidth; @observable private plotHeight = ScalarModule.minPlotHeight; + @computed + private get inputKeys() { + return this.groupService.numericalFeatureNames; + } + @computed private get scalarKeys() { const outputSpec = this.appState.currentModelSpecs[this.model].spec.output; @@ -341,6 +347,9 @@ export class ScalarModule extends LitModule { const pred = Object.assign( {}, classificationPreds[i], scalarPreds[i], regressionPreds[i], {id: currId}); + for (const inputKey of this.inputKeys) { + pred[inputKey] = currentInputData[i].data[inputKey]; + } preds.push(pred); } @@ -380,6 +389,8 @@ export class ScalarModule extends LitModule { scoreRange[0] = scoreRange[0] - .1; scoreRange[1] = scoreRange[1] + .1; } + } else if (this.inputKeys.indexOf(key) !== -1) { + scoreRange = this.groupService.numericalFeatureRanges[key]; } return d3.scaleLinear().domain(scoreRange).range([ @@ -408,6 +419,17 @@ export class ScalarModule extends LitModule { ]); } + private getValue(preds: Preds, spec: ModelSpec, key: string, label: string) { + // If for a multiclass prediction, return the top label score. + if (isLitSubtype(spec.output[key], 'MulticlassPreds')) { + const predictionLabels = spec.output[key].vocab!; + const index = predictionLabels.indexOf(label); + return preds[key][index]; + } + // Otherwise, return the raw value. + return preds[key]; + } + /** * Re-renders threshold bar at the new threshold value and updates datapoint * colors. @@ -422,7 +444,7 @@ export class ScalarModule extends LitModule { const scatterplot = item as SVGGElement; const key = (item as HTMLElement).dataset['key']; - if (key == null) { + if (key == null || this.inputKeys.indexOf(key) !== -1) { return; } @@ -721,15 +743,7 @@ export class ScalarModule extends LitModule { circles .attr( 'cx', - (d) => { - if (isLitSubtype(spec.output[key], 'MulticlassPreds')) { - const predictionLabels = spec.output[key].vocab!; - const index = predictionLabels.indexOf(label); - return xScale(d[key][index]); - } - // Otherwise, return the regression score. - return xScale(d[key]); - }) + (d) => xScale(this.getValue(d, spec, key, label))) .attr( 'cy', (d) => { @@ -798,6 +812,7 @@ export class ScalarModule extends LitModule { ${this.scalarKeys.map(key => this.renderPlot(key, ''))} ${this.classificationKeys.map(key => this.renderClassificationGroup(key))} + ${this.inputKeys.map(key => this.renderPlot(key, ''))}
`; // clang-format on @@ -830,6 +845,17 @@ export class ScalarModule extends LitModule { this.numPlotsRendered++; const axisTitle = label ? `${key}:${label}` : key; + let selectedValue = ''; + if (this.selectionService.primarySelectedId != null) { + const selectedIndex = this.appState.getIndexById( + this.selectionService.primarySelectedId); + if (selectedIndex != null && this.preds[selectedIndex] != null) { + const spec = this.appState.getModelSpec(this.model); + const displayVal = formatForDisplay( + this.getValue(this.preds[selectedIndex], spec, key, label)); + selectedValue = `Value: ${displayVal}`; + } + } // clang-format off const toggleCollapse = () => { const isHidden = (this.isPlotHidden.get(axisTitle) == null) ? @@ -845,7 +871,11 @@ export class ScalarModule extends LitModule { {'display': `${isHidden ? 'none': 'block'}`}); return html`
-
${axisTitle} +
+
+
${axisTitle}
+
${selectedValue}
+
${isHidden ? 'expand_more': 'expand_less'} diff --git a/lit_nlp/client/services/services.ts b/lit_nlp/client/services/services.ts index f29637da..3d6760cc 100644 --- a/lit_nlp/client/services/services.ts +++ b/lit_nlp/client/services/services.ts @@ -19,6 +19,7 @@ export {ApiService} from './api_service'; export {ClassificationService} from './classification_service'; export {ColorService} from './color_service'; export {FocusService} from './focus_service'; +export {GroupService} from './group_service'; export {ModulesService} from './modules_service'; export {RegressionService} from './regression_service'; export {SelectionService} from './selection_service'; diff --git a/lit_nlp/examples/coref/model.py b/lit_nlp/examples/coref/model.py index 159cc330..d4830161 100644 --- a/lit_nlp/examples/coref/model.py +++ b/lit_nlp/examples/coref/model.py @@ -93,10 +93,6 @@ def predict_minibatch(self, inputs: List[JsonDict]): span1=orig_edge.span1, span2=orig_edge.span2, label=ep['proba']) preds[edge['src_idx']]['coref'].append(new_edge) for ex, p in zip(inputs, preds): - # TODO(b/172975096): allow plotting of scalars from input data, - # so we don't need to add this to the predictions. - if 'pf_bls' in ex: - p['pf_bls'] = ex['pf_bls'] # Choose an answer if there are only two target edges. if len(p['coref']) == 2 and 'answer' in ex: probas = np.array([ep.label for ep in p['coref']]) @@ -136,8 +132,4 @@ def output_spec(self): 'pred_answer': lit_types.MulticlassPreds( vocab=winogender.ANSWER_VOCAB, parent='answer'), - # TODO(b/172975096): allow plotting of scalars from input data, - # so we don't need to add this to the predictions. - 'pf_bls': - lit_types.Scalar(), } From 926c76198e975d2df868c7ac433da7bad626c572 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Fri, 16 Apr 2021 13:25:33 -0700 Subject: [PATCH 016/213] Move table controls to data table module, since these were quite specific to that use case. This greatly simplifies table.ts, which no longer needs a toolbar or column visibility controls. - Remove columnVisibility map; just pass names as strings instead. - lit-data-table can accept data as records, and will use column names to extract fields. - Simplify rendering logic to compute header widths. - Remove dead codepaths and redundant state. PiperOrigin-RevId: 368908615 --- lit_nlp/client/elements/table.css | 61 ---- lit_nlp/client/elements/table.ts | 306 ++++++++---------- .../client/modules/classification_module.ts | 9 +- lit_nlp/client/modules/data_table_module.css | 39 +++ lit_nlp/client/modules/data_table_module.ts | 240 ++++++++++---- lit_nlp/client/modules/metrics_module.ts | 29 +- lit_nlp/client/modules/regression_module.ts | 23 +- 7 files changed, 378 insertions(+), 329 deletions(-) diff --git a/lit_nlp/client/elements/table.css b/lit_nlp/client/elements/table.css index e0148b9e..15a415d4 100644 --- a/lit_nlp/client/elements/table.css +++ b/lit_nlp/client/elements/table.css @@ -1,36 +1,11 @@ :host { --table-header-height: 40px; - --toolbar-height: 43px; -} - -.toolbar{ - font-family: 'Google Sans', sans-serif !important; - display: flex; - flex-direction: row; - justify-content: space-between; - height: var(--toolbar-height); -} - -#toolbar-buttons{ - margin-bottom: 6px; } #holder{ - display: flex; - flex-direction: column; height: 100%; } -.table-container { - flex: 1; - overflow: visible; - height: 100% -} - -.toolbar + .table-container { - height: calc(100% - var(--toolbar-height)); -} - #rows-container{ width: 100%; height: calc(100% - var(--table-header-height)); @@ -198,42 +173,6 @@ td div { overflow: auto; } -/* For in-line icons in a */ -[data-icon] { - margin: 0; -} - -[data-icon]:before { - font-family: 'Material Icons'; - content: attr(data-icon); - vertical-align: middle; -} - -/* Column settings */ -.column-button { - position: absolute; - right: 15px; - top: 5px; - width: 100px; -} - -.column-dropdown-hide { - display:none; -} - -.column-dropdown { - top: 50px; - right: 15px; - padding: 10px; - background: white; - border: 1px solid gray; - position: absolute; - visibility: visible; - z-index: 1000; - max-height: calc(100% - 80px); - overflow: auto; -} - /* TODO(lit-dev): Make the table image width configurable. */ .table-img{ width: 100px; diff --git a/lit_nlp/client/elements/table.ts b/lit_nlp/client/elements/table.ts index d475a782..28375981 100644 --- a/lit_nlp/client/elements/table.ts +++ b/lit_nlp/client/elements/table.ts @@ -28,7 +28,8 @@ import {ascending, descending} from 'd3'; // array helpers. import {customElement, html, property, TemplateResult} from 'lit-element'; import {classMap} from 'lit-html/directives/class-map'; import {styleMap} from 'lit-html/directives/style-map'; -import {computed, observable} from 'mobx'; +import {action, computed, observable} from 'mobx'; + import {ReactiveElement} from '../lib/elements'; import {chunkWords} from '../lib/utils'; @@ -38,7 +39,13 @@ import {styles} from './table.css'; export type TableEntry = string|number|TemplateResult; type SortableTableEntry = string|number; /** Wrapper types for the data supplied to the data table */ -export type TableData = TableEntry[]; +export type TableData = TableEntry[]|{[key: string]: TableEntry}; + +/** Internal data, including metadata */ +interface TableRowInternal { + inputIndex: number; /* index in original this.data */ + rowData: TableEntry[]; +} /** Callback for selection */ export type OnSelectCallback = (selectedIndices: number[]) => void; @@ -60,16 +67,29 @@ const IMAGE_PREFIX = 'data:image'; */ @customElement('lit-data-table') export class DataTable extends ReactiveElement { - @observable @property({type: Array}) data: TableData[] = []; - @observable @property({type: Array}) selectedIndices: number[] = []; + // observable.struct is necessary to avoid spurious updates + // if this object identity changes. This can happen if plain JS data is + // passed to this property, as it will subsequently be proxied by mobx. + // The structural comparison ensures that this proxying will not trigger + // updates if the underlying data does not change. + // TODO(lit-dev): investigate any performance implications of this deep + // comparison, as this may run more frequently than we'd like. + // TODO(lit-dev): consider observable.ref or observable.shallow; + // see https://mobx.js.org/observable-state.html#available-annotations. + // This could save performance, since calling code can always do [...data] + // to generate a new reference and force a refresh if needed. + @observable.struct @property({type: Array}) data: TableData[] = []; + @observable.struct @property({type: Array}) columnNames: string[] = []; + @observable.struct @property({type: Array}) selectedIndices: number[] = []; @observable @property({type: Number}) primarySelectedIndex: number = -1; @observable @property({type: Number}) referenceSelectedIndex: number = -1; + // TODO(lit-dev): consider a custom reaction to make this more responsive, + // instead of triggering a full re-render. @observable @property({type: Number}) focusedIndex: number = -1; - @observable @property({type: Boolean}) selectionDisabled: boolean = false; - @observable @property({type: Boolean}) controlsEnabled: boolean = false; - @observable - @property({type: Object}) - columnVisibility = new Map(); + + // Mode controls + @observable @property({type: Boolean}) selectionEnabled: boolean = false; + @observable @property({type: Boolean}) searchEnabled: boolean = false; // Callbacks @property({type: Object}) onClick: OnPrimarySelectCallback|undefined; @@ -87,17 +107,17 @@ export class DataTable extends ReactiveElement { @observable private showColumnMenu = false; @observable private columnMenuName = ''; @observable private readonly columnSearchQueries = new Map(); - @observable private filterSelected = false; - @observable columnDropdownVisible = false; @observable private headerWidths: number[] = []; - // Sorted data. We manage updates with a reaction to enable "sticky" behavior. - private stickySortedData?: TableData[]|null = null; + // Sorted data. We manage updates with a reaction to enable "sticky" behavior, + // where subsequent sorts are based on the last sort rather than the original + // inputs (i.e. this.data). This way, you can do useful compound sorts like in + // a typical spreadsheet program. + @observable private stickySortedData?: TableRowInternal[]|null = null; private resizeObserver!: ResizeObserver; private selectedIndicesSetForRender = new Set(); - private rowIndexToDataIndex = new Map(); private shiftSelectionStartIndex = 0; private shiftSelectionEndIndex = 0; @@ -111,12 +131,17 @@ export class DataTable extends ReactiveElement { }); this.resizeObserver.observe(container); - // Clear "sticky" sorted data if the inputs change. + // If inputs changed, re-sort data based on the new inputs. this.reactImmediately(() => this.rowFilteredData, filteredData => { this.stickySortedData = null; + this.requestUpdate(); }); - this.reactImmediately(() => this.columnVisibility, columnVisibility => { - this.computeHeaderWidths(); + // If sort settings are changed, re-sort data optionally using result of + // previous sort. + const triggerSort = () => [this.sortName, this.sortAscending]; + this.reactImmediately(triggerSort, () => { + this.stickySortedData = this.getSortedData(this.displayData); + this.requestUpdate(); }); } @@ -149,11 +174,6 @@ export class DataTable extends ReactiveElement { return colEntry instanceof TemplateResult ? 0 : colEntry; } - @computed - get columnNames(): string[] { - return Array.from(this.columnVisibility.keys()); - } - @computed get sortIndex(): number|undefined { return (this.sortName == null) ? undefined : @@ -161,22 +181,43 @@ export class DataTable extends ReactiveElement { } /** - * This computed returns the data filtered by row (filtering by column - * happens in render()). + * First pass processing input data to canonical form. + * The data goes through several stages before rendering: + * - this.data is the input data (bound via element property) + * - indexedData converts to parallel-list form and adds input indices + * for later reference + * - rowFilteredData filters to a subset of rows, based on search criteria + * - getSortedData() is called in a reaction to sort this if sort-by-column + * is used. The result is stored in this.stickySortedData so that future + * sorts can be "stable" rather than re-sorting from the input. + * - displayData is stickySortedData, or rowFilteredData if that is unset. + * This is used to actually render the table. */ @computed - get rowFilteredData(): TableData[] { - const data = this.data.slice(); - const selectedIndices = new Set(this.selectedIndices); + get indexedData(): TableRowInternal[] { + // Convert any objects to simple arrays by selecting fields. + const convertedData: TableEntry[][] = this.data.map((d: TableData) => { + if (d instanceof Array) return d; + return this.columnNames.map(k => d[k]); + }); + return convertedData.map((rowData: TableEntry[], inputIndex: number) => { + return {inputIndex, rowData}; + }); + } - const rowFilteredData = data.filter((item) => { + /** + * This computed returns the data filtered by row. + */ + @computed + get rowFilteredData(): TableRowInternal[] { + return this.indexedData.filter((item) => { let isShownByTextFilter = true; // Apply column search filters for (const [key, value] of this.columnSearchQueries) { const index = this.columnNames.indexOf(key); if (index === -1) return; - const col = item[index]; + const col = item.rowData[index]; if (typeof col === 'string') { isShownByTextFilter = isShownByTextFilter && col.search(new RegExp(value)) !== -1; @@ -187,37 +228,26 @@ export class DataTable extends ReactiveElement { col.toString() === value; } } - - let isShownBySelectedFilter = true; - if (this.filterSelected) { - const areSomeSelected = this.selectedIndices.length > 0; - isShownBySelectedFilter = - areSomeSelected ? selectedIndices.has(+item[0]) : true; - } - return isShownByTextFilter && isShownBySelectedFilter; + return isShownByTextFilter; }); - return rowFilteredData; } - getSortedData(): TableData[] { - const source = this.stickySortedData ?? this.rowFilteredData; + getSortedData(source: TableRowInternal[]): TableRowInternal[] { let sortedData = source.slice(); if (this.sortName != null) { sortedData = sortedData.sort( (a, b) => (this.sortAscending ? ascending : descending)( - this.getSortableEntry(a[this.sortIndex!]), - this.getSortableEntry(b[this.sortIndex!]))); + this.getSortableEntry(a.rowData[this.sortIndex!]), + this.getSortableEntry(b.rowData[this.sortIndex!]))); } - - // Store a mapping from the row to data indices. - // TODO(lit-dev): remove hard-coded dependence on first column as index. - this.rowIndexToDataIndex = - new Map(sortedData.map((d, index) => [index, +d[0]])); - - this.stickySortedData = sortedData; return sortedData; } + @computed + get displayData(): TableRowInternal[] { + return this.stickySortedData ?? this.rowFilteredData; + } + private setShiftSelectionSpan(startIndex: number, endIndex: number) { this.shiftSelectionStartIndex = startIndex; this.shiftSelectionEndIndex = endIndex; @@ -228,10 +258,14 @@ export class DataTable extends ReactiveElement { index <= this.shiftSelectionEndIndex; } + private getInputIndexFromRowIndex(rowIndex: number) { + return this.displayData[rowIndex].inputIndex; + } + private selectFromRange( selectedIndices: Set, start: number, end: number, select = true) { for (let rowIndex = start; rowIndex <= end; rowIndex++) { - const dataIndex = this.rowIndexToDataIndex.get(rowIndex); + const dataIndex = this.getInputIndexFromRowIndex(rowIndex); if (dataIndex == null) return; if (select) { @@ -248,13 +282,11 @@ export class DataTable extends ReactiveElement { } /** Logic for handling row / multirow selection */ - private handleRowClick(e: MouseEvent, rowIndex: number) { + @action + private handleRowClick(e: MouseEvent, dataIndex: number, rowIndex: number) { let selectedIndices = new Set(this.selectedIndices); let doChangeSelectedSet = true; - const dataIndex = this.rowIndexToDataIndex.get(rowIndex); - if (dataIndex == null) return; - if (this.onClick != null) { this.onClick(dataIndex); return; @@ -330,19 +362,14 @@ export class DataTable extends ReactiveElement { } /** Logic for handling row hover */ - private handleRowMouseEnter(e: MouseEvent, rowIndex: number) { - const dataIndex = this.rowIndexToDataIndex.get(rowIndex); - if (dataIndex == null) return; + private handleRowMouseEnter(e: MouseEvent, dataIndex: number) { this.hoveredIndex = dataIndex; - if (this.onHover != null) { this.onHover(this.hoveredIndex); return; } } - private handleRowMouseLeave(e: MouseEvent, rowIndex: number) { - const dataIndex = this.rowIndexToDataIndex.get(rowIndex); - if (dataIndex == null) return; + private handleRowMouseLeave(e: MouseEvent, dataIndex: number) { if (dataIndex === this.hoveredIndex) { this.hoveredIndex = null; } @@ -353,10 +380,11 @@ export class DataTable extends ReactiveElement { } } + @action private setPrimarySelectedIndex(rowIndex: number) { let primaryIndex = -1; if (rowIndex !== -1) { - const dataIndex = this.rowIndexToDataIndex.get(rowIndex); + const dataIndex = this.getInputIndexFromRowIndex(rowIndex); if (dataIndex == null) return; primaryIndex = dataIndex; @@ -365,21 +393,33 @@ export class DataTable extends ReactiveElement { this.onPrimarySelect(primaryIndex); } + /** + * Imperative controls, intended to be used by a containing module + * such as data_table_module.ts + */ + @computed + get isDefaultView() { + return this.sortName === undefined && this.columnSearchQueries.size === 0; + } + + resetView() { + this.columnSearchQueries.clear(); + this.sortName = undefined; // reset to input ordering + this.showColumnMenu = false; // hide search bar + // Reset sticky sort and re-render from input data. + this.stickySortedData = null; + this.requestUpdate(); + } + + getVisibleDataIdxs(): number[] { + return this.displayData.map(d => d.inputIndex); + } + render() { // Make a private, temporary set of selectedIndices to simplify lookup // in the row render method this.selectedIndicesSetForRender = new Set(this.selectedIndices); - const data = this.getSortedData(); - const columns = Array.from(this.columnVisibility.keys()); - - // Only show columns that are set as visible in the column dropdown. - const columnFilteredData = data.map((row) => { - return row.filter((entry, i) => { - return this.columnVisibility.get(columns[i]); - }); - }); - // Synchronizes the horizontal scrolling of the header with the rows. const onScroll = (e: Event) => { const header = this.shadowRoot!.getElementById('header-container'); @@ -389,70 +429,21 @@ export class DataTable extends ReactiveElement { } }; - const onClickSelectAll = () => { - this.selectedIndices = columnFilteredData.map((d, index) => +d[0]); - this.onSelect([...this.selectedIndices]); - }; - - const isDefaultView = this.sortName === undefined && - this.columnSearchQueries.size === 0 && !this.filterSelected; - const onClickResetView = () => { - this.columnSearchQueries.clear(); - this.sortName = undefined; // reset to input ordering - this.filterSelected = false; - }; - - const toggleFilterSelected = () => { - this.filterSelected = !this.filterSelected; - }; - - const onToggleShowColumn = () => { - this.columnDropdownVisible = !this.columnDropdownVisible; - }; - - const visibleColumns = - this.columnNames.filter((key) => this.columnVisibility.get(key)); - // clang-format off return html`
- ${this.controlsEnabled ? html` -
- -
- - - -
- ${this.renderColumnDropdown()} -
` : null} -
-
- -
-
- - - ${columnFilteredData.map((d, rowIndex) => this.renderRow(d, rowIndex))} - -
+
+
+
+ + + ${this.displayData.map((d, rowIndex) => this.renderRow(d, rowIndex))} + +
+
`; // clang-format on @@ -466,6 +457,7 @@ export class DataTable extends ReactiveElement { if (index >= this.headerWidths.length) return; const headerWidth = this.headerWidths[index]; const width = headerWidth ? `${headerWidth}px` : ''; + let searchText = this.columnSearchQueries.get(title); if (searchText === undefined) { searchText = ''; @@ -537,7 +529,7 @@ export class DataTable extends ReactiveElement {
${title}
- ${this.controlsEnabled ? html` + ${this.searchEnabled ? html` ` : null} @@ -546,7 +538,7 @@ export class DataTable extends ReactiveElement { arrow_drop_down
- ${this.controlsEnabled ? html` + ${this.searchEnabled ? html`
{ - if (this.selectionDisabled) return; - this.handleRowClick(e, rowIndex); + if (!this.selectionEnabled) return; + this.handleRowClick(e, dataIndex, rowIndex); }; const mouseEnter = (e: MouseEvent) => { - this.handleRowMouseEnter(e, rowIndex); + this.handleRowMouseEnter(e, dataIndex); }; const mouseLeave = (e: MouseEvent) => { - this.handleRowMouseLeave(e, rowIndex); + this.handleRowMouseLeave(e, dataIndex); }; // clang-format off return html` - ${data.map((d => { + ${data.rowData.map((d => { if (typeof d === "string" && d.startsWith(IMAGE_PREFIX)) { return html``; } else { @@ -599,39 +590,6 @@ export class DataTable extends ReactiveElement { `; // clang-format on } - - renderColumnDropdown() { - // clang-format off - return html` -
- ${this.columnNames.filter((column) => column !== 'index') - .map(key => this.renderDropdownItem(key))} -
- `; - // clang-format on - } - - renderDropdownItem(key: string) { - const checked = this.columnVisibility.get(key); - if (checked == null) return; - - const toggleChecked = () => { - this.columnVisibility.set(key, !checked); - this.computeHeaderWidths(); - }; - - return html` -
- - -
- `; - } } declare global { diff --git a/lit_nlp/client/modules/classification_module.ts b/lit_nlp/client/modules/classification_module.ts index 4e6c3e4b..739f1c84 100644 --- a/lit_nlp/client/modules/classification_module.ts +++ b/lit_nlp/client/modules/classification_module.ts @@ -22,6 +22,7 @@ import {observable} from 'mobx'; import {app} from '../core/lit_app'; import {LitModule} from '../core/lit_module'; +import {TableData} from '../elements/table'; import {formatBoolean, IndexedInput, ModelInfoMap, Preds, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys} from '../lib/utils'; import {ClassificationInfo} from '../services/classification_service'; @@ -161,7 +162,7 @@ export class ClassificationModule extends LitModule { } private renderRow(fieldName: string, prediction: DisplayInfo[]) { - const rows: Array> = prediction.map((pred) => { + const rows: TableData[] = prediction.map((pred) => { const row = [ pred['label'], formatBoolean(pred['isGroundTruth']!), @@ -172,15 +173,13 @@ export class ClassificationModule extends LitModule { return row; }); const columnNames = ["Class", "Label", "Predicted", "Score", "Score Bar"]; - const columnVisibility = new Map(); - columnNames.forEach((name) => {columnVisibility.set(name, true);}); return html`
${fieldName}
`; } diff --git a/lit_nlp/client/modules/data_table_module.css b/lit_nlp/client/modules/data_table_module.css index e69de29b..6762e234 100644 --- a/lit_nlp/client/modules/data_table_module.css +++ b/lit_nlp/client/modules/data_table_module.css @@ -0,0 +1,39 @@ +.module-toolbar { + justify-content: space-between; + height: 32px; /* for buttons */ +} + +#toolbar-buttons{ + display: flex; + flex-direction: row; +} + +/* For in-line icons in a */ +[data-icon] { + margin: 0; +} + +[data-icon]:before { + font-family: 'Material Icons'; + content: attr(data-icon); + vertical-align: middle; +} + +/* Column settings */ +.column-dropdown-hide { + display:none; +} + +/* TODO(b/173445400): align this with rest of UI spec */ +.column-dropdown { + top: 32px; + right: 4px; + padding: 10px; + background: white; + border: 1px solid gray; + position: absolute; + visibility: visible; + z-index: 1000; + max-height: calc(100% - 48px); + overflow: auto; +} diff --git a/lit_nlp/client/modules/data_table_module.ts b/lit_nlp/client/modules/data_table_module.ts index 5ad11c90..a0ccb8c9 100644 --- a/lit_nlp/client/modules/data_table_module.ts +++ b/lit_nlp/client/modules/data_table_module.ts @@ -18,12 +18,12 @@ // tslint:disable:no-new-decorators import '../elements/checkbox'; -import {customElement, html} from 'lit-element'; +import {customElement, html, query} from 'lit-element'; import {computed, observable} from 'mobx'; import {app} from '../core/lit_app'; import {LitModule} from '../core/lit_module'; -import {TableData} from '../elements/table'; +import {DataTable, TableData} from '../elements/table'; import {formatForDisplay, IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {compareArrays, findSpecKeys, shortenId} from '../lib/utils'; import {ClassificationInfo} from '../services/classification_service'; @@ -60,7 +60,16 @@ export class DataTableModule extends LitModule { modelPredToClassificationInfo = new Map(); @observable modelPredToRegressionInfo = new Map(); @observable searchText = ''; - @observable filterSelected = false; + + // Module options / configuration state + @observable private filterSelected: boolean = false; + @observable private columnDropdownVisible: boolean = false; + + // Persistent selection state + @observable private selectedInputData: IndexedInput[] = []; + + // Child components + @query('lit-data-table') private readonly table?: DataTable; @computed get dataSpec(): Spec { @@ -80,11 +89,26 @@ export class DataTableModule extends LitModule { return ['index', 'id', ...this.keys]; } + @computed + get filteredData(): IndexedInput[] { + return this.filterSelected ? this.selectedInputData : + this.appState.currentInputData; + } + + @computed + get sortedData(): IndexedInput[] { + // TODO(lit-dev): pre-compute the index chains for each point, since + // this might get slow if we have a lot of counterfactuals. + return this.filteredData.slice().sort( + (a, b) => compareArrays( + this.reversedAncestorIndices(a), this.reversedAncestorIndices(b))); + } + @computed get selectedRowIndices(): number[] { - return this.selectionService.selectedIds - .map((id) => this.appState.getIndexById(id)) - .filter((index) => index !== -1); + return this.sortedData + .map((ex, i) => this.selectionService.isIdSelected(ex.id) ? i : -1) + .filter(i => i !== -1); } /** @@ -131,19 +155,11 @@ export class DataTableModule extends LitModule { // TODO(lit-dev): figure out why this updates so many times; // it gets run _four_ times every time a new datapoint is added. @computed - get data(): TableData[] { - const inputData = this.appState.currentInputData; - - // TODO(lit-dev): pre-compute the index chains for each point, since - // this might get slow if we have a lot of counterfactuals. - const sortedData = inputData.slice().sort( - (a, b) => compareArrays( - this.reversedAncestorIndices(a), this.reversedAncestorIndices(b))); - + get tableData(): TableData[] { // TODO(b/160170742): Make data table render immediately once the // non-prediction data is available, then fetch predictions asynchronously // and enable the additional columns when ready. - return sortedData.map((d) => { + return this.sortedData.map((d) => { let displayId = shortenId(d.id); displayId = displayId ? displayId + '...' : ''; // Add an asterisk for generated examples @@ -163,7 +179,9 @@ export class DataTableModule extends LitModule { // are filtered before rendering. const predictionInfoColumns = Array.from(this.columnVisibility.keys()) - .filter((column) => !this.defaultColumns.includes(column)); + .filter( + (column) => !this.defaultColumns.includes(column) && + this.columnVisibility.get(column)); predictionInfoColumns.forEach((columnName: string) => { const entry = this.keysToTableEntry.get(this.getTableKey(rowName, columnName)); @@ -171,21 +189,23 @@ export class DataTableModule extends LitModule { }); } - return [ - index, displayId, - ...this.keys.map( - (key) => formatForDisplay(d.data[key], this.dataSpec[key])), - ...predictionInfoEntries - ]; + const dataEntries = + this.keys.filter(k => this.columnVisibility.get(k)) + .map(k => formatForDisplay(d.data[k], this.dataSpec[k])); + + const ret: TableData = [index]; + if (this.columnVisibility.get('id')) { + ret.push(displayId); + } + return [...ret, ...dataEntries, ...predictionInfoEntries]; }); } firstUpdated() { const getCurrentInputData = () => this.appState.currentInputData; this.reactImmediately(getCurrentInputData, currentInputData => { - if (currentInputData != null) { - this.updatePredictionInfo(currentInputData); - } + if (currentInputData == null) return; + this.updatePredictionInfo(currentInputData); }); const getCurrentModels = () => this.appState.currentModels; this.react(getCurrentModels, currentModels => { @@ -199,6 +219,13 @@ export class DataTableModule extends LitModule { this.react(getKeys, keys => { this.updateColumns(); }); + this.reactImmediately( + () => this.selectionService.selectedOrAllInputData, inputData => { + this.selectedInputData = inputData; + if (this.table) { + this.table.resetView(); + } + }); this.updateColumns(); } @@ -324,23 +351,28 @@ export class DataTableModule extends LitModule { return `${row}:${column}`; } + /** + * Table callbacks receive indices corresponding to the rows of + * this.tableData, which matches this.sortedData. + * We need to map those back to global ids for selection purposes. + */ + getIdFromTableIndex(tableIndex: number) { + return this.sortedData[tableIndex]?.id; + } - onSelect(selectedRowIndices: number[]) { - const ids = selectedRowIndices - .map(index => this.appState.currentInputData[index]?.id) + onSelect(tableDataIndices: number[]) { + const ids = tableDataIndices.map(i => this.getIdFromTableIndex(i)) .filter(id => id != null); - this.selectionService.selectIds(ids); + this.selectionService.selectIds(ids, this); } - onPrimarySelect(index: number) { - const id = - index === -1 ? null : this.appState.currentInputData[index]?.id ?? null; - this.selectionService.setPrimarySelection(id); + onPrimarySelect(tableIndex: number) { + const id = this.getIdFromTableIndex(tableIndex); + this.selectionService.setPrimarySelection(id, this); } - onHover(index: number|null) { - const id = - index == null ? null : this.appState.currentInputData[index]?.id; + onHover(tableIndex: number|null) { + const id = tableIndex != null ? this.getIdFromTableIndex(tableIndex) : null; if (id == null) { this.focusService.clearFocus(); } else { @@ -348,54 +380,138 @@ export class DataTableModule extends LitModule { } } - render() { - const onSelect = (selectedIndices: number[]) => { - this.onSelect(selectedIndices); + renderDropdownItem(key: string) { + const checked = this.columnVisibility.get(key); + if (checked == null) return; + + const toggleChecked = () => { + this.columnVisibility.set(key, !checked); }; - const onPrimarySelect = (index: number) => { - this.onPrimarySelect(index); + + // clang-format off + return html` +
+ + +
+ `; + // clang-format on + } + + renderColumnDropdown() { + const names = [...this.columnVisibility.keys()].filter(c => c !== 'index'); + const classes = + this.columnDropdownVisible ? 'column-dropdown' : 'column-dropdown-hide'; + // clang-format off + return html` +
+ ${names.map(key => this.renderDropdownItem(key))} +
+ `; + // clang-format on + } + + renderControls() { + const onClickResetView = () => { + this.table!.resetView(); }; - const onHover = (index: number|null) => { - this.onHover(index); + + const onClickSelectAll = () => { + this.onSelect(this.table!.getVisibleDataIdxs()); }; - const primarySelectedIndex = - this.appState.getIndexById(this.selectionService.primarySelectedId); + const onToggleShowColumn = () => { + this.columnDropdownVisible = !this.columnDropdownVisible; + }; + // clang-format off + return html` + { this.filterSelected = !this.filterSelected; }} + > +
+ + + +
+ ${this.renderColumnDropdown()} + `; + // clang-format on + } + + renderTable() { + const tableDataIds = this.sortedData.map(d => d.id); + const indexOfId = (id: string|null) => + id != null ? tableDataIds.indexOf(id) : -1; + + const primarySelectedIndex = + indexOfId(this.selectionService.primarySelectedId); - const focusData = this.focusService.focusData; // Set focused index if a datapoint is focused according to the focus // service. If the focusData is null then nothing is focused. If focusData // contains a value in the "io" field then the focus is on a subfield of // a datapoint, as opposed to a datapoint itself. + const focusData = this.focusService.focusData; const focusedIndex = focusData == null || focusData.io != null ? -1 : - this.appState.getIndexById(focusData.datapointId); + indexOfId(focusData.datapointId); // Handle reference selection, if in compare examples mode. let referenceSelectedIndex = -1; if (this.appState.compareExamplesEnabled) { const referenceSelectionService = app.getServiceArray(SelectionService)[1]; - referenceSelectedIndex = this.appState.getIndexById( - referenceSelectionService.primarySelectedId); + referenceSelectedIndex = + indexOfId(referenceSelectionService.primarySelectedId); } + const columnNames = [...this.columnVisibility.keys()].filter( + k => this.columnVisibility.get(k)); + + // clang-format off return html` - + { this.onSelect(idxs); }} + .onPrimarySelect=${(i: number) => { this.onPrimarySelect(i); }} + .onHover=${(i: number|null)=> { this.onHover(i); }} + searchEnabled + selectionEnabled + > + `; + // clang-format on + } + render() { + // clang-format off + return html` +
+
+ ${this.renderControls()} +
+
+ ${this.renderTable()} +
+
`; + // clang-format on } static shouldDisplayModule(modelSpecs: ModelInfoMap, datasetSpec: Spec) { diff --git a/lit_nlp/client/modules/metrics_module.ts b/lit_nlp/client/modules/metrics_module.ts index 88454b03..0ec64cce 100644 --- a/lit_nlp/client/modules/metrics_module.ts +++ b/lit_nlp/client/modules/metrics_module.ts @@ -266,7 +266,7 @@ export class MetricsModule extends LitModule { /** Convert the metricsMap information into table data for display. */ @computed get tableData(): TableHeaderAndData { - const rows = [] as TableData[]; + const tableRows = [] as TableData[]; const allMetricNames = new Set(); Object.values(this.metricsMap).forEach(row => { Object.keys(row.headMetrics).forEach(metricsType => { @@ -279,7 +279,7 @@ export class MetricsModule extends LitModule { const metricNames = [...allMetricNames]; - Object.values(this.metricsMap).forEach(row => { + for (const row of Object.values(this.metricsMap)) { const rowMetrics = metricNames.map(metricKey => { const [metricsType, metricName] = metricKey.split(": "); if (row.headMetrics[metricsType] == null) { @@ -304,16 +304,17 @@ export class MetricsModule extends LitModule { }); const tableRow = [ - rows.length, row.model, row.selection, ...rowFacets, row.predKey, - row.exampleIds.length, ...rowMetrics]; - rows.push(tableRow); - }); + row.model, row.selection, ...rowFacets, row.predKey, + row.exampleIds.length, ...rowMetrics + ]; + tableRows.push(tableRow); + } return { - 'header': - ["id", 'Model', 'From', ...this.selectedFacets, 'Field', 'N', - ...metricNames], - 'data': rows + 'header': [ + 'Model', 'From', ...this.selectedFacets, 'Field', 'N', ...metricNames + ], + 'data': tableRows }; } @@ -333,17 +334,11 @@ export class MetricsModule extends LitModule { } renderTable() { - const columnNames = this.tableData.header; - const columnVisibility = new Map(); - columnNames.forEach((name) => { - columnVisibility.set(name, name !== "id"); - }); // TODO(b/180903904): Add onSelect behavior to rows for selection. return html` `; } diff --git a/lit_nlp/client/modules/regression_module.ts b/lit_nlp/client/modules/regression_module.ts index 6e5640e7..53f41ffb 100644 --- a/lit_nlp/client/modules/regression_module.ts +++ b/lit_nlp/client/modules/regression_module.ts @@ -21,6 +21,7 @@ import {observable} from 'mobx'; import {app} from '../core/lit_app'; import {LitModule} from '../core/lit_module'; +import {TableData} from '../elements/table'; import {IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {doesOutputSpecContain, findSpecKeys} from '../lib/utils'; import {RegressionService} from '../services/services'; @@ -101,7 +102,7 @@ export class RegressionModule extends LitModule { const scoreFields: string[] = findSpecKeys(spec.output, 'RegressionScore'); - const rows: string[][] = []; + const rows: TableData[] = []; let hasParent = false; // Per output, display score, and parent field and error if available. for (const scoreField of scoreFields) { @@ -123,22 +124,24 @@ export class RegressionModule extends LitModule { errorScore = error.toFixed(4); } } - rows.push([scoreField, parentScore, score, errorScore]); + rows.push({ + 'Field': scoreField, + 'Ground truth': parentScore, + 'Score': score, + 'Error': errorScore + }); } // If no fields have ground truth scores to compare then don't display the // ground truth-related columns. - const columnNames = ["Field", "Ground truth", "Score", "Error"]; - const columnVisibility = new Map(); - columnNames.forEach((name) => { - columnVisibility.set( - name, hasParent || (name !== 'Ground truth' && name !== 'Error')); - }); + const columnNames = hasParent ? + ['Field', 'Ground truth', 'Score', 'Error'] : + ['Field', 'Score']; return html` `; } From 1e722603af71625718aa15123bdb11ec6fc189b9 Mon Sep 17 00:00:00 2001 From: Ellen Jiang Date: Mon, 19 Apr 2021 14:05:43 -0700 Subject: [PATCH 017/213] Changes LIT TCAV's t-test to check against the scores of a random sample, and adds SST2 BERT tiny as a test model. PiperOrigin-RevId: 369298147 --- lit_nlp/components/metrics_test.py | 55 ++++--- lit_nlp/components/tcav.py | 32 +++- lit_nlp/components/tcav_test.py | 229 ++++++++++++++--------------- lit_nlp/lib/testing_utils.py | 18 ++- 4 files changed, 177 insertions(+), 157 deletions(-) diff --git a/lit_nlp/components/metrics_test.py b/lit_nlp/components/metrics_test.py index d5c8dcac..ecffa42a 100644 --- a/lit_nlp/components/metrics_test.py +++ b/lit_nlp/components/metrics_test.py @@ -39,7 +39,7 @@ def test_compute(self): result = regression_metrics.compute([1, 2, 3, 4], [1, 2, 3, 4], types.RegressionScore(), types.RegressionScore()) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'mse': 0, 'pearsonr': 1.0, 'spearmanr': 1.0 @@ -49,7 +49,7 @@ def test_compute(self): result = regression_metrics.compute([1, 2, 3, 4], [1, 2, 5.5, 6.3], types.RegressionScore(), types.RegressionScore()) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'mse': 2.885, 'pearsonr': 0.96566, 'spearmanr': 1.0 @@ -59,16 +59,16 @@ def test_compute(self): result = regression_metrics.compute([1, 2, 3, 4], [-5, -10, 5, 6], types.RegressionScore(), types.RegressionScore()) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'mse': 47.0, 'pearsonr': 0.79559, - 'spearmanr': 0.79999 + 'spearmanr': 0.799999 }) # Empty labels and predictions result = regression_metrics.compute([], [], types.RegressionScore(), types.RegressionScore()) - testing_utils.assert_dicts_almost_equal(self, result, {}) + testing_utils.assert_deep_almost_equal(self, result, {}) class MulticlassMetricsTest(absltest.TestCase): @@ -90,7 +90,7 @@ def test_compute(self): ['1', '2', '0', '1'], [[0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1', '2'], null_idx=0)) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'accuracy': 1.0, 'f1': 1.0, 'precision': 1.0, @@ -103,13 +103,12 @@ def test_compute(self): [[.1, .4, .5], [0, .1, .9], [.1, 0, .9], [0, 1, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1', '2'], null_idx=0)) - testing_utils.assert_dicts_almost_equal( - self, result, { - 'accuracy': 0.5, - 'f1': 0.57143, - 'precision': 0.5, - 'recall': 0.66666 - }) + testing_utils.assert_deep_almost_equal(self, result, { + 'accuracy': 0.5, + 'f1': 0.57143, + 'precision': 0.5, + 'recall': 0.66667 + }) # All incorrect predictions. result = multiclass_metrics.compute( @@ -117,7 +116,7 @@ def test_compute(self): [[.1, .4, .5], [.2, .7, .1], [.1, 0, .9], [1, 0, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1', '2'], null_idx=0)) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'accuracy': 0.0, 'f1': 0.0, 'precision': 0.0, @@ -129,13 +128,13 @@ def test_compute(self): ['1', '2', '0', '1'], [[.1, .4, .5], [0, .1, .9], [.1, 0, .9], [0, 1, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1', '2'])) - testing_utils.assert_dicts_almost_equal(self, result, {'accuracy': 0.5}) + testing_utils.assert_deep_almost_equal(self, result, {'accuracy': 0.5}) # Empty labels and predictions result = multiclass_metrics.compute([], [], types.CategoryLabel(), types.MulticlassPreds( vocab=['0', '1', '2'], null_idx=0)) - testing_utils.assert_dicts_almost_equal(self, result, {}) + testing_utils.assert_deep_almost_equal(self, result, {}) class MulticlassPairedMetricsTest(absltest.TestCase): @@ -163,7 +162,7 @@ def test_compute(self): ['1', '1', '0', '0'], [[0, 1], [0, 1], [1, 0], [1, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), indices, metas) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'mean_jsd': 0.0, 'num_pairs': 2, 'swap_rate': 0.0 @@ -174,7 +173,7 @@ def test_compute(self): ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [1, 0]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), indices, metas) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'mean_jsd': 0.34657, 'num_pairs': 2, 'swap_rate': 0.5 @@ -185,7 +184,7 @@ def test_compute(self): ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [0, 1]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), indices, metas) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'mean_jsd': 0.69315, 'num_pairs': 2, 'swap_rate': 1.0 @@ -196,7 +195,7 @@ def test_compute(self): ['1', '1', '0', '0'], [[0, 1], [1, 0], [1, 0], [0, 1]], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1']), indices, metas) - testing_utils.assert_dicts_almost_equal(self, result, { + testing_utils.assert_deep_almost_equal(self, result, { 'mean_jsd': 0.69315, 'num_pairs': 2, 'swap_rate': 1.0 @@ -206,7 +205,7 @@ def test_compute(self): result = multiclass_paired_metrics.compute_with_metadata( [], [], types.CategoryLabel(), types.MulticlassPreds(vocab=['0', '1'], null_idx=0), [], []) - testing_utils.assert_dicts_almost_equal(self, result, {}) + testing_utils.assert_deep_almost_equal(self, result, {}) class CorpusBLEUTest(absltest.TestCase): @@ -228,28 +227,28 @@ def test_compute(self): ['This is a test.', 'Test two', 'A third test example'], ['This is a test.', 'Test two', 'A third test example'], types.GeneratedText(), types.GeneratedText()) - testing_utils.assert_dicts_almost_equal(self, result, - {'corpus_bleu': 100.00000}) + testing_utils.assert_deep_almost_equal(self, result, + {'corpus_bleu': 100.00000}) # Some incorrect predictions. result = corpusblue_metrics.compute( ['This is a test.', 'Test one', 'A third test'], ['This is a test.', 'Test two', 'A third test example'], types.GeneratedText(), types.GeneratedText()) - testing_utils.assert_dicts_almost_equal(self, result, - {'corpus_bleu': 68.037493}) + testing_utils.assert_deep_almost_equal(self, result, + {'corpus_bleu': 68.037493}) result = corpusblue_metrics.compute( ['This is a test.', 'Test one', 'A third test'], ['these test.', 'Test two', 'A third test example'], types.GeneratedText(), types.GeneratedText()) - testing_utils.assert_dicts_almost_equal(self, result, - {'corpus_bleu': 29.508062388758525}) + testing_utils.assert_deep_almost_equal(self, result, + {'corpus_bleu': 29.508062388758525}) # Empty labels and predictions result = corpusblue_metrics.compute([], [], types.GeneratedText(), types.GeneratedText()) - testing_utils.assert_dicts_almost_equal(self, result, {}) + testing_utils.assert_deep_almost_equal(self, result, {}) if __name__ == '__main__': diff --git a/lit_nlp/components/tcav.py b/lit_nlp/components/tcav.py index 1bcf189b..d76a6ca8 100644 --- a/lit_nlp/components/tcav.py +++ b/lit_nlp/components/tcav.py @@ -16,12 +16,13 @@ """Quantitative Testing with Concept Activation Vectors (TCAV).""" import random -from typing import Any, List, Optional, cast, Sequence, Text +from typing import Any, List, Optional, Sequence, Text, cast import attr from lit_nlp.api import components as lit_components from lit_nlp.api import dataset as lit_dataset from lit_nlp.api import model as lit_model + from lit_nlp.api import types import numpy as np import scipy.stats @@ -33,7 +34,7 @@ IndexedInput = types.IndexedInput Spec = types.Spec -NUM_SPLITS = 20 # TODO(lit-dev): Make this configurable in the UI. +NUM_SPLITS = 15 # TODO(lit-dev): Make this configurable in the UI. @attr.s(auto_attribs=True, kw_only=True) @@ -74,11 +75,11 @@ class TCAV(lit_components.Interpreter): - MulticlassPreds (`probas`) """ - def hyp_test(self, scores): + def hyp_test(self, scores, random_scores): """Returns the p-value for a two-sided t-test on the TCAV score.""" # The null hypothesis is 0.5, since a TCAV score of 0.5 would indicate # the concept does not affect the prediction positively or negatively. - _, p_val = scipy.stats.ttest_1samp(scores, 0.5) + _, p_val = scipy.stats.ttest_ind(scores, random_scores) return p_val def run_with_metadata( @@ -157,8 +158,26 @@ def _subsample(examples, n): config.test_size, config.random_state)) + random_results = [] + # Get tcav scores on random splits. + for _ in range(NUM_SPLITS): + concept_split_outputs = _subsample(dataset_outputs, n) + comparison_split_outputs = _subsample(non_concept_outputs, n) + random_results.append(self._run_tcav(concept_split_outputs, + comparison_split_outputs, + dataset_outputs, + config.class_to_explain, + emb_layer, + grad_layer, + grad_class_key, + config.test_size, + config.random_state)) + cav_scores = [res['score'] for res in concept_results] - p_val = self.hyp_test(cav_scores) + random_scores = [res['score'] for res in random_results] + p_val = self.hyp_test(cav_scores, random_scores) + + random_mean = np.mean(random_scores) # Get index of CAV result with the highest accuracy. accuracies = [res['accuracy'] for res in concept_results] @@ -166,7 +185,8 @@ def _subsample(examples, n): # Many CAVS are trained and checked for statistical testing to calculate # the p-value. The values of the first CAV are returned. - results = {'result': concept_results[index], 'p_val': p_val} + results = {'result': concept_results[index], 'p_val': p_val, + 'random_mean': random_mean} return [results] def _get_training_data(self, comparison_outputs, concept_outputs, emb_layer): diff --git a/lit_nlp/components/tcav_test.py b/lit_nlp/components/tcav_test.py index f470640e..04fc8e2b 100644 --- a/lit_nlp/components/tcav_test.py +++ b/lit_nlp/components/tcav_test.py @@ -16,154 +16,105 @@ """Tests for lit_nlp.components.gradient_maps.""" import random -from typing import List from absl.testing import absltest from lit_nlp.api import dataset as lit_dataset -from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types from lit_nlp.components import tcav +# TODO(lit-dev): Move glue_models out of lit_nlp/examples +from lit_nlp.examples.models import glue_models +from lit_nlp.lib import testing_utils import numpy as np + JsonDict = lit_types.JsonDict +Spec = lit_types.Spec -class TestModelClassificationTCAV(lit_model.Model): - """Implements lit.Model interface for testing TCAV. - - Returns the same output for every input. - """ - - # LIT API implementation - def input_spec(self): - return {'input_embs': lit_types.TokenEmbeddings(align='tokens', - required=False), - 'segment': lit_types.TextSegment()} - - def output_spec(self): - return { - 'probas': - lit_types.MulticlassPreds( - parent='label', vocab=['0', '1'], null_idx=0), - 'cls_emb': - lit_types.Embeddings(), - 'cls_grad': - lit_types.Gradients(grad_for='cls_emb', grad_target='grad_class'), - 'grad_class': - lit_types.CategoryLabel() - } +BERT_TINY_PATH = 'https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz' # pylint: disable=line-too-long +import transformers +BERT_TINY_PATH = transformers.file_utils.cached_path(BERT_TINY_PATH, +extract_compressed_file=True) - def predict_minibatch(self, inputs: List[JsonDict], **kw): - output = { - 'probas': np.array([0.2, 0.8]), - 'cls_emb': [1, 0, 0, 0], - 'cls_grad': [2, 0, 0, 0], - 'grad_class': '1' - } - return map(lambda x: output, inputs) - -class TCAVTest(absltest.TestCase): +class ModelBasedTCAVTest(absltest.TestCase): def setUp(self): - super(TCAVTest, self).setUp() + super(ModelBasedTCAVTest, self).setUp() self.tcav = tcav.TCAV() - - def test_hyp_test(self): - # t-test where p-value = 1. - scores = [0, 0, 0.5, 0.5, 1, 1] - result = self.tcav.hyp_test(scores) - self.assertEqual(1, result) - - # t-test where p-value ~ 0. - scores = [0.1, 0.13, 0.19, 0.09, 0.12, 0.1] - result = self.tcav.hyp_test(scores) - self.assertAlmostEqual(1.7840024559935266e-06, result) - - def test_compute_tcav_score(self): - dir_deriv_positive_class = [1] - result = self.tcav.compute_tcav_score(dir_deriv_positive_class) - self.assertAlmostEqual(1, result) - - dir_deriv_positive_class = [0] - result = self.tcav.compute_tcav_score(dir_deriv_positive_class) - self.assertAlmostEqual(0, result) - - dir_deriv_positive_class = [1, -5, 4, 6.5, -3, -2.5, 0, 2] - result = self.tcav.compute_tcav_score(dir_deriv_positive_class) - self.assertAlmostEqual(0.5, result) + self.model = glue_models.SST2Model(BERT_TINY_PATH) def test_tcav(self): random.seed(0) # Sets seed since create_comparison_splits() uses random. # Basic test with dummy outputs from the model. examples = [ - {'segment': 'a'}, - {'segment': 'b'}, - {'segment': 'c'}, - {'segment': 'd'}, - {'segment': 'e'}, - {'segment': 'f'}, - {'segment': 'g'}, - {'segment': 'h'}] + {'sentence': 'a'}, + {'sentence': 'b'}, + {'sentence': 'c'}, + {'sentence': 'd'}, + {'sentence': 'e'}, + {'sentence': 'f'}, + {'sentence': 'g'}, + {'sentence': 'h'}] indexed_inputs = [ { 'id': '1', 'data': { - 'segment': 'a' + 'sentence': 'a' } }, { 'id': '2', 'data': { - 'segment': 'b' + 'sentence': 'b' } }, { 'id': '3', 'data': { - 'segment': 'c' + 'sentence': 'c' } }, { 'id': '4', 'data': { - 'segment': 'd' + 'sentence': 'd' } }, { 'id': '5', 'data': { - 'segment': 'e' + 'sentence': 'e' } }, { 'id': '6', 'data': { - 'segment': 'f' + 'sentence': 'f' } }, { 'id': '7', 'data': { - 'segment': 'g' + 'sentence': 'g' } }, { 'id': '8', 'data': { - 'segment': 'h' + 'sentence': 'h' } }, { 'id': '9', 'data': { - 'segment': 'i' + 'sentence': 'i' } }, ] - model = TestModelClassificationTCAV() - dataset_spec = {'segment': lit_types.TextSegment()} + + dataset_spec = {'sentence': lit_types.TextSegment()} dataset = lit_dataset.Dataset(dataset_spec, examples) config = { 'concept_set_ids': ['1', '3', '4', '8'], @@ -171,24 +122,30 @@ def test_tcav(self): 'grad_layer': 'cls_grad', 'random_state': 0 } - result = self.tcav.run_with_metadata(indexed_inputs, model, dataset, + result = self.tcav.run_with_metadata(indexed_inputs, self.model, dataset, config=config) self.assertLen(result, 1) expected = { - 'p_val': 0.0, + 'p_val': 0.13311, + 'random_mean': 0.56667, 'result': { - 'score': 1.0, - 'cos_sim': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 'score': 0.33333, + 'cos_sim': [ + 0.088691, -0.12179, 0.16013, + 0.24840, -0.09793, 0.05166, + -0.21578, -0.06560, -0.14759 + ], 'dot_prods': [ - 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, - 1.6669444907484283, 1.6669444907484283, 1.6669444907484283, - 1.6669444907484283, 1.6669444907484283, 1.6669444907484283 + 189.085096, -266.36317, 344.350498, + 547.144949, -211.663965, 112.502439, + -472.72066, -144.529598, -323.31888 ], - 'accuracy': 0.3333333333333333 + 'accuracy': 0.66667 } } - self.assertDictEqual(expected, result[0]) + + testing_utils.assert_deep_almost_equal(self, expected, result[0]) def test_tcav_sample_from_positive(self): # Tests the case where more concept examples are passed than non-concept @@ -198,66 +155,65 @@ def test_tcav_sample_from_positive(self): # Basic test with dummy outputs from the model. examples = [ - {'segment': 'a'}, - {'segment': 'b'}, - {'segment': 'c'}, - {'segment': 'd'}, - {'segment': 'e'}, - {'segment': 'f'}, - {'segment': 'g'}, - {'segment': 'h'}] + {'sentence': 'a'}, + {'sentence': 'b'}, + {'sentence': 'c'}, + {'sentence': 'd'}, + {'sentence': 'e'}, + {'sentence': 'f'}, + {'sentence': 'g'}, + {'sentence': 'h'}] indexed_inputs = [ { 'id': '1', 'data': { - 'segment': 'a' + 'sentence': 'a' } }, { 'id': '2', 'data': { - 'segment': 'b' + 'sentence': 'b' } }, { 'id': '3', 'data': { - 'segment': 'c' + 'sentence': 'c' } }, { 'id': '4', 'data': { - 'segment': 'd' + 'sentence': 'd' } }, { 'id': '5', 'data': { - 'segment': 'e' + 'sentence': 'e' } }, { 'id': '6', 'data': { - 'segment': 'f' + 'sentence': 'f' } }, { 'id': '7', 'data': { - 'segment': 'g' + 'sentence': 'g' } }, { 'id': '8', 'data': { - 'segment': 'h' + 'sentence': 'h' } }, ] - model = TestModelClassificationTCAV() - dataset_spec = {'segment': lit_types.TextSegment()} + dataset_spec = {'sentence': lit_types.TextSegment()} dataset = lit_dataset.Dataset(dataset_spec, examples) config = { 'concept_set_ids': ['1', '3', '4', '5', '8'], @@ -265,25 +221,64 @@ def test_tcav_sample_from_positive(self): 'grad_layer': 'cls_grad', 'random_state': 0 } - result = self.tcav.run_with_metadata(indexed_inputs, model, dataset, + result = self.tcav.run_with_metadata(indexed_inputs, self.model, dataset, config=config) self.assertLen(result, 1) expected = { - 'p_val': 0.0, + 'p_val': 0.80489, + 'random_mean': 0.53333, 'result': { - 'score': 1.0, - 'cos_sim': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 'score': 0.8, + 'cos_sim': [ + 0.09527, -0.20442, 0.05141, + 0.14985, 0.06750, -0.28244, + -0.11022, -0.14479 + ], 'dot_prods': [ - 2.0589251447995237e-14, 2.0589251447995237e-14, - 2.0589251447995237e-14, 2.0589251447995237e-14, - 2.0589251447995237e-14, 2.0589251447995237e-14, - 2.0589251447995237e-14, 2.0589251447995237e-14 + 152.48776, -335.64998, 82.99588, + 247.80113, 109.53684, -461.81805, + -181.29095, -239.47817 ], - 'accuracy': 0.5 + 'accuracy': 1.0 } } - self.assertDictEqual(expected, result[0]) + + testing_utils.assert_deep_almost_equal(self, expected, result[0]) + + +class TCAVTest(absltest.TestCase): + + def setUp(self): + super(TCAVTest, self).setUp() + self.tcav = tcav.TCAV() + self.model = glue_models.SST2Model(BERT_TINY_PATH) + + def test_hyp_test(self): + # t-test where p-value != 1. + scores = [0, 0, 0.5, 0.5, 1, 1] + random_scores = [3, 5, -8, -100, 0, -90] + result = self.tcav.hyp_test(scores, random_scores) + self.assertAlmostEqual(0.1415165926492605, result) + + # t-test where p-value = 1. + scores = [0.1, 0.13, 0.19, 0.09, 0.12, 0.1] + random_scores = [0.1, 0.13, 0.19, 0.09, 0.12, 0.1] + result = self.tcav.hyp_test(scores, random_scores) + self.assertEqual(1.0, result) + + def test_compute_tcav_score(self): + dir_deriv_positive_class = [1] + result = self.tcav.compute_tcav_score(dir_deriv_positive_class) + self.assertAlmostEqual(1, result) + + dir_deriv_positive_class = [0] + result = self.tcav.compute_tcav_score(dir_deriv_positive_class) + self.assertAlmostEqual(0, result) + + dir_deriv_positive_class = [1, -5, 4, 6.5, -3, -2.5, 0, 2] + result = self.tcav.compute_tcav_score(dir_deriv_positive_class) + self.assertAlmostEqual(0.5, result) def test_get_trained_cav(self): # 1D inputs. diff --git a/lit_nlp/lib/testing_utils.py b/lit_nlp/lib/testing_utils.py index df08fc80..7f974a1a 100644 --- a/lit_nlp/lib/testing_utils.py +++ b/lit_nlp/lib/testing_utils.py @@ -24,6 +24,7 @@ from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types import numpy as np +import numpy.testing as npt JsonDict = lit_types.JsonDict @@ -177,9 +178,14 @@ def fake_projection_input(n, num_dims): return [{'x': rng.rand(num_dims)} for i in range(n)] -def assert_dicts_almost_equal(testcase, result, actual, places=3): - """Checks if provided dicts are almost equal.""" - if set(result.keys()) != set(actual.keys()): - testcase.fail('results and actual have different keys') - for key in result: - testcase.assertAlmostEqual(result[key], actual[key], places=places) +def assert_deep_almost_equal(testcase, result, actual, places=5): + """Checks if provided inputs are almost equal, recurses on dicts values.""" + if isinstance(result, (int, float)): + testcase.assertAlmostEqual(result, actual, places=places) + elif isinstance(result, (list)): + npt.assert_array_almost_equal(result, actual, decimal=places) + elif isinstance(result, dict): + if set(result.keys()) != set(actual.keys()): + testcase.fail('results and actual have different keys') + for key in result: + assert_deep_almost_equal(testcase, result[key], actual[key]) From 2913710582851efc08a8c6607dffb2e17d361037 Mon Sep 17 00:00:00 2001 From: Noah Broestl Date: Tue, 20 Apr 2021 15:17:52 +0000 Subject: [PATCH 018/213] Fix typo and add demo instructions --- lit_nlp/examples/models/stanza_models.py | 4 ++-- lit_nlp/examples/stanza_demo.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/lit_nlp/examples/models/stanza_models.py b/lit_nlp/examples/models/stanza_models.py index a776e480..95c17437 100644 --- a/lit_nlp/examples/models/stanza_models.py +++ b/lit_nlp/examples/models/stanza_models.py @@ -132,14 +132,14 @@ def entity_char_to_token(entity, sentence): object to return the start and end tokens for the entity. Example entity: - {'text': 'Barrack Obama', + {'text': 'Barack Obama', 'type': 'PERSON', 'start_char': 0, 'end_char': 13} Example sentence: [ {'id': 1, - 'text': 'Barrack', + 'text': 'Barack', ..., 'misc': 'start_char=0|end_char=7'}, {'id': 2, diff --git a/lit_nlp/examples/stanza_demo.py b/lit_nlp/examples/stanza_demo.py index c7dfc228..b24c6090 100644 --- a/lit_nlp/examples/stanza_demo.py +++ b/lit_nlp/examples/stanza_demo.py @@ -14,7 +14,8 @@ # ============================================================================== # Lint at: python3 """Example demo loading Stanza models. -To run with the demo: +To run the demo: + pip install stanza python -m lit_nlp.examples.stanza_demo --port=5432 Then navigate to localhost:5432 to access the demo UI. """ From a6b57f61f54324c7d6de4a230bc69af9ea375687 Mon Sep 17 00:00:00 2001 From: Noah Broestl Date: Tue, 20 Apr 2021 19:39:27 +0000 Subject: [PATCH 019/213] Formatting update --- lit_nlp/examples/models/stanza_models.py | 37 ++++++++++++++---------- lit_nlp/examples/stanza_demo.py | 14 ++++----- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/lit_nlp/examples/models/stanza_models.py b/lit_nlp/examples/models/stanza_models.py index 95c17437..de3ccdcd 100644 --- a/lit_nlp/examples/models/stanza_models.py +++ b/lit_nlp/examples/models/stanza_models.py @@ -13,18 +13,18 @@ # limitations under the License. # ============================================================================== # Lint as: python3 -"""Wrapper for Stanza model""" +"""Wrapper for Stanza model.""" +from lit_nlp.api import dtypes from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types -from lit_nlp.api import dtypes SpanLabel = dtypes.SpanLabel EdgeLabel = dtypes.EdgeLabel class StanzaTagger(lit_model.Model): - """Stanza Model wrapper""" + """Stanza Model wrapper.""" def __init__(self, model, tasks): """Initialize with Stanza model and a dictionary of tasks. @@ -32,7 +32,8 @@ def __init__(self, model, tasks): Args: model: A Stanza model tasks: A dictionary of tasks, grouped by task type. - Keys are the grouping, which should be one of ('sequence', 'span', 'edge'). + Keys are the grouping, which should be one of: + ('sequence', 'span', 'edge'). Values are a list of stanza task names as strings. """ self.model = model @@ -74,22 +75,27 @@ def _predict(self, ex): doc = self.model(ex["sentence"]) prediction = {task: [] for task in self._output_spec} for sentence in doc.sentences: - # Get starting token of the offset to align task to tokens for multiple sentences - start_token = len(prediction['tokens']) + # Get starting token of the offset to align task for multiple sentences + start_token = len(prediction["tokens"]) prediction["tokens"].extend([word.text for word in sentence.words]) # Process each sequence task for task in self.sequence_tasks: - prediction[task].extend([word.to_dict()[task] for word in sentence.words]) + prediction[task].extend( + [word.to_dict()[task] for word in sentence.words]) # Process each span task for task in self.span_tasks: # Mention is currently the only span task if task == "mention": for entity in sentence.entities: - # Stanza indexes start/end of entities on char. LIT needs them as token indexes + # Stanza indexes start/end of entities on char. LIT needs them as + # token indexes start, end = entity_char_to_token(entity, sentence) - span_label = SpanLabel(start=start+start_token, end=end+start_token, label=entity.type) + span_label = SpanLabel( + start=start + start_token, + end=end + start_token, + label=entity.type) prediction[task].append(span_label) else: raise ValueError(f"Invalid span task: '{task}'") @@ -101,11 +107,12 @@ def _predict(self, ex): for relation in sentence.dependencies: label = relation[1] span1 = relation[2].id + start_token - span2 = relation[2].id + start_token if label == "root" else relation[0].id + start_token - # Relation lists have a root value at index 0, so subtract 1 to align them to tokens - edge_label = EdgeLabel( - (span1 - 1, span1), (span2 - 1, span2), label - ) + span2_index = 2 if label == "root" else 0 + span2 = relation[span2_index].id + start_token + # Relation lists have a root value at index 0, so subtract 1 to + # align them to tokens + edge_label = EdgeLabel((span1 - 1, span1), (span2 - 1, span2), + label) prediction[task].append(edge_label) else: raise ValueError(f"Invalid edge task: '{task}'") @@ -123,7 +130,7 @@ def output_spec(self): def entity_char_to_token(entity, sentence): - """Takes Stanza entity and sentence objects and returns the start and end tokens for the entity + """Takes Stanza entity and sentence objects and returns the start and end tokens for the entity. The misc value in a stanza sentence object contains a string with additional information, separated by a pipe character. This string contains the diff --git a/lit_nlp/examples/stanza_demo.py b/lit_nlp/examples/stanza_demo.py index b24c6090..0be4e30f 100644 --- a/lit_nlp/examples/stanza_demo.py +++ b/lit_nlp/examples/stanza_demo.py @@ -14,6 +14,7 @@ # ============================================================================== # Lint at: python3 """Example demo loading Stanza models. + To run the demo: pip install stanza python -m lit_nlp.examples.stanza_demo --port=5432 @@ -21,16 +22,14 @@ """ from absl import app from absl import flags - -import lit_nlp.api.dataset as lit_dataset -import lit_nlp.api.types as lit_types -from lit_nlp.examples.datasets import glue -from lit_nlp.examples.models import stanza_models from lit_nlp import dev_server from lit_nlp import server_flags +import lit_nlp.api.dataset as lit_dataset +import lit_nlp.api.types as lit_types from lit_nlp.components import scrambler from lit_nlp.components import word_replacer - +from lit_nlp.examples.datasets import glue +from lit_nlp.examples.models import stanza_models import stanza FLAGS = flags.FLAGS @@ -77,7 +76,8 @@ def main(_): } # Datasets for LIT demo - # TODO: Use the UD dataset (https://huggingface.co/datasets/universal_dependencies) + # TODO(nbroestl): Use the UD dataset + # (https://huggingface.co/datasets/universal_dependencies) datasets = { "SST2": glue.SST2Data(split="validation").slice[: FLAGS.max_examples], "blank": lit_dataset.Dataset({"text": lit_types.TextSegment()}, []), From dbeaa5efd95217ca6fdd7be3560c58e7c3934bcb Mon Sep 17 00:00:00 2001 From: Noah Broestl Date: Tue, 20 Apr 2021 21:02:28 +0000 Subject: [PATCH 020/213] Another format update --- lit_nlp/examples/models/stanza_models.py | 12 ++++---- lit_nlp/examples/stanza_demo.py | 38 ++++++++++++------------ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/lit_nlp/examples/models/stanza_models.py b/lit_nlp/examples/models/stanza_models.py index de3ccdcd..fdf90366 100644 --- a/lit_nlp/examples/models/stanza_models.py +++ b/lit_nlp/examples/models/stanza_models.py @@ -43,11 +43,11 @@ def __init__(self, model, tasks): self.edge_tasks = tasks["edge"] self._input_spec = { - "sentence": lit_types.TextSegment(), + "sentence": lit_types.TextSegment(), } self._output_spec = { - "tokens": lit_types.Tokens(), + "tokens": lit_types.Tokens(), } # Output spec based on specified tasks @@ -82,7 +82,7 @@ def _predict(self, ex): # Process each sequence task for task in self.sequence_tasks: prediction[task].extend( - [word.to_dict()[task] for word in sentence.words]) + [word.to_dict()[task] for word in sentence.words]) # Process each span task for task in self.span_tasks: @@ -93,9 +93,9 @@ def _predict(self, ex): # token indexes start, end = entity_char_to_token(entity, sentence) span_label = SpanLabel( - start=start + start_token, - end=end + start_token, - label=entity.type) + start=start + start_token, + end=end + start_token, + label=entity.type) prediction[task].append(span_label) else: raise ValueError(f"Invalid span task: '{task}'") diff --git a/lit_nlp/examples/stanza_demo.py b/lit_nlp/examples/stanza_demo.py index 0be4e30f..201a1f24 100644 --- a/lit_nlp/examples/stanza_demo.py +++ b/lit_nlp/examples/stanza_demo.py @@ -35,27 +35,27 @@ FLAGS = flags.FLAGS flags.DEFINE_list( - "sequence_tasks", - ["upos", "xpos", "lemma"], - "Sequence tasks to load and use for prediction. Defaults to all sequence tasks", + "sequence_tasks", + ["upos", "xpos", "lemma"], + "Sequence tasks to load and use for prediction. Defaults to all sequence tasks", ) flags.DEFINE_list( - "span_tasks", - ["mention"], - "Span tasks to load and use for prediction. Only mentions are included in this demo", + "span_tasks", + ["mention"], + "Span tasks to load and use for prediction. Only mentions are included in this demo", ) flags.DEFINE_list( - "edge_tasks", - ["deps"], - "Span tasks to load and use for prediction. Only deps are included in this demo", + "edge_tasks", + ["deps"], + "Span tasks to load and use for prediction. Only deps are included in this demo", ) flags.DEFINE_string("language", "en", "Language to load for Stanza model.") flags.DEFINE_integer( - "max_examples", None, "Maximum number of examples to load into LIT." + "max_examples", None, "Maximum number of examples to load into LIT." ) @@ -63,35 +63,35 @@ def main(_): # Set Tasks as a dictionary with task groups as # keys and values as lists of strings of Stanza task names tasks = { - "sequence": FLAGS.sequence_tasks, - "span": FLAGS.span_tasks, - "edge": FLAGS.edge_tasks, + "sequence": FLAGS.sequence_tasks, + "span": FLAGS.span_tasks, + "edge": FLAGS.edge_tasks, } # Get the correct model for the language stanza.download(FLAGS.language) pretrained_model = stanza.Pipeline(FLAGS.language) models = { - "stanza": stanza_models.StanzaTagger(pretrained_model, tasks), + "stanza": stanza_models.StanzaTagger(pretrained_model, tasks), } # Datasets for LIT demo # TODO(nbroestl): Use the UD dataset # (https://huggingface.co/datasets/universal_dependencies) datasets = { - "SST2": glue.SST2Data(split="validation").slice[: FLAGS.max_examples], - "blank": lit_dataset.Dataset({"text": lit_types.TextSegment()}, []), + "SST2": glue.SST2Data(split="validation").slice[: FLAGS.max_examples], + "blank": lit_dataset.Dataset({"text": lit_types.TextSegment()}, []), } # Add generators generators = { - "scrambler": scrambler.Scrambler(), - "word_replacer": word_replacer.WordReplacer(), + "scrambler": scrambler.Scrambler(), + "word_replacer": word_replacer.WordReplacer(), } # Start the LIT server. See server_flags.py for server options. lit_demo = dev_server.Server( - models, datasets, generators, **server_flags.get_flags() + models, datasets, generators, **server_flags.get_flags() ) lit_demo.serve() From 35fa4dda2746f838ba7542bb87328ea7712225df Mon Sep 17 00:00:00 2001 From: Ellen Jiang Date: Wed, 21 Apr 2021 10:21:20 -0700 Subject: [PATCH 021/213] Fixes a precision error by switching testing utils helper to use numpy.testing.assert_allclose(). PiperOrigin-RevId: 369682142 --- lit_nlp/lib/testing_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lit_nlp/lib/testing_utils.py b/lit_nlp/lib/testing_utils.py index 7f974a1a..8eadc33e 100644 --- a/lit_nlp/lib/testing_utils.py +++ b/lit_nlp/lib/testing_utils.py @@ -178,12 +178,13 @@ def fake_projection_input(n, num_dims): return [{'x': rng.rand(num_dims)} for i in range(n)] -def assert_deep_almost_equal(testcase, result, actual, places=5): +def assert_deep_almost_equal(testcase, result, actual, places=4): """Checks if provided inputs are almost equal, recurses on dicts values.""" if isinstance(result, (int, float)): testcase.assertAlmostEqual(result, actual, places=places) elif isinstance(result, (list)): - npt.assert_array_almost_equal(result, actual, decimal=places) + rtol = 10 ** (-1 * places) + npt.assert_allclose(result, actual, rtol=rtol) elif isinstance(result, dict): if set(result.keys()) != set(actual.keys()): testcase.fail('results and actual have different keys') From 75fec950d00769c681980ccf0ccbd907459eb114 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 21 Apr 2021 10:38:23 -0700 Subject: [PATCH 022/213] Escape key to exit fullscreen windows. - Global settings dialogue can also be exited by clicking background. - Escape to exit settings dialogue or maximized module PiperOrigin-RevId: 369686011 --- lit_nlp/client/core/modules.ts | 10 ++++++++++ lit_nlp/client/modules/global_settings.css | 4 ++-- lit_nlp/client/modules/global_settings.ts | 9 ++++++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lit_nlp/client/core/modules.ts b/lit_nlp/client/core/modules.ts index dfc2acc5..779f1c81 100644 --- a/lit_nlp/client/core/modules.ts +++ b/lit_nlp/client/core/modules.ts @@ -81,6 +81,16 @@ export class LitModules extends ReactiveElement { () => this.modulesService.getRenderLayout(), renderLayout => { this.calculateWidths(renderLayout); }); + + // Escape key to exit full-screen modules. + document.addEventListener('keydown', (e: KeyboardEvent) => { + if (e.key === 'Escape') { + for (const e of this.shadowRoot!.querySelectorAll( + 'lit-widget-group[maximized]')) { + e.removeAttribute('maximized'); + } + } + }); } // Calculate widths of all module groups in all panels. diff --git a/lit_nlp/client/modules/global_settings.css b/lit_nlp/client/modules/global_settings.css index b5454380..699236fe 100644 --- a/lit_nlp/client/modules/global_settings.css +++ b/lit_nlp/client/modules/global_settings.css @@ -12,12 +12,12 @@ top: 0; left: 0; opacity: .2; - transition: opacity 500ms; - pointer-events: none; + transition: opacity 250ms; } #overlay.hide { opacity: 0; + visibility: hidden; } #global-settings { diff --git a/lit_nlp/client/modules/global_settings.ts b/lit_nlp/client/modules/global_settings.ts index 66a2f5fd..2756751d 100644 --- a/lit_nlp/client/modules/global_settings.ts +++ b/lit_nlp/client/modules/global_settings.ts @@ -154,7 +154,8 @@ export class GlobalSettingsComponent extends MobxLitElement { // clang-format off return html`
-
+
{ this.close(); }}>
Configure LIT
@@ -656,6 +657,12 @@ export class GlobalSettingsComponent extends MobxLitElement {
${buttonsHTML}
`; } + + firstUpdated() { + document.addEventListener('keydown', (e: KeyboardEvent) => { + if (e.key === 'Escape') this.close(); + }); + } } declare global { From 9fd08fd79520755fecb8fc75295fa459391ebb4c Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 21 Apr 2021 11:04:48 -0700 Subject: [PATCH 023/213] Internal change. PiperOrigin-RevId: 369692304 --- lit_nlp/components/static_preds.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lit_nlp/components/static_preds.py b/lit_nlp/components/static_preds.py index 1b8e6694..e61ece8f 100644 --- a/lit_nlp/components/static_preds.py +++ b/lit_nlp/components/static_preds.py @@ -31,6 +31,13 @@ def key_fn(self, example: JsonDict) -> str: reduced_example = {k: example[k] for k in self.input_identifier_keys} return caching.input_hash(reduced_example) + def description(self): + return self._description + + @property + def input_dataset(self): + return self._all_inputs + def __init__(self, inputs: lit_dataset.Dataset, preds: lit_dataset.Dataset, @@ -43,8 +50,10 @@ def __init__(self, input_identifier_keys: (optional), list of keys to treat as identifiers for matching inputs. If None, will use all fields in inputs.spec() """ - self._output_spec = preds.spec() + self._all_inputs = inputs self._input_spec = inputs.spec() + self._output_spec = preds.spec() + self._description = preds.description() self.input_identifier_keys = input_identifier_keys or self._input_spec.keys( ) # Filter to only the identifier keys From 27c7f5991315e90e674c71ec2bea7afbe35eb138 Mon Sep 17 00:00:00 2001 From: Ankur Taly Date: Wed, 21 Apr 2021 14:37:47 -0700 Subject: [PATCH 024/213] Hotflip: Option to drop tokens instead of flipping PiperOrigin-RevId: 369737574 --- lit_nlp/components/hotflip.py | 58 +++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/lit_nlp/components/hotflip.py b/lit_nlp/components/hotflip.py index fe1bb964..0559064f 100644 --- a/lit_nlp/components/hotflip.py +++ b/lit_nlp/components/hotflip.py @@ -46,6 +46,8 @@ MAX_FLIPS_DEFAULT = 3 TOKENS_TO_IGNORE_KEY = "Tokens to freeze" TOKENS_TO_IGNORE_DEFAULT = [] +DROP_TOKENS_KEY = "Drop tokens instead of flipping" +DROP_TOKENS_DEFAULT = False MAX_FLIPPABLE_TOKENS = 10 @@ -97,6 +99,18 @@ def _gen_tokens_to_flip(self, token_idxs, max_flips): for s in itertools.combinations(token_idxs, i+1): yield s + def _drop_tokens(self, tokens, token_idxs): + # Returns a copy of 'tokens' with all tokens at indices specified in + # 'token_idxs' dropped. + return [t for i, t in enumerate(tokens) if i not in token_idxs] + + def _replace_tokens(self, tokens, token_idxs, + replacement_tokens): + # Returns a copy of 'tokens' with all tokens at indices specified in + # 'token_idxs' replaced with corresponding tokens in 'replacement_tokens'. + return [replacement_tokens[j] if j in token_idxs else t + for j, t in enumerate(tokens)] + def generate(self, example: JsonDict, model: lit_model.Model, @@ -110,6 +124,7 @@ def generate(self, max_flips = int(config.get(MAX_FLIPS_KEY, MAX_FLIPS_DEFAULT)) tokens_to_ignore = config.get(TOKENS_TO_IGNORE_KEY, TOKENS_TO_IGNORE_DEFAULT) + drop_tokens = bool(config.get(DROP_TOKENS_KEY, DROP_TOKENS_DEFAULT)) assert model is not None, "Please provide a model for this generator." logging.info(r"W3lc0m3 t0 H0tFl1p \o/") @@ -170,16 +185,21 @@ def generate(self, token_embs = orig_output[token_emb_fields[0]] assert token_embs.shape[0] == grads.shape[0] - # We take a dot product of each input token gradient (grads) with the - # embedding table (embed) - # TODO(ataly): Only consider tokens that have the same part-of-speech - # tag as the original token (and a certain cosine similarity with the - # original token) - replacement_token_ids = np.argmin( - (np.expand_dims(embed, 1) @ grads.T).squeeze(1), axis=0) - - replacement_tokens = [inv_vocab[id] for id in replacement_token_ids] - logging.info("Replacement tokens: %s", replacement_tokens) + if drop_tokens: + # Update max_flips so that it is at most len(tokens) - 1 (we don't + # want to drop all tokens!) + max_flips = min(len(tokens)-1, max_flips) + else: + # Identify replacement tokens. + # We take a dot product of each input token gradient (grads) with the + # embedding table (embed) and pick the argmin embedding. + # TODO(ataly): Only consider tokens that have the same part-of-speech + # tag as the original token (and a certain cosine similarity with the + # original token) + replacement_token_ids = np.argmin( + (np.expand_dims(embed, 1) @ grads.T).squeeze(1), axis=0) + replacement_tokens = [inv_vocab[id] for id in replacement_token_ids] + logging.info("Replacement tokens: %s", replacement_tokens) # Consider all combinations of tokens upto length max_flips. # We will iterate through this list (in toplogically sorted order) @@ -209,19 +229,23 @@ def generate(self, if self._subset_exists(set(token_idxs), successful_positions): continue - logging.info("Selected tokens to flip: %s (positions=%s) with: %s", - [tokens[i] for i in token_idxs], token_idxs, - [replacement_tokens[i] for i in token_idxs]) - # Create a new input to the model. # TODO(iftenney, bastings): enforce somewhere that this field has the # same name in the input and output specs. input_token_field = token_field input_text_field = input_spec[input_token_field].parent # pytype: disable=attribute-error counterfactual = copy.deepcopy(example) - modified_tokens = copy.copy(tokens) - for j in token_idxs: - modified_tokens[j] = replacement_tokens[j] + if drop_tokens: + modified_tokens = self._drop_tokens(tokens, token_idxs) + logging.info("Selected tokens to drop: %s (positions=%s)", + [tokens[i] for i in token_idxs], token_idxs) + else: + modified_tokens = self._replace_tokens(tokens, token_idxs, + replacement_tokens) + logging.info( + "Selected tokens to flip: %s (positions=%s) with: %s", + [tokens[i] for i in token_idxs], token_idxs, + [replacement_tokens[i] for i in token_idxs]) counterfactual[input_token_field] = modified_tokens # TODO(iftenney, bastings): call a model-provided detokenizer here? # Though in general tokenization isn't invertible and it's possible for From ae5411fa24eea1cc350e41f3f90e8ca48915306a Mon Sep 17 00:00:00 2001 From: Tolga Bolukbasi Date: Fri, 23 Apr 2021 12:58:32 -0700 Subject: [PATCH 025/213] Make table handle fields with missing data more gracefully. PiperOrigin-RevId: 370140696 --- lit_nlp/client/elements/table.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/client/elements/table.ts b/lit_nlp/client/elements/table.ts index 28375981..4b256d2f 100644 --- a/lit_nlp/client/elements/table.ts +++ b/lit_nlp/client/elements/table.ts @@ -583,7 +583,7 @@ export class DataTable extends ReactiveElement { return html``; } else { return (d instanceof TemplateResult) ? d : - html`
${chunkWords(d.toString())}
`; + html`
${d ? chunkWords(d.toString()) : ''}
`; } }))} From 7b98da147138818131a7ffd08ba44a66aabc915c Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 28 Apr 2021 17:40:48 -0700 Subject: [PATCH 026/213] Improvements to model and dataset selection in main toolbar. - Use a clean button style - Quick-select for compatible models PiperOrigin-RevId: 371022221 --- lit_nlp/client/modules/app_toolbar.css | 85 ++++++++-- lit_nlp/client/modules/app_toolbar.ts | 205 +++++++++++++++++-------- 2 files changed, 217 insertions(+), 73 deletions(-) diff --git a/lit_nlp/client/modules/app_toolbar.css b/lit_nlp/client/modules/app_toolbar.css index 3b4da68d..98c0b62b 100644 --- a/lit_nlp/client/modules/app_toolbar.css +++ b/lit_nlp/client/modules/app_toolbar.css @@ -4,24 +4,23 @@ z-index: 3; } -#toolbar { - display: flex; - flex-direction: row; - align-items: center; -} - #headline { width: 100vw; display: flex; + flex-direction: row; justify-content: space-between; align-items: center; background-color: #2f8c9b; font-family: 'Google Sans' !important; color: white; - font-size: 14pt; + font-size: 12pt; letter-spacing: +0.1; } +#title-group { + font-size: 14pt; +} + .headline-section { display: flex; align-items: center; @@ -40,7 +39,7 @@ mwc-icon.icon-button { min-width: 24px; --mdc-icon-size: 24px; cursor: pointer; - margin-left: 5pt; + margin: 0pt 2pt; } mwc-icon.icon-button:hover { @@ -58,12 +57,74 @@ mwc-icon.icon-button:active { display: block; } -.status-item { - margin-left: 1em; +.vertical-separator { + background: #098591; + width: 2px; + height: 1.2rem; + padding: 0; + margin: 0px 4px; +} + +/* For in-line icons in a */ +.material-icon { + font-family: 'Material Icons'; + vertical-align: middle; + margin: 0; +} + +.material-icon-outlined { + font-family: 'Material Icons Outlined'; + vertical-align: middle; + margin: 0; +} + +/* Custom button style to work with the dark background */ +button { font-size: 12pt; + color: white; } -.status-text-underline { - text-decoration: underline; +/* disabled, unselected, default, focus, hover, selected */ + +.headline-button { + margin: 0 4px; + background: transparent; + border-radius: 4px; + border: 1px solid #098591; + font-family: Roboto; + color: #E4F7FB; /* Cyan/50 */ } +.headline-button.unbordered { + border: 1px solid transparent; +} + +.headline-button.unselected { + color: #A1E4F2; /* Cyan/200 */ +} + +.headline-button.selected { + border: 1px solid #E4F7FB; /* Cyan/50 */ +} + +.headline-button:focus { + background: rgba(138, 180, 248, 0.12); + outline: none; +} + +.headline-button:hover { + background: rgba(26, 115, 232, 0.04); + opacity: .7; + border: 1px solid #BDC1C6; +} + +.headline-button:active { + background: #E8F0FE; + box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), + 0px 1px 3px 1px rgba(60, 64, 67, 0.15); + opacity: .4; +} + +.headline-button:disabled { + color: rgba(60, 64, 67, 0.38); +} diff --git a/lit_nlp/client/modules/app_toolbar.ts b/lit_nlp/client/modules/app_toolbar.ts index 96aff41a..18bf3ba8 100644 --- a/lit_nlp/client/modules/app_toolbar.ts +++ b/lit_nlp/client/modules/app_toolbar.ts @@ -26,11 +26,12 @@ import './main_toolbar'; import {MobxLitElement} from '@adobe/lit-mobx'; import {customElement, html, query} from 'lit-element'; +import {classMap} from 'lit-html/directives/class-map'; import {app} from '../core/lit_app'; import {datasetDisplayName} from '../lib/types'; import {copyToClipboard} from '../lib/utils'; -import {AppState, ModulesService, StatusService} from '../services/services'; +import {AppState, ModulesService, SettingsService, StatusService} from '../services/services'; import {styles} from './app_toolbar.css'; import {GlobalSettingsComponent, TabName} from './global_settings'; @@ -48,6 +49,7 @@ export class ToolbarComponent extends MobxLitElement { } private readonly appState = app.getService(AppState); + private readonly settingsService = app.getService(SettingsService); private readonly statusService = app.getService(StatusService); private readonly modulesService = app.getService(ModulesService); @@ -62,46 +64,151 @@ export class ToolbarComponent extends MobxLitElement { jumpToSettingsTab(targetTab: TabName) { if (this.globalSettingsElement === undefined) return; - this.globalSettingsElement.selectedTab = targetTab; - this.globalSettingsElement.open(); + if (this.globalSettingsElement.isOpen && + this.globalSettingsElement.selectedTab === targetTab) { + this.globalSettingsElement.close(); + } else { + this.globalSettingsElement.selectedTab = targetTab; + this.globalSettingsElement.open(); + } } - onCopyLinkClick() { + renderStatusAndTitle() { + let title = 'Language Interpretability Tool'; + if (this.appState.initialized && this.appState.metadata.pageTitle) { + title = this.appState.metadata.pageTitle; + } + // clang-format off + return html` + + `; + // clang-format on + } + + renderModelInfo() { + const compatibleModels = + Object.keys(this.appState.metadata.models) + .filter( + model => this.settingsService.isDatasetValidForModels( + this.appState.currentDataset, [model])); + + if (2 <= compatibleModels.length && compatibleModels.length <= 4) { + // If we have more than one compatible model (but not too many), + // show the in-line selector. + const modelChips = compatibleModels.map(name => { + const isSelected = this.appState.currentModels.includes(name); + const classes = { + 'headline-button': true, + 'unselected': !isSelected, // not the same as default; see CSS + 'selected': isSelected, + }; + const icon = isSelected ? 'check_box' : 'check_box_outline_blank'; + const updateModelSelection = () => { + const modelSet = new Set(this.appState.currentModels); + if (modelSet.has(name)) { + modelSet.delete(name); + } else { + modelSet.add(name); + } + this.settingsService.updateSettings({'models': [...modelSet]}); + this.requestUpdate(); + }; + // clang-format off + return html` + + `; + // clang-format on + }); + // clang-format off + return html` + ${modelChips} + + `; + // clang-format on + } else { + // Otherwise, give a regular button that opens the models menu. + // clang-format off + return html` + + `; + // clang-format on + } + } + + renderDatasetInfo() { + // clang-format off + return html` +
+ + `; + // clang-format on + } + + renderConfigControls() { + // clang-format off + return html` + ${this.appState.initialized ? this.renderModelInfo() : null} + ${this.appState.initialized ? this.renderDatasetInfo() : null} +
+
+ { this.jumpToSettingsTab("Layout"); }}> + view_compact + +
+
+ + settings + +
+ `; + // clang-format on + } + + onClickCopyLink() { const urlBase = (this.appState.metadata.canonicalURL || window.location.host); copyToClipboard(urlBase + window.location.search); } - renderModelAndDatasetInfo() { - const modelsPrefix = - this.appState.currentModels.length > 1 ? 'Models' : 'Model'; - const modelsText = html` - ${modelsPrefix}: - - ${this.appState.currentModels.join(', ')} - `; - const datasetText = html` - Dataset: - - ${datasetDisplayName(this.appState.currentDataset)} - `; + renderRightCorner() { // clang-format off return html` -
-
{ this.jumpToSettingsTab("Models"); }}> - ${modelsText} -
-
{ this.jumpToSettingsTab("Dataset"); }}> - ${datasetText} -
-
+ `; // clang-format on } - render() { const doRenderToolbar = (this.appState.initialized && @@ -113,39 +220,15 @@ export class ToolbarComponent extends MobxLitElement { ${this.appState.initialized ? html`` : null}
-
-
-
-
- - ${this.statusService.hasError ? - html`` : - html``} - - ${this.appState.initialized && this.appState.metadata.pageTitle ? - this.appState.metadata.pageTitle : "Language Interpretability Tool"} -
- ${this.appState.initialized ? this.renderModelAndDatasetInfo() : null} -
-
-
- - link - -
-
- { this.jumpToSettingsTab("Layout"); }}> - view_compact - -
-
- - settings - -
-
+
+
+ ${this.renderStatusAndTitle()} +
+
+ ${this.renderConfigControls()} +
+
+ ${this.renderRightCorner()}
${doRenderToolbar? html`` : null} From a61d38c6e1de6f296a2006652e6346d8e289b406 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 28 Apr 2021 19:02:11 -0700 Subject: [PATCH 027/213] Pin header and main-area heights in the widget, so we don't get extra vertical scollbars in the main or tray areas. PiperOrigin-RevId: 371032145 --- lit_nlp/client/core/widget_group.css | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lit_nlp/client/core/widget_group.css b/lit_nlp/client/core/widget_group.css index a182682f..05ecfaf4 100644 --- a/lit_nlp/client/core/widget_group.css +++ b/lit_nlp/client/core/widget_group.css @@ -38,7 +38,7 @@ } .holder { - height: 100%; + height: calc(100% - 28px); display: flex; flex-direction: column; position: relative; @@ -46,6 +46,7 @@ } .header { + height: 28px; display: flex; padding-left: 2pt; line-height: 20pt; From 1b1d8abcdfc37eb3fa645a2e9b9c61de994f1e18 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Wed, 28 Apr 2021 20:11:49 -0700 Subject: [PATCH 028/213] Fix rendering of false-y values that are not 'undefined' PiperOrigin-RevId: 371038940 --- lit_nlp/client/elements/table.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lit_nlp/client/elements/table.ts b/lit_nlp/client/elements/table.ts index 4b256d2f..aa662a06 100644 --- a/lit_nlp/client/elements/table.ts +++ b/lit_nlp/client/elements/table.ts @@ -583,7 +583,7 @@ export class DataTable extends ReactiveElement { return html``; } else { return (d instanceof TemplateResult) ? d : - html`
${d ? chunkWords(d.toString()) : ''}
`; + html`
${d !== undefined ? chunkWords(String(d)) : ''}
`; } }))} From b5118be4835e3add30a992e97d833abf24897b46 Mon Sep 17 00:00:00 2001 From: Ian Tenney Date: Thu, 29 Apr 2021 08:15:40 -0700 Subject: [PATCH 029/213] Data table set to "only show selected" by default, for more informative view when driving selections from other modules. Ignore self-selection, so that working within the data table still works normally. - Also ignore selections from URLService on page load. - Use a toggle switch instead of a checkbox for the control. PiperOrigin-RevId: 371122353 --- lit_nlp/client/modules/data_table_module.css | 9 ++++++ lit_nlp/client/modules/data_table_module.ts | 30 +++++++++++++++----- lit_nlp/client/modules/main_toolbar.css | 6 ---- lit_nlp/client/modules/shared_styles.css | 8 ++++++ lit_nlp/client/services/url_service.ts | 12 ++++---- 5 files changed, 46 insertions(+), 19 deletions(-) diff --git a/lit_nlp/client/modules/data_table_module.css b/lit_nlp/client/modules/data_table_module.css index 6762e234..8377aa2d 100644 --- a/lit_nlp/client/modules/data_table_module.css +++ b/lit_nlp/client/modules/data_table_module.css @@ -37,3 +37,12 @@ max-height: calc(100% - 48px); overflow: auto; } + +.switch-container { + display: flex; + align-items: center; +} + +.switch-container > * { + margin-right: 4px; +} diff --git a/lit_nlp/client/modules/data_table_module.ts b/lit_nlp/client/modules/data_table_module.ts index a0ccb8c9..654b38a4 100644 --- a/lit_nlp/client/modules/data_table_module.ts +++ b/lit_nlp/client/modules/data_table_module.ts @@ -16,6 +16,7 @@ */ // tslint:disable:no-new-decorators +import '@material/mwc-switch'; import '../elements/checkbox'; import {customElement, html, query} from 'lit-element'; @@ -28,7 +29,7 @@ import {formatForDisplay, IndexedInput, ModelInfoMap, Spec} from '../lib/types'; import {compareArrays, findSpecKeys, shortenId} from '../lib/utils'; import {ClassificationInfo} from '../services/classification_service'; import {RegressionInfo} from '../services/regression_service'; -import {ClassificationService, FocusService, RegressionService, SelectionService} from '../services/services'; +import {ClassificationService, FocusService, RegressionService, SelectionService, UrlService} from '../services/services'; import {styles} from './data_table_module.css'; import {styles as sharedStyles} from './shared_styles.css'; @@ -62,12 +63,16 @@ export class DataTableModule extends LitModule { @observable searchText = ''; // Module options / configuration state - @observable private filterSelected: boolean = false; + @observable private filterSelected: boolean = true; @observable private columnDropdownVisible: boolean = false; // Persistent selection state @observable private selectedInputData: IndexedInput[] = []; + // Used to manage selection response; should not be interacted + // with directly. + private readonly urlService = app.getService(UrlService); + // Child components @query('lit-data-table') private readonly table?: DataTable; @@ -219,9 +224,19 @@ export class DataTableModule extends LitModule { this.react(getKeys, keys => { this.updateColumns(); }); + // React to change in selection. this.reactImmediately( () => this.selectionService.selectedOrAllInputData, inputData => { - this.selectedInputData = inputData; + // Don't react to selection set from within this module. + if (this.selectionService.lastUser === this) { + return; + } + // If selection set from URL, also show the full dataset only. + if (this.selectionService.lastUser === this.urlService) { + this.selectedInputData = this.appState.currentInputData; + } else { + this.selectedInputData = inputData; + } if (this.table) { this.table.resetView(); } @@ -427,10 +442,11 @@ export class DataTableModule extends LitModule { // clang-format off return html` - { this.filterSelected = !this.filterSelected; }} - > +
{ this.filterSelected = !this.filterSelected; }}> + + Only show selected +
- -
- `; + const row: {[key: string]: TableEntry} = {}; + for (const key of fieldNames) { + const editable = + !this.appState.currentModelRequiredInputSpecKeys.includes(key); + row[key] = editable ? this.renderEntry(generated.data, key) + : generated.data[key]; + } + row['Controls'] = + html` +
+ + +
`; + rows.push(row); // clang-format on - }); - }); + } + } + return rows; } - renderHeader() { + renderOverallControls() { const onAddAll = async () => { await this.createNewDatapoints(this.generated); this.resetEditedData(); }; - if (this.totalNumGenerated <= 0) { - return null; - } // clang-format off return html` - `; diff --git a/lit_nlp/components/nearest_neighbors.py b/lit_nlp/components/nearest_neighbors.py new file mode 100644 index 00000000..60435658 --- /dev/null +++ b/lit_nlp/components/nearest_neighbors.py @@ -0,0 +1,99 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Finds the k nearest neighbors to an input embedding.""" + + +from typing import List, Optional, Sequence + +import attr +from lit_nlp.api import components as lit_components +from lit_nlp.api import dataset as lit_dataset +from lit_nlp.api import model as lit_model +from lit_nlp.api import types +import numpy as np +from scipy.spatial import distance + + +JsonDict = types.JsonDict +IndexedInput = types.IndexedInput +Spec = types.Spec + + +@attr.s(auto_attribs=True, kw_only=True) +class NearestNeighborsConfig(object): + """Config options for Nearest Neighbors component.""" + embedding_name: str = '' + num_neighbors: Optional[int] = 10 + dataset_name: Optional[str] = '' + + +class NearestNeighbors(lit_components.Interpreter): + """Computes nearest neighbors of an example embedding. + + Required Model Output: + - Embeddings (`emb_layer`) to return the input embeddings + for a layer + """ + + def run_with_metadata( + self, + indexed_inputs: Sequence[IndexedInput], + model: lit_model.Model, + dataset: lit_dataset.IndexedDataset, + model_outputs: Optional[List[JsonDict]] = None, + config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]: + """Finds the nearest neighbors of the example specified in the config. + + Args: + indexed_inputs: the dataset example to find nearest neighbors for. + model: the model being explained. + dataset: the dataset which the current examples belong to. + model_outputs: optional model outputs from calling model.predict(inputs). + config: a config which should specify: + { + 'num_neighbors': [the number of nearest neighbors to return] + 'dataset_name': [the name of the dataset (used for caching)] + 'embedding_name': [the name of the embedding field to use] + } + + Returns: + A JsonDict containing the a list of num_neighbors nearest neighbors, + where each has the example id and distance from the main example. + """ + config = NearestNeighborsConfig(**config) + + dataset_outputs = list(model.predict_with_metadata( + dataset.indexed_examples, dataset_name=config.dataset_name)) + + example_outputs = list(model.predict_with_metadata( + indexed_inputs, dataset_name=config.dataset_name)) + # TODO(lit-dev): Add support for selecting nearest neighbors of a set. + if len(example_outputs) != 1: + raise ValueError('More than one selected example was passed in.') + example_output = example_outputs[0] + + # [emb_size] + dataset_embs = [output[config.embedding_name] for output in dataset_outputs] + example_embs = [example_output[config.embedding_name]] + distances = distance.cdist(example_embs, dataset_embs)[0] + sorted_indices = np.argsort(distances) + k = config.num_neighbors + k_nearest_neighbors = [ + {'id': dataset.indexed_examples[original_index]['id'], + 'nn_distance': distances[original_index] + } for original_index in sorted_indices[:k]] + + return [{'nearest_neighbors': k_nearest_neighbors}] diff --git a/lit_nlp/components/nearest_neighbors_test.py b/lit_nlp/components/nearest_neighbors_test.py new file mode 100644 index 00000000..248ae676 --- /dev/null +++ b/lit_nlp/components/nearest_neighbors_test.py @@ -0,0 +1,102 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Tests for lit_nlp.components.gradient_maps.""" + +from typing import List + +from absl.testing import absltest +from lit_nlp.api import dataset as lit_dataset +from lit_nlp.api import model as lit_model +from lit_nlp.api import types as lit_types +from lit_nlp.components import nearest_neighbors +from lit_nlp.lib import caching # for hash id fn +from lit_nlp.lib import testing_utils +import numpy as np + + +JsonDict = lit_types.JsonDict + + +class TestModelNearestNeighbors(lit_model.Model): + """Implements lit.Model interface for nearest neighbors. + + Returns the same output for every input. + """ + + # LIT API implementation + def max_minibatch_size(self, **unused_kw): + return 3 + + def input_spec(self): + return {'segment': lit_types.TextSegment} + + def output_spec(self): + return {'probas': lit_types.MulticlassPreds( + parent='label', + vocab=['0', '1'], + null_idx=0), + 'input_embs': lit_types.TokenEmbeddings(align='tokens'), + } + + def predict_minibatch(self, inputs: List[JsonDict], **kw): + embs = [np.array([0, 0, 0, 0]), + np.array([1, 1, 1, 0]), + np.array([5, 8, -10, 0])] + probas = np.array([0.2, 0.8]) + return [{'probas': probas, 'input_embs': embs[i]} + for i, _ in enumerate(inputs)] + + +class NearestNeighborTest(absltest.TestCase): + + def setUp(self): + super(NearestNeighborTest, self).setUp() + self.nearest_neighbors = nearest_neighbors.NearestNeighbors() + + def test_run_nn(self): + examples = [ + { + 'segment': 'a' + }, + { + 'segment': 'b' + }, + { + 'segment': 'c' + }, + ] + indexed_inputs = [{'id': caching.input_hash(ex), 'data': ex} + for ex in examples] + + model = TestModelNearestNeighbors() + dataset = lit_dataset.IndexedDataset(id_fn=caching.input_hash, + indexed_examples=indexed_inputs) + config = { + 'embedding_name': 'input_embs', + 'num_neighbors': 2, + } + result = self.nearest_neighbors.run_with_metadata([indexed_inputs[1]], + model, dataset, + config=config) + expected = {'nearest_neighbors': [ + {'id': '1', 'nn_distance': 0.0}, + {'id': '0', 'nn_distance': 1.7320508075688772}]} + + self.assertLen(result, 1) + testing_utils.assert_deep_almost_equal(self, expected, result[0]) + +if __name__ == '__main__': + absltest.main() diff --git a/lit_nlp/lib/testing_utils.py b/lit_nlp/lib/testing_utils.py index 8eadc33e..ad44202f 100644 --- a/lit_nlp/lib/testing_utils.py +++ b/lit_nlp/lib/testing_utils.py @@ -183,8 +183,12 @@ def assert_deep_almost_equal(testcase, result, actual, places=4): if isinstance(result, (int, float)): testcase.assertAlmostEqual(result, actual, places=places) elif isinstance(result, (list)): - rtol = 10 ** (-1 * places) - npt.assert_allclose(result, actual, rtol=rtol) + if all(isinstance(n, (int, float)) for n in result): + rtol = 10 ** (-1 * places) + npt.assert_allclose(result, actual, rtol=rtol) + elif all(isinstance(n, dict) for n in result): + for i in range(len(result)): + assert_deep_almost_equal(testcase, result[i], actual[i]) elif isinstance(result, dict): if set(result.keys()) != set(actual.keys()): testcase.fail('results and actual have different keys') From d94a9287046579c6ba0adcab3cae38d48e56083e Mon Sep 17 00:00:00 2001 From: Ankur Taly Date: Tue, 11 May 2021 11:01:22 -0700 Subject: [PATCH 043/213] Hotflip: Add support for regression models PiperOrigin-RevId: 373190015 --- lit_nlp/components/hotflip.py | 173 +++++++++++++++-------- lit_nlp/components/hotflip_test.py | 217 +++++++++++++++++++++-------- 2 files changed, 275 insertions(+), 115 deletions(-) diff --git a/lit_nlp/components/hotflip.py b/lit_nlp/components/hotflip.py index e228c6d0..eee7a37d 100644 --- a/lit_nlp/components/hotflip.py +++ b/lit_nlp/components/hotflip.py @@ -16,7 +16,7 @@ """Hotflip generator that perturbs input tokens to flip the prediction. A hotflip is defined as a counterfactual sentence that alters one or more -tokens in the input sentence in order to to obtain a different prediction +tokens in the input sentence in order to obtain a different prediction from the input sentence. A hotflip is considered minimal if no strict subset of the applied token flips @@ -59,6 +59,8 @@ TOKENS_TO_IGNORE_DEFAULT = [] DROP_TOKENS_KEY = "Drop tokens instead of flipping" DROP_TOKENS_DEFAULT = False +REGRESSION_THRESH_KEY = "Regression threshold" +REGRESSION_THRESH_DEFAULT = 0.0 MAX_FLIPPABLE_TOKENS = 10 @@ -69,7 +71,11 @@ class HotFlip(lit_components.Generator): each token and uses them to heuristically estimate the impact of perturbing the token. - This generator is currently only supported on classification models. + This generator works for both classification and regression models. In the + case of classification models, the returned counterfactuals are guaranteed to + have a different prediction class as the original example. In the case of + regression models, the returned counterfactuals are guaranteed to be on the + opposite side of a user-provided threshold as the original example. """ def find_fields( @@ -86,14 +92,37 @@ def find_fields( return [f for f in fields if getattr(output_spec[f], "align", None) == align_field] + def _get_tokens_and_gradients(self, + output_spec: JsonDict, + output: JsonDict): + """Returns a dictionary mapping token fields to tokens and gradients.""" + # Find gradient fields + grad_fields = self.find_fields(output_spec, types.TokenGradients, + None) + if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test + return {} + + ret = {} + for grad_field in grad_fields: + # Get tokens, token gradients and token embeddings. + token_field = output_spec[grad_field].align # pytype: disable=attribute-error + tokens = output[token_field] + grads = output[grad_field] + ret[token_field] = [tokens, grads] + return ret + def config_spec(self) -> types.Spec: return { NUM_EXAMPLES_KEY: types.TextSegment(default=str(NUM_EXAMPLES_DEFAULT)), MAX_FLIPS_KEY: types.TextSegment(default=str(MAX_FLIPS_DEFAULT)), TOKENS_TO_IGNORE_KEY: types.Tokens(default=TOKENS_TO_IGNORE_DEFAULT), - DROP_TOKENS_KEY: types.TextSegment(default=str(DROP_TOKENS_DEFAULT)), + DROP_TOKENS_KEY: types.Boolean(default=DROP_TOKENS_DEFAULT), PREDICTION_KEY: types.FieldMatcher(spec="output", - types=["MulticlassPreds"]) + types=["MulticlassPreds", + "RegressionScore"]), + REGRESSION_THRESH_KEY: types.TextSegment( + default=str(REGRESSION_THRESH_DEFAULT)), + } def _subset_exists(self, cand_set, sets): @@ -109,7 +138,7 @@ def _gen_token_idxs_to_flip( token_grads: np.ndarray, max_flips: int, tokens_to_ignore: List[str], - drop_tokens: Optional[bool] = False) -> Iterator[Tuple[int, ...]]: + drop_tokens: bool = False) -> Iterator[Tuple[int, ...]]: """Generates sets of token positions that are eligible for flipping.""" if drop_tokens: # Update max_flips so that it is at most len(tokens) - 1 (we don't @@ -160,6 +189,25 @@ def _flip_tokens(self, [replacement_tokens[i] for i in token_idxs]) return modified_tokens + def _create_cf(self, + example: JsonDict, + token_field: str, + text_field: str, + tokens: List[str], + token_idxs: Tuple[int, ...], + drop_tokens: bool, + replacement_tokens: List[str]) -> JsonDict: + cf = copy.deepcopy(example) + modified_tokens = self._flip_tokens( + tokens, token_idxs, drop_tokens, replacement_tokens) + # TODO(iftenney, bastings): call a model-provided detokenizer here? + # Though in general tokenization isn't invertible and it's possible for + # HotFlip to produce wordpiece sequences that don't correspond to any + # input string. + cf[token_field] = modified_tokens + cf[text_field] = " ".join(modified_tokens) + return cf + def _update_label(self, example: JsonDict, example_output: JsonDict, @@ -176,24 +224,34 @@ def _update_label(self, def _is_hotflip(self, cf_output: JsonDict, orig_output: JsonDict, - pred_key: Text) -> bool: - """Check if cf_output and orig_output specify different prediction classes.""" - cf_pred_class = np.argmax(cf_output[pred_key]) - orig_pred_class = np.argmax(orig_output[pred_key]) + pred_key: Text, + is_regression: bool = False, + regression_thresh: Optional[float] = None) -> bool: + """Check if cf_output and orig_output specify different prediciton classes.""" + if is_regression: + # regression model. We use the provided threshold to binarize the output. + cf_pred_class = cf_output[pred_key] <= regression_thresh + orig_pred_class = orig_output[pred_key] <= regression_thresh + else: + cf_pred_class = np.argmax(cf_output[pred_key]) + orig_pred_class = np.argmax(orig_output[pred_key]) return cf_pred_class != orig_pred_class def _get_replacement_tokens( self, embedding_matrix: np.ndarray, inv_vocab: List[Text], - token_grads: np.ndarray) -> List[str]: + token_grads: np.ndarray, + orig_output: JsonDict, + direction: int = -1) -> List[str]: """Identifies replacement tokens for each token position.""" + token_grads = token_grads * direction # Compute dot product of each input token gradient with the embedding # matrix, and pick the argmin. # TODO(ataly): Only consider tokens that have the same part-of-speech # tag as the original token and/or a certain cosine similarity with the # original token. - replacement_token_ids = np.argmin( + replacement_token_ids = np.argmax( (np.expand_dims(embedding_matrix, 1) @ token_grads.T).squeeze(1), axis=0) replacement_tokens = [inv_vocab[id] for id in replacement_token_ids] @@ -214,30 +272,43 @@ def generate(self, TOKENS_TO_IGNORE_DEFAULT) drop_tokens = bool(config.get(DROP_TOKENS_KEY, DROP_TOKENS_DEFAULT)) pred_key = config.get(PREDICTION_KEY, "") + regression_thresh = float(config.get(REGRESSION_THRESH_KEY, + REGRESSION_THRESH_DEFAULT)) assert model is not None, "Please provide a model for this generator." input_spec = model.input_spec() output_spec = model.output_spec() assert pred_key, "Please provide the prediction key" assert pred_key in output_spec, "Invalid prediction key" - assert isinstance(output_spec[pred_key], types.MulticlassPreds), ( - "Only classification models are supported") + + is_regression = False + if isinstance(output_spec[pred_key], types.RegressionScore): + is_regression = True + else: + assert isinstance(output_spec[pred_key], types.MulticlassPreds), ( + "Only classification or regression models are supported") logging.info(r"W3lc0m3 t0 H0tFl1p \o/") logging.info("Original example: %r", example) - # Find gradient fields to use for HotFlip - grad_fields = self.find_fields(output_spec, types.TokenGradients, - None) - if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test - logging.info("No gradient fields found. Cannot use HotFlip. :-(") - return [] # Cannot generate examples without gradients. - logging.info("Found gradient fields for HotFlip use: %s", str(grad_fields)) - # Get model outputs. - logging.info("Performing a forward/backward pass on the input example.") + logging.info("Performing a forward pass on the input example.") orig_output = list(model.predict([example]))[0] logging.info(orig_output.keys()) + # Get tokens (corresponding to each text input field) and corresponding + # gradients. + tokens_and_gradients = self._get_tokens_and_gradients( + output_spec, orig_output) + if len(tokens_and_gradients) == 0: # pylint: disable=g-explicit-length-test + logging.info("No token or gradient fields found. Cannot use HotFlip. :-(") + return [] # Cannot generate examples without tokens or gradients. + + # Copy tokens into input example. + example = copy.deepcopy(example) + for token_field, v in tokens_and_gradients.items(): + tokens, _ = v + example[token_field] = tokens + # Get model word embeddings and vocab. inv_vocab, embedding_matrix = model.get_embedding_table() assert len(inv_vocab) == embedding_matrix.shape[0], ( @@ -245,28 +316,26 @@ def generate(self, logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab), embedding_matrix.shape) - # TODO(lit-team): use only 1 sequence as input (configurable in UI). successful_cfs = [] - successful_positions = [] + # TODO(lit-team): use only 1 sequence as input (configurable in UI). # TODO(lit-team): Refactor the following code so that it's not so deeply # nested (and easier to track loop state). - for grad_field in grad_fields: - # Get tokens, token gradients and token embeddings. - token_field = output_spec[grad_field].align # pytype: disable=attribute-error - tokens = orig_output[token_field] - grads = orig_output[grad_field] - token_emb_fields = self.find_fields(output_spec, types.TokenEmbeddings, - token_field) - assert len(token_emb_fields) == 1, "Found multiple token embeddings" - token_embs = orig_output[token_emb_fields[0]] - assert token_embs.shape[0] == grads.shape[0] - + for token_field, v in tokens_and_gradients.items(): + tokens, grads = v + text_field = input_spec[token_field].parent # pytype: disable=attribute-error + logging.info("Identifying Hotflips for input field: %s", str(text_field)) replacement_tokens = None if not drop_tokens: + direction = -1 + if is_regression: + # We want the replacements to increase the prediction score if the + # original score is below the threshold, and decrease otherwise. + direction = (1 if orig_output[pred_key] <= regression_thresh else -1) replacement_tokens = self._get_replacement_tokens( - embedding_matrix, inv_vocab, grads) + embedding_matrix, inv_vocab, grads, direction) logging.info("Replacement tokens: %s", replacement_tokens) + successful_positions = [] for token_idxs in self._gen_token_idxs_to_flip( tokens, grads, max_flips, tokens_to_ignore, drop_tokens): if len(successful_cfs) >= num_examples: @@ -278,28 +347,20 @@ def generate(self, continue # Create counterfactual. - cf = copy.deepcopy(example) - modified_tokens = self._flip_tokens( - tokens, token_idxs, drop_tokens, replacement_tokens) - cf[token_field] = modified_tokens - # TODO(iftenney, bastings): call a model-provided detokenizer here? - # Though in general tokenization isn't invertible and it's possible for - # HotFlip to produce wordpiece sequences that don't correspond to any - # input string. - text_field = input_spec[token_field].parent # pytype: disable=attribute-error - cf[text_field] = " ".join(modified_tokens) - - # Get model outputs. + cf = self._create_cf(example, token_field, text_field, tokens, + token_idxs, drop_tokens, replacement_tokens) + # Obtain model prediction. + logging.info("Performing a forward pass on counterfactual candidate.") cf_output = list(model.predict([cf]))[0] - if self._is_hotflip(cf_output, orig_output, pred_key): - # Hotflip found - # Update label if multi-class prediction. - # TODO(lit-dev): provide a general system for handling labels on - # generated examples. - self._update_label(cf, cf_output, output_spec, pred_key) - + if self._is_hotflip(cf_output, orig_output, pred_key, + is_regression, regression_thresh): + # Hotflip found! successful_cfs.append(cf) successful_positions.append(set(token_idxs)) + if not is_regression: + # Update label if multi-class prediction. + # TODO(lit-dev): provide a general system for handling labels on + # generated examples. + self._update_label(cf, cf_output, output_spec, pred_key) return successful_cfs - diff --git a/lit_nlp/components/hotflip_test.py b/lit_nlp/components/hotflip_test.py index 8461abad..07367288 100644 --- a/lit_nlp/components/hotflip_test.py +++ b/lit_nlp/components/hotflip_test.py @@ -34,100 +34,199 @@ class ModelBasedHotflipTest(absltest.TestCase): def setUp(self): super(ModelBasedHotflipTest, self).setUp() self.hotflip = hotflip.HotFlip() - self.model = glue_models.SST2Model(BERT_TINY_PATH) - self.pred_key = 'probas' - self.config = {hotflip.PREDICTION_KEY: self.pred_key} + + # Classification model that clasifies a given input sentence. + self.classification_model = glue_models.SST2Model(BERT_TINY_PATH) + self.classification_config = {hotflip.PREDICTION_KEY: 'probas'} + + # Regression model determining similarity between two input sentences. + self.regression_model = glue_models.STSBModel(STSB_PATH) + self.regression_config = {hotflip.PREDICTION_KEY: 'score'} def test_find_fields(self): - fields = self.hotflip.find_fields(self.model.output_spec(), + fields = self.hotflip.find_fields(self.classification_model.output_spec(), lit_types.MulticlassPreds) self.assertEqual(['probas'], fields) - fields = self.hotflip.find_fields(self.model.output_spec(), + fields = self.hotflip.find_fields(self.classification_model.output_spec(), lit_types.TokenGradients, 'tokens_sentence') self.assertEqual(['token_grad_sentence'], fields) def test_find_fields_empty(self): - fields = self.hotflip.find_fields(self.model.output_spec(), + fields = self.hotflip.find_fields(self.classification_model.output_spec(), lit_types.TokenGradients, 'input_embs_sentence') self.assertEmpty(fields) def test_hotflip_num_ex(self): ex = {'sentence': 'this long movie is terrible.'} - self.config[hotflip.NUM_EXAMPLES_KEY] = 0 + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 0 self.assertEmpty( - self.hotflip.generate(ex, self.model, None, self.config)) - self.config[hotflip.NUM_EXAMPLES_KEY] = 1 + self.hotflip.generate(ex, self.classification_model, None, + self.classification_config)) + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 1 + self.assertLen( + self.hotflip.generate(ex, self.classification_model, None, + self.classification_config), 1) + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 2 self.assertLen( - self.hotflip.generate(ex, self.model, None, self.config), 1) - self.config[hotflip.NUM_EXAMPLES_KEY] = 2 + self.hotflip.generate(ex, self.classification_model, None, + self.classification_config), 2) + + def test_hotflip_num_ex_multi_input(self): + ex = {'sentence1': 'this long movie is terrible.', + 'sentence2': 'this short movie is great.'} + self.regression_config[hotflip.NUM_EXAMPLES_KEY] = 2 + thresh = 2 + self.regression_config[hotflip.REGRESSION_THRESH_KEY] = thresh self.assertLen( - self.hotflip.generate(ex, self.model, None, self.config), 2) + self.hotflip.generate(ex, self.regression_model, None, + self.regression_config), 2) def test_hotflip_freeze_tokens(self): ex = {'sentence': 'this long movie is terrible.'} - self.config[hotflip.NUM_EXAMPLES_KEY] = 10 - self.config[hotflip.TOKENS_TO_IGNORE_KEY] = ['terrible'] - generated = self.hotflip.generate( - ex, self.model, None, self.config) - for gen in generated: - self.assertLen(gen['tokens_sentence'], 6) - self.assertEqual('terrible', gen['tokens_sentence'][4]) - - self.config[hotflip.NUM_EXAMPLES_KEY] = 10 - self.config[hotflip.TOKENS_TO_IGNORE_KEY] = ['terrible', 'long'] - generated = self.hotflip.generate( - ex, self.model, None, self.config) - for gen in generated: - self.assertEqual('long', gen['tokens_sentence'][1]) - self.assertEqual('terrible', gen['tokens_sentence'][4]) + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 10 + self.classification_config[hotflip.TOKENS_TO_IGNORE_KEY] = ['terrible'] + cfs = self.hotflip.generate( + ex, self.classification_model, None, self.classification_config) + for cf in cfs: + tokens = cf['tokens_sentence'] + self.assertLen(tokens, 6) + self.assertEqual('terrible', tokens[4]) + + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 10 + self.classification_config[hotflip.TOKENS_TO_IGNORE_KEY] = ['long', + 'terrible'] + cfs = self.hotflip.generate( + ex, self.classification_model, None, self.classification_config) + for cf in cfs: + tokens = cf['tokens_sentence'] + self.assertEqual('terrible', tokens[4]) + self.assertEqual('long', tokens[1]) + + def test_hotflip_freeze_tokens_multi_input(self): + ex = {'sentence1': 'this long movie is terrible.', + 'sentence2': 'this long movie is great.'} + self.regression_config[hotflip.NUM_EXAMPLES_KEY] = 10 + thresh = 2 + self.regression_config[hotflip.REGRESSION_THRESH_KEY] = thresh + self.regression_config[hotflip.TOKENS_TO_IGNORE_KEY] = ['long', 'terrible'] + cfs = self.hotflip.generate(ex, self.regression_model, None, + self.regression_config) + for cf in cfs: + tokens1 = cf['tokens_sentence1'] + tokens2 = cf['tokens_sentence2'] + self.assertEqual('terrible', tokens1[4]) + self.assertEqual('long', tokens1[1]) + self.assertEqual('long', tokens2[1]) def test_hotflip_drops(self): ex = {'sentence': 'this long movie is terrible.'} - self.config[hotflip.NUM_EXAMPLES_KEY] = 1 - self.config[hotflip.DROP_TOKENS_KEY] = True - generated = self.hotflip.generate( - ex, self.model, None, self.config) - self.assertLess(len(generated[0]['tokens_sentence']), 6) + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 1 + self.classification_config[hotflip.DROP_TOKENS_KEY] = True + cfs = self.hotflip.generate( + ex, self.classification_model, None, self.classification_config) + self.assertLess(len(list(cfs)[0]['tokens_sentence']), 6) + + def test_hotflip_drops_multi_input(self): + ex = {'sentence1': 'this long movie is terrible.', + 'sentence2': 'this short movie is great.'} + self.regression_config[hotflip.NUM_EXAMPLES_KEY] = 10 + thresh = 2 + self.regression_config[hotflip.REGRESSION_THRESH_KEY] = thresh + self.regression_config[hotflip.DROP_TOKENS_KEY] = True + cfs = self.hotflip.generate(ex, self.regression_model, None, + self.regression_config) + for cf in cfs: + self.assertLessEqual(len(cf['tokens_sentence1']), 6) + self.assertLessEqual(len(cf['tokens_sentence2']), 6) def test_hotflip_max_flips(self): ex = {'sentence': 'this long movie is terrible.'} - self.config[hotflip.NUM_EXAMPLES_KEY] = 1 - self.config[hotflip.MAX_FLIPS_KEY] = 1 - generated = self.hotflip.generate( - ex, self.model, None, self.config) - self.assertLen(generated, 1) - - num_flipped = 0 - pred = list(self.model.predict([ex]))[0] - pred_tokens = pred['tokens_sentence'] - gen_tokens = generated[0]['tokens_sentence'] - for i in range(len(gen_tokens)): - if gen_tokens[i] != pred_tokens[i]: - num_flipped += 1 - self.assertEqual(1, num_flipped) + ex_output = list(self.classification_model.predict([ex]))[0] + ex_tokens = ex_output['tokens_sentence'] - ex = {'sentence': 'this long movie is terrible and horrible.'} - self.config[hotflip.NUM_EXAMPLES_KEY] = 1 - self.config[hotflip.MAX_FLIPS_KEY] = 1 - generated = self.hotflip.generate( - ex, self.model, None, self.config) - self.assertEmpty(generated) + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 1 + self.classification_config[hotflip.MAX_FLIPS_KEY] = 1 + cfs = self.hotflip.generate( + ex, self.classification_model, None, self.classification_config) + cf_tokens = list(cfs)[0]['tokens_sentence'] + self.assertEqual(1, sum([1 for i, t in enumerate(cf_tokens) + if t != ex_tokens[i]])) - def test_hotflip_changes_pred(self): + ex = {'sentence': 'this long movie is terrible and horrible.'} + self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 1 + self.classification_config[hotflip.MAX_FLIPS_KEY] = 1 + cfs = self.hotflip.generate( + ex, self.classification_model, None, self.classification_config) + self.assertEmpty(cfs) + + def test_hotflip_max_flips_multi_input(self): + ex = {'sentence1': 'this long movie is terrible.', + 'sentence2': 'this short movie is great.'} + ex_output = list(self.regression_model.predict([ex]))[0] + ex_tokens1 = ex_output['tokens_sentence1'] + ex_tokens2 = ex_output['tokens_sentence2'] + + self.regression_config[hotflip.NUM_EXAMPLES_KEY] = 20 + thresh = 2 + self.regression_config[hotflip.REGRESSION_THRESH_KEY] = thresh + self.regression_config[hotflip.MAX_FLIPS_KEY] = 1 + cfs = self.hotflip.generate(ex, self.regression_model, None, + self.regression_config) + for cf in cfs: + # Number of flips in each field should be no more than MAX_FLIPS. + cf_tokens1 = cf['tokens_sentence1'] + cf_tokens2 = cf['tokens_sentence2'] + self.assertLessEqual(sum([1 for i, t in enumerate(cf_tokens1) + if t != ex_tokens1[i]]), 1) + self.assertLessEqual(sum([1 for i, t in enumerate(cf_tokens2) + if t != ex_tokens2[i]]), 1) + + def test_hotflip_only_flip_one_field(self): + ex = {'sentence1': 'this long movie is terrible.', + 'sentence2': 'this short movie is great.'} + self.regression_config[hotflip.NUM_EXAMPLES_KEY] = 10 + thresh = 2 + self.regression_config[hotflip.REGRESSION_THRESH_KEY] = thresh + cfs = self.hotflip.generate(ex, self.regression_model, None, + self.regression_config) + for cf in cfs: + self.assertTrue( + (cf['sentence1'] == ex['sentence1']) or + (cf['sentence2'] == ex['sentence2'])) + + def test_hotflip_changes_pred_class(self): + # Test with a classification model. ex = {'sentence': 'this long movie is terrible.'} - pred = list(self.model.predict([ex]))[0] - pred_class = str(np.argmax(pred['probas'])) + ex_output = list(self.classification_model.predict([ex]))[0] + pred_class = str(np.argmax(ex_output['probas'])) self.assertEqual('0', pred_class) - generated = self.hotflip.generate(ex, self.model, None, self.config) - for gen in generated: - self.assertEqual('1', gen['label']) + cfs = self.hotflip.generate(ex, self.classification_model, None, + self.classification_config) + cf_outputs = self.classification_model.predict(cfs) + for cf_output in cf_outputs: + self.assertNotEqual(np.argmax(ex_output['probas']), + np.argmax(cf_output['probas'])) + + def test_hotflip_changes_regression_score(self): + ex = {'sentence1': 'this long movie is terrible.', + 'sentence2': 'this short movie is great.'} + self.regression_config[hotflip.NUM_EXAMPLES_KEY] = 2 + ex_output = list(self.regression_model.predict([ex]))[0] + thresh = 2 + self.regression_config[hotflip.REGRESSION_THRESH_KEY] = thresh + cfs = self.hotflip.generate(ex, self.regression_model, None, + self.regression_config) + cf_outputs = self.regression_model.predict(cfs) + for cf_output in cf_outputs: + self.assertNotEqual((ex_output['score'] <= thresh), + (cf_output['score'] <= thresh)) def test_hotflip_fails_without_pred_key(self): ex = {'sentence': 'this long movie is terrible.'} with self.assertRaises(AssertionError): - self.hotflip.generate(ex, self.model, None, None) + self.hotflip.generate(ex, self.classification_model, None, None) if __name__ == '__main__': From 2a39d949590ef9c01954184db684354dfc4463bd Mon Sep 17 00:00:00 2001 From: James Wexler Date: Tue, 11 May 2021 12:56:59 -0700 Subject: [PATCH 044/213] Style confusion matrix. - Add total columns/rows. - Add percentages. - Improve visuals, based on mocks. PiperOrigin-RevId: 373214531 --- lit_nlp/client/elements/data_matrix.css | 37 +- lit_nlp/client/elements/data_matrix.ts | 323 +++++++++++------- .../client/modules/confusion_matrix_module.ts | 7 + 3 files changed, 239 insertions(+), 128 deletions(-) diff --git a/lit_nlp/client/elements/data_matrix.css b/lit_nlp/client/elements/data_matrix.css index 9b072240..b66f2505 100644 --- a/lit_nlp/client/elements/data_matrix.css +++ b/lit_nlp/client/elements/data_matrix.css @@ -1,9 +1,27 @@ .cell { + font-family: 'Roboto Mono', monospace; cursor: pointer; - background: #e6e6e6; padding: 2px 4px; border-radius: 2px; - text-align: center; + border: 2px solid transparent; +} + +.cell-container { + display: flex; + min-width: 90px; +} + +.percentage { + width: 50%; + min-width: 50%; + max-width: 50%; + text-align: end; + padding-right: 6px; +} + +.val { + width: 50%; + text-align: end; } .header-cell { @@ -19,16 +37,15 @@ text-align: end; } -.cell.diagonal { - background: #b5bcc3; -} - -.cell.selected { - background: #ffe839; +.total-title-cell { + padding: 2px 4px; + color: #5f6368; } -.cell:hover { - background: #faf49f; +.total-cell { + padding: 2px 4px; + font-family: 'Roboto Mono', monospace; + color: #5f6368; } .axis-title { diff --git a/lit_nlp/client/elements/data_matrix.ts b/lit_nlp/client/elements/data_matrix.ts index 65fac036..25f0f532 100644 --- a/lit_nlp/client/elements/data_matrix.ts +++ b/lit_nlp/client/elements/data_matrix.ts @@ -15,11 +15,13 @@ * limitations under the License. */ +import * as d3 from 'd3'; import '@material/mwc-icon-button-toggle'; // tslint:disable:no-new-decorators import {customElement, html, LitElement, property} from 'lit-element'; import {classMap} from 'lit-html/directives/class-map'; -import {observable} from 'mobx'; +import {styleMap} from 'lit-html/directives/style-map'; +import {computed, observable} from 'mobx'; import {styles} from './data_matrix.css'; @@ -48,132 +50,183 @@ export class DataMatrix extends LitElement { @property({type: Array}) colLabels: string[] = []; @property({type: Array}) rowLabels: string[] = []; - render() { - if (this.matrixCells.length === 0) { - return null; - } - - const rowsWithNonZeroCounts = new Set(); - const colsWithNonZeroCounts = new Set(); + @computed get totalIds(): number { + let totalIds = 0; for (let rowIndex = 0; rowIndex < this.matrixCells.length; rowIndex++) { const row = this.matrixCells[rowIndex]; for (let colIndex = 0; colIndex < row.length; colIndex++) { const cell = row[colIndex]; - if (cell.ids.length > 0) { - rowsWithNonZeroCounts.add(this.rowLabels[rowIndex]); - colsWithNonZeroCounts.add(this.colLabels[colIndex]); - } + totalIds += cell.ids.length; } } + return totalIds; + } - // Render a clickable column header cell. - const renderColHeader = (label: string, colIndex: number) => { - const onColClick = () => { - const cells = this.matrixCells.map((cells) => cells[colIndex]); - const allSelected = cells.every((cell) => cell.selected); - cells.forEach((cell) => { - cell.selected = !allSelected; - }); - this.updateSelection(); - this.requestUpdate(); - }; - if (this.hideEmptyLabels && !colsWithNonZeroCounts.has(label)) { - return null; - } - const classes = classMap({ - 'header-cell': true, - 'align-bottom': this.verticalColumnLabels, - 'label-vertical': this.verticalColumnLabels - }); - // clang-format off - return html` - -
${label}
- - `; - // clang-format on - }; + @computed + get colorScale() { + return d3.scaleLinear() + .domain([0, this.totalIds]) + // Need to cast to numbers due to d3 typing. + .range(["#F5F5F5" as unknown as number, "#006064" as unknown as number]); + } - // Render a clickable confusion matrix cell. - const renderCell = (rowIndex: number, colIndex: number) => { - if (this.matrixCells[rowIndex]?.[colIndex] == null) { - return null; + private updateSelection() { + let ids: string[] = []; + for (const cellInfo of this.matrixCells.flat()) { + if (cellInfo.selected) { + ids = ids.concat(cellInfo.ids); } - const cellInfo = this.matrixCells[rowIndex][colIndex]; - const cellClasses = classMap({ - cell: true, - selected: cellInfo.selected, - diagonal: colIndex === rowIndex, - }); - const onCellClick = () => { - cellInfo.selected = !cellInfo.selected; - this.updateSelection(); - this.requestUpdate(); - }; - if (this.hideEmptyLabels && - !colsWithNonZeroCounts.has(this.colLabels[colIndex])) { - return null; + } + const event = new CustomEvent('matrix-selection', { + detail: { + ids, + } + }); + this.dispatchEvent(event); + } + + private renderColHeader(label: string, colIndex: number, + colsWithNonZeroCounts: Set) { + const onColClick = () => { + const cells = this.matrixCells.map((cells) => cells[colIndex]); + const allSelected = cells.every((cell) => cell.selected); + for (const cell of cells) { + cell.selected = !allSelected; } - return html` - - ${cellInfo.ids.length} - `; + this.updateSelection(); + this.requestUpdate(); + }; + if (this.hideEmptyLabels && !colsWithNonZeroCounts.has(label)) { + return null; + } + const classes = classMap({ + 'header-cell': true, + 'align-bottom': this.verticalColumnLabels, + 'label-vertical': this.verticalColumnLabels + }); + // clang-format off + return html` + +
${label}
+ + `; + // clang-format on + } + + // Render a clickable confusion matrix cell. + private renderCell(rowIndex: number, colIndex: number, + colsWithNonZeroCounts: Set) { + if (this.matrixCells[rowIndex]?.[colIndex] == null) { + return null; + } + const cellInfo = this.matrixCells[rowIndex][colIndex]; + const onCellClick = () => { + cellInfo.selected = !cellInfo.selected; + this.updateSelection(); + this.requestUpdate(); }; + if (this.hideEmptyLabels && + !colsWithNonZeroCounts.has(this.colLabels[colIndex])) { + return null; + } + const backgroundColor = this.colorScale(cellInfo.ids.length); + const percentage = cellInfo.ids.length / this.totalIds * 100; + const textColor = percentage > 50 ? 'white' : 'black'; + const border = cellInfo.selected ? + '2px solid #12B5CB' : '2px solid transparent'; + const cellStyle = styleMap({ + background: `${backgroundColor}`, + color: `${textColor}`, + border + }); + return html` + +
+
${percentage.toFixed(1)}%
+
(${cellInfo.ids.length})
+
+ `; + } - // Render a row of the confusion matrix, starting with the clickable - // row header. - const renderRow = (rowLabel: string, rowIndex: number) => { - const onRowClick = () => { - const cells = this.matrixCells[rowIndex]; - const allSelected = cells.every((cell) => cell.selected); - cells.forEach((cell) => { - cell.selected = !allSelected; - }); - this.updateSelection(); - this.requestUpdate(); - }; - if (this.hideEmptyLabels && !rowsWithNonZeroCounts.has(rowLabel)) { - return null; + // Render a row of the confusion matrix, starting with the clickable + // row header. + private renderRow(rowLabel: string, rowIndex: number, + rowsWithNonZeroCounts: Set, + colsWithNonZeroCounts: Set) { + const onRowClick = () => { + const cells = this.matrixCells[rowIndex]; + const allSelected = cells.every((cell) => cell.selected); + for (const cell of cells) { + cell.selected = !allSelected; } - // clang-format off - return html` - - ${rowIndex === 0 ? html` - -
${this.rowTitle}
- ` - : null} - - ${rowLabel} - - ${this.colLabels.map( - (colLabel, colIndex) => renderCell(rowIndex, colIndex))} - `; - // clang-format on + this.updateSelection(); + this.requestUpdate(); }; + if (this.hideEmptyLabels && !rowsWithNonZeroCounts.has(rowLabel)) { + return null; + } + let totalRowIds = 0; + for (const cell of this.matrixCells[rowIndex]) { + totalRowIds += cell.ids.length; + } + // clang-format off + return html` + + ${rowIndex === 0 ? html` + +
${this.rowTitle}
+ ` + : null} + + ${rowLabel} + + ${this.colLabels.map( + (colLabel, colIndex) => this.renderCell( + rowIndex, colIndex,colsWithNonZeroCounts))} + ${this.renderTotalCell(totalRowIds)} + `; + // clang-format on + } + + private renderTotalCell(num: number) { + const percentage = (num / this.totalIds * 100).toFixed(1); + return html` + +
+
${percentage}%
+
(${num})
+
+ `; + } + + private renderColTotalCell(colIndex: number) { + let totalColIds = 0; + for (const row of this.matrixCells) { + totalColIds += row[colIndex].ids.length; + } + return this.renderTotalCell(totalColIds); + } + private renderTotalRow(colsWithNonZeroCounts: Set) { // clang-format off return html` - - - - - - - - ${this.colLabels.map( - (colLabel, colIndex) => renderColHeader(colLabel, colIndex))} - - ${this.rowLabels.map( - (rowLabel, rowIndex) => renderRow(rowLabel, rowIndex))} -
${this.renderColumnRotateButton()} - ${this.colTitle} -
- `; + + Total + ${this.colLabels.map( + (colLabel, colIndex) => { + if (this.hideEmptyLabels && + !colsWithNonZeroCounts.has(this.colLabels[colIndex])) { + return null; + } + return this.renderColTotalCell(colIndex); + })} + + `; // clang-format on } - renderColumnRotateButton() { + private renderColumnRotateButton() { const toggleVerticalColumnLabels = () => { this.verticalColumnLabels = !this.verticalColumnLabels; this.requestUpdate(); @@ -192,19 +245,53 @@ export class DataMatrix extends LitElement { // clang-format on } - private updateSelection() { - let ids: string[] = []; - for (const cellInfo of this.matrixCells.flat()) { - if (cellInfo.selected) { - ids = ids.concat(cellInfo.ids); - } + render() { + if (this.matrixCells.length === 0) { + return null; } - const event = new CustomEvent('matrix-selection', { - detail: { - ids, + + const rowsWithNonZeroCounts = new Set(); + const colsWithNonZeroCounts = new Set(); + for (let rowIndex = 0; rowIndex < this.matrixCells.length; rowIndex++) { + const row = this.matrixCells[rowIndex]; + for (let colIndex = 0; colIndex < row.length; colIndex++) { + const cell = row[colIndex]; + if (cell.ids.length > 0) { + rowsWithNonZeroCounts.add(this.rowLabels[rowIndex]); + colsWithNonZeroCounts.add(this.colLabels[colIndex]); + } } - }); - this.dispatchEvent(event); + } + + const totalColumnClasses = classMap({ + 'total-title-cell': true, + 'align-bottom': this.verticalColumnLabels + }); + + // clang-format off + return html` + + + + + + + + ${this.colLabels.map( + (colLabel, colIndex) => this.renderColHeader( + colLabel, colIndex, colsWithNonZeroCounts))} + + + ${this.rowLabels.map( + (rowLabel, rowIndex) => this.renderRow( + rowLabel, rowIndex, rowsWithNonZeroCounts, + colsWithNonZeroCounts))} + ${this.renderTotalRow(colsWithNonZeroCounts)} +
${this.renderColumnRotateButton()} + ${this.colTitle} +
Total
+ `; + // clang-format on } } diff --git a/lit_nlp/client/modules/confusion_matrix_module.ts b/lit_nlp/client/modules/confusion_matrix_module.ts index da4706ce..47621bf0 100644 --- a/lit_nlp/client/modules/confusion_matrix_module.ts +++ b/lit_nlp/client/modules/confusion_matrix_module.ts @@ -221,6 +221,13 @@ export class ConfusionMatrixModule extends LitModule { this.matrixCells = rowLabels.map(rowLabel => { return colLabels.map(colLabel => { + // If the rows and columns are the same feature but the cells are for + // different values of that feature, then by definition no examples can + // go into that cell. Handle this special case as the facetsDict below + // only handles a single value per feature. + if (colName === rowName && colLabel !== rowLabel) { + return {ids: [], selected: false}; + } // Find the bin corresponding to this row/column value combination. const facetsDict = {[colName]: colLabel, [rowName]: rowLabel}; const bin = bins[objToDictKey(facetsDict)]; From 88c50f04bafc13bd794ed1cb7045fb03057b58c2 Mon Sep 17 00:00:00 2001 From: James Wexler Date: Wed, 12 May 2021 10:17:44 -0700 Subject: [PATCH 045/213] Fix broken hotflip unit test. - Fix glue_models get_embedding_table to always return a correct vocab list. In OSS, the tokenizer vocab was a dict and not an OrderedDict so just taking its keys and assuming it is ordered by token ID was a faulty assumption. - Fix missing STSB_PATH loading in hotflip_test PiperOrigin-RevId: 373396843 --- lit_nlp/components/hotflip_test.py | 5 ++++- lit_nlp/examples/models/glue_models.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lit_nlp/components/hotflip_test.py b/lit_nlp/components/hotflip_test.py index 07367288..8ff428ff 100644 --- a/lit_nlp/components/hotflip_test.py +++ b/lit_nlp/components/hotflip_test.py @@ -24,9 +24,12 @@ BERT_TINY_PATH = 'https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz' # pylint: disable=line-too-long +STSB_PATH = 'https://storage.googleapis.com/what-if-tool-resources/lit-models/stsb_tiny.tar.gz' # pylint: disable=line-too-long import transformers BERT_TINY_PATH = transformers.file_utils.cached_path(BERT_TINY_PATH, -extract_compressed_file=True) + extract_compressed_file=True) +STSB_PATH = transformers.file_utils.cached_path(STSB_PATH, + extract_compressed_file=True) class ModelBasedHotflipTest(absltest.TestCase): diff --git a/lit_nlp/examples/models/glue_models.py b/lit_nlp/examples/models/glue_models.py index f4479808..cd47cb00 100644 --- a/lit_nlp/examples/models/glue_models.py +++ b/lit_nlp/examples/models/glue_models.py @@ -63,6 +63,8 @@ def _load_model(self, model_name_or_path): """Load model. Can be overridden for testing.""" self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_name_or_path) + self.vocab = self.tokenizer.convert_ids_to_tokens( + range(len(self.tokenizer))) model_config = transformers.AutoConfig.from_pretrained( model_name_or_path, num_labels=1 if self.is_regression else len(self.config.labels), @@ -322,8 +324,7 @@ def max_minibatch_size(self): return self.config.inference_batch_size def get_embedding_table(self): - vocab = list(self.tokenizer.vocab.keys()) - return vocab, self.model.bert.embeddings.word_embeddings.numpy() + return self.vocab, self.model.bert.embeddings.word_embeddings.numpy() def predict_minibatch(self, inputs: Iterable[JsonDict]): # Use watch_accessed_variables to save memory by having the tape do nothing From f071b4e686dd764ada8a49739882384494de51eb Mon Sep 17 00:00:00 2001 From: Googler Date: Thu, 13 May 2021 09:36:05 -0700 Subject: [PATCH 046/213] Extend TFXModel to serialize TF examples and pass them as inputs to the model since TFX models expect this as input. Also add dependency for tf_text, which is commonly used for NLP TF models. PiperOrigin-RevId: 373594674 --- lit_nlp/components/tfx_model.py | 38 +++++++++++++++++++++++++--- lit_nlp/components/tfx_model_test.py | 26 ++++++++++++------- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/lit_nlp/components/tfx_model.py b/lit_nlp/components/tfx_model.py index 86d1a617..b621bca5 100644 --- a/lit_nlp/components/tfx_model.py +++ b/lit_nlp/components/tfx_model.py @@ -5,10 +5,30 @@ from lit_nlp.api import types as lit_types import tensorflow as tf +import tensorflow_text as tf_text # pylint: disable=unused-import _SERVING_DEFAULT_SIGNATURE = 'serving_default' +# TODO(b/188036366): Revisit the assumed mapping between input values and +# TF.Examples. +def _inputs_to_serialized_example(input_dict: lit_types.JsonDict): + """Converts the input dictionary to a serialized tf example.""" + feature_dict = {} + for k, v in input_dict.items(): + if not isinstance(v, list): + v = [v] + if isinstance(v[0], int): + feature_dict[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) + elif isinstance(v[0], float): + feature_dict[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v)) + else: + feature_dict[k] = tf.train.Feature( + bytes_list=tf.train.BytesList(value=[bytes(i, 'utf-8') for i in v])) + result = tf.train.Example(features=tf.train.Features(feature=feature_dict)) + return result.SerializeToString() + + class TFXModel(lit_model.Model): """Wrapper for querying a TFX-generated SavedModel.""" @@ -20,12 +40,22 @@ def __init__(self, path: str, input_spec: lit_types.Spec, self._input_spec = input_spec self._output_spec = output_spec - def predict_minibatch(self, inputs: List[lit_types.JsonDict]) -> Iterator[ - lit_types.JsonDict]: + def predict_minibatch( + self, inputs: List[lit_types.JsonDict]) -> Iterator[lit_types.JsonDict]: for i in inputs: + filtered_inputs = {k: v for k, v in i.items() if k in self._input_spec} result = self._model.signatures[self._signature]( - **{k: tf.reshape(tf.constant(v), [1, -1]) for k, v in i.items()}) - result = {k: tf.squeeze(v).numpy() for k, v in result.items()} + tf.constant([_inputs_to_serialized_example(filtered_inputs)])) + result = { + k: tf.squeeze(v).numpy().tolist() + for k, v in result.items() + if k in self._output_spec + } + for k, v in result.items(): + # If doing Multiclass Prediction for a Binary Classifier. + if (isinstance(self._output_spec[k], lit_types.MulticlassPreds) and + not isinstance(v, list)): + result[k] = [1 - v, v] yield result def input_spec(self) -> lit_types.Spec: diff --git a/lit_nlp/components/tfx_model_test.py b/lit_nlp/components/tfx_model_test.py index 9cb24cc5..ee6d8b47 100644 --- a/lit_nlp/components/tfx_model_test.py +++ b/lit_nlp/components/tfx_model_test.py @@ -3,7 +3,6 @@ from lit_nlp.api import types as lit_types from lit_nlp.components import tfx_model -import numpy as np import tensorflow as tf @@ -12,9 +11,14 @@ class TfxModelTest(tf.test.TestCase): def setUp(self): super(TfxModelTest, self).setUp() self._path = tempfile.mkdtemp() - input_layer = tf.keras.layers.Input(shape=(1,), dtype=tf.float32, - name='input_0') - output_layer = tf.keras.layers.Dense(1, name='output_0')(input_layer) + input_layer = tf.keras.layers.Input( + shape=(1), dtype=tf.string, name='example') + parsed_input = tf.io.parse_example( + tf.reshape(input_layer, [-1]), + {'input_0': tf.io.FixedLenFeature([1], dtype=tf.float32)}) + output_layer = tf.keras.layers.Dense( + 1, name='output_0')( + parsed_input['input_0']) model = tf.keras.Model(input_layer, output_layer) model.compile( optimizer=tf.keras.optimizers.Adam(lr=.001), @@ -23,15 +27,19 @@ def setUp(self): def testTfxModel(self): input_spec = {'input_0': lit_types.Scalar()} - output_spec = {'output_0': lit_types.RegressionScore(parent='input_0')} - lit_model = tfx_model.TFXModel(self._path, - input_spec=input_spec, - output_spec=output_spec) + output_spec = { + 'output_0': + lit_types.MulticlassPreds(vocab=['0', '1'], parent='input_0') + } + lit_model = tfx_model.TFXModel( + self._path, input_spec=input_spec, output_spec=output_spec) result = list(lit_model.predict([{'input_0': 0.5}])) self.assertLen(result, 1) result = result[0] self.assertListEqual(list(result.keys()), ['output_0']) - self.assertIsInstance(result['output_0'], np.float32) + self.assertLen(result['output_0'], 2) + self.assertIsInstance(result['output_0'][0], float) + self.assertIsInstance(result['output_0'][1], float) self.assertDictEqual(lit_model.input_spec(), input_spec) self.assertDictEqual(lit_model.output_spec(), output_spec) From 4ca89d712f2cba46231d0586a3d0e6dacc60bbd6 Mon Sep 17 00:00:00 2001 From: James Wexler Date: Tue, 18 May 2021 06:07:58 -0700 Subject: [PATCH 047/213] Fix maximization overlays. - Updated color for global settings and maximized modules overlay. - Clicking outside maximized modules will bring it back to normal size. PiperOrigin-RevId: 374400028 --- lit_nlp/client/core/widget_group.css | 17 ++++++++++--- lit_nlp/client/core/widget_group.ts | 29 +++++++++++++++++----- lit_nlp/client/modules/global_settings.css | 3 +-- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/lit_nlp/client/core/widget_group.css b/lit_nlp/client/core/widget_group.css index 05ecfaf4..9b95d61c 100644 --- a/lit_nlp/client/core/widget_group.css +++ b/lit_nlp/client/core/widget_group.css @@ -12,15 +12,14 @@ } :host([maximized]) { - margin: 15px 45px; + margin: 0; padding: 0; width: calc(100vw - 90px) !important; height: calc(100vh - 130px); position: fixed; z-index: 2; - left: 0; - top: 82px; - box-shadow: rgba(0, 0, 0, 0.14) 0px 2px 2px 0px, rgba(0, 0, 0, 0.2) 0px 3px 1px -2px, rgba(0, 0, 0, 0.12) 0px 1px 5px 0px; + /** Top set to place the maximized module right below the toolbars. */ + top: 76px; } .wrapper { @@ -37,6 +36,16 @@ padding: 0; } +.outside { + width: 100%; + height: 100%; +} + +.outside.maximized { + padding: 15px 45px; + background: rgba(4, 29, 51, .47); +} + .holder { height: calc(100% - 28px); display: flex; diff --git a/lit_nlp/client/core/widget_group.ts b/lit_nlp/client/core/widget_group.ts index 73ebff5e..cc68301e 100644 --- a/lit_nlp/client/core/widget_group.ts +++ b/lit_nlp/client/core/widget_group.ts @@ -146,6 +146,11 @@ export class WidgetGroup extends LitElement { host.style.setProperty('--width', width); host.style.setProperty('--min-width', width); + const outsideClasses = classMap({ + 'outside': true, + 'maximized': this.maximized, + }); + const wrapperClasses = classMap({ 'wrapper': true, 'minimized': this.minimized, @@ -166,15 +171,27 @@ export class WidgetGroup extends LitElement { } else { widgetStyle['height'] = `${100 / configGroup.length}%`; } + // For clicks on the maximized-module darkened background, undo the + // module maximization. + const onBackgroundClick = () => { + this.maximized = false; + }; + // A listener to stop clicks on a maximized module from causing the + // background click listener from firing. + const onWrapperClick = (e: Event) => { + e.stopPropagation(); + }; // clang-format off return html` -
- ${this.renderHeader(configGroup)} -
- ${configGroup.map(config => this.renderModule(config, widgetStyle))} - ${this.renderExpander()} +
+
+ ${this.renderHeader(configGroup)} +
+ ${configGroup.map(config => this.renderModule(config, widgetStyle))} + ${this.renderExpander()} +
-
+
`; // clang-format on } diff --git a/lit_nlp/client/modules/global_settings.css b/lit_nlp/client/modules/global_settings.css index 699236fe..210d13f0 100644 --- a/lit_nlp/client/modules/global_settings.css +++ b/lit_nlp/client/modules/global_settings.css @@ -6,12 +6,11 @@ } #overlay { - background-color: black; + background: rgba(4, 29, 51, .47); width: 100vw; height: 100vh; top: 0; left: 0; - opacity: .2; transition: opacity 250ms; } From 44a0d5e95541f94eb086e98b8070452dfec62b59 Mon Sep 17 00:00:00 2001 From: Ankur Taly Date: Tue, 18 May 2021 09:59:49 -0700 Subject: [PATCH 048/213] Hotflip: Enable "drop tokens" when gradients and embeddings are not available. PiperOrigin-RevId: 374443213 --- lit_nlp/components/hotflip.py | 67 +++++++++++++++++++----------- lit_nlp/components/hotflip_test.py | 43 ++++++++++++++++++- 2 files changed, 85 insertions(+), 25 deletions(-) diff --git a/lit_nlp/components/hotflip.py b/lit_nlp/components/hotflip.py index eee7a37d..a21a704e 100644 --- a/lit_nlp/components/hotflip.py +++ b/lit_nlp/components/hotflip.py @@ -79,10 +79,10 @@ class HotFlip(lit_components.Generator): """ def find_fields( - self, output_spec: Spec, typ: Type[types.LitType], + self, spec: Spec, typ: Type[types.LitType], align_field: Optional[Text] = None) -> List[Text]: # Find fields of provided 'typ'. - fields = utils.find_spec_keys(output_spec, typ) + fields = utils.find_spec_keys(spec, typ) if align_field is None: return fields @@ -90,24 +90,30 @@ def find_fields( # Only return fields that are aligned to fields with name specified by # align_field. return [f for f in fields - if getattr(output_spec[f], "align", None) == align_field] + if getattr(spec[f], "align", None) == align_field] def _get_tokens_and_gradients(self, + input_spec: JsonDict, output_spec: JsonDict, output: JsonDict): """Returns a dictionary mapping token fields to tokens and gradients.""" - # Find gradient fields - grad_fields = self.find_fields(output_spec, types.TokenGradients, - None) - if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test + # Find token fields + token_fields = [key + for key in utils.find_spec_keys(input_spec, types.Tokens) + if input_spec[key].is_compatible(output_spec.get(key))] + + if len(token_fields) == 0: # pylint: disable=g-explicit-length-test return {} ret = {} - for grad_field in grad_fields: + for token_field in token_fields: # Get tokens, token gradients and token embeddings. - token_field = output_spec[grad_field].align # pytype: disable=attribute-error tokens = output[token_field] - grads = output[grad_field] + grad_fields = self.find_fields(output_spec, types.TokenGradients, + token_field) + assert len(grad_fields) <= 1, ( + f"Multiple gradients found for {token_field}") + grads = output[grad_fields[0]] if grad_fields else None ret[token_field] = [tokens, grads] return ret @@ -135,7 +141,7 @@ def _subset_exists(self, cand_set, sets): def _gen_token_idxs_to_flip( self, tokens: List[str], - token_grads: np.ndarray, + token_grads: Optional[np.ndarray], max_flips: int, tokens_to_ignore: List[str], drop_tokens: bool = False) -> Iterator[Tuple[int, ...]]: @@ -152,13 +158,14 @@ def _gen_token_idxs_to_flip( # consider combinations by ordering tokens by gradient L2 in order to # prioritize flipping tokens that may have the largest impact on the # prediction. - token_grads_l2 = np.sum(token_grads * token_grads, axis=-1) - # TODO(ataly, bastings): Consider sorting by attributions (either - # Integrated Gradients or Shapley values). - token_idxs_sorted_by_grads = np.argsort(token_grads_l2)[::-1] - token_idxs_to_flip = [idx for idx in token_idxs_sorted_by_grads + token_idxs = np.arange(len(tokens)) + if token_grads is not None: + token_grads_l2 = np.sum(token_grads * token_grads, axis=-1) + # TODO(ataly, bastings): Consider sorting by attributions (either + # Integrated Gradients or Shapley values). + token_idxs = np.argsort(token_grads_l2)[::-1] + token_idxs_to_flip = [idx for idx in token_idxs if tokens[idx] not in tokens_to_ignore] - # If the number of tokens considered for flipping is larger than # MAX_FLIPPABLE_TOKENS we only consider the top tokens. token_idxs_to_flip = token_idxs_to_flip[:MAX_FLIPPABLE_TOKENS] @@ -298,7 +305,7 @@ def generate(self, # Get tokens (corresponding to each text input field) and corresponding # gradients. tokens_and_gradients = self._get_tokens_and_gradients( - output_spec, orig_output) + input_spec, output_spec, orig_output) if len(tokens_and_gradients) == 0: # pylint: disable=g-explicit-length-test logging.info("No token or gradient fields found. Cannot use HotFlip. :-(") return [] # Cannot generate examples without tokens or gradients. @@ -309,12 +316,19 @@ def generate(self, tokens, _ = v example[token_field] = tokens - # Get model word embeddings and vocab. - inv_vocab, embedding_matrix = model.get_embedding_table() - assert len(inv_vocab) == embedding_matrix.shape[0], ( - "Vocab/embeddings size mismatch.") - logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab), - embedding_matrix.shape) + if not drop_tokens: + # Get model word embeddings and vocab. + try: + inv_vocab, embedding_matrix = model.get_embedding_table() + except NotImplementedError: + raise NotImplementedError( + "get_embedding_table is not implemented by the model. Cannot" + "generate Hotflips by flipping tokens. Please set %s to True to" + "generate Hotflips by dropping tokens." % DROP_TOKENS_KEY) + logging.info("Vocab size: %d, Embedding size: %r", len(inv_vocab), + embedding_matrix.shape) + assert len(inv_vocab) == embedding_matrix.shape[0], ( + "Vocab/embeddings size mismatch.") successful_cfs = [] # TODO(lit-team): use only 1 sequence as input (configurable in UI). @@ -326,6 +340,11 @@ def generate(self, logging.info("Identifying Hotflips for input field: %s", str(text_field)) replacement_tokens = None if not drop_tokens: + assert grads is not None, ( + "Gradients are not exposed by the model. Cannot generate" + "Hotflips by flipping tokens. Please set %s to True to generate" + "Hotflips by dropping tokens." % DROP_TOKENS_KEY) + direction = -1 if is_regression: # We want the replacements to increase the prediction score if the diff --git a/lit_nlp/components/hotflip_test.py b/lit_nlp/components/hotflip_test.py index 8ff428ff..f84278c0 100644 --- a/lit_nlp/components/hotflip_test.py +++ b/lit_nlp/components/hotflip_test.py @@ -32,6 +32,13 @@ extract_compressed_file=True) +class STSBModelWithoutEmbeddings(glue_models.STSBModel): + + def get_embedding_table(self): + raise NotImplementedError('get_embedding_table() not implemented for ' + + self.__class__.__name__) + + class ModelBasedHotflipTest(absltest.TestCase): def setUp(self): @@ -46,6 +53,11 @@ def setUp(self): self.regression_model = glue_models.STSBModel(STSB_PATH) self.regression_config = {hotflip.PREDICTION_KEY: 'score'} + # A wrapped version of the above regression model that does not expose + # emeddings. + self.regression_model_without_embeddings = STSBModelWithoutEmbeddings( + STSB_PATH) + def test_find_fields(self): fields = self.hotflip.find_fields(self.classification_model.output_spec(), lit_types.MulticlassPreds) @@ -86,6 +98,21 @@ def test_hotflip_num_ex_multi_input(self): self.hotflip.generate(ex, self.regression_model, None, self.regression_config), 2) + def test_hotflip_num_ex_without_embeddings(self): + ex = {'sentence1': 'this long movie is terrible.', + 'sentence2': 'this short movie is great.'} + self.regression_config[hotflip.NUM_EXAMPLES_KEY] = 2 + thresh = 2 + self.regression_config[hotflip.REGRESSION_THRESH_KEY] = thresh + with self.assertRaises(NotImplementedError): + self.hotflip.generate(ex, self.regression_model_without_embeddings, None, + self.regression_config) + + self.regression_config[hotflip.DROP_TOKENS_KEY] = True + self.assertLen( + self.hotflip.generate(ex, self.regression_model_without_embeddings, + None, self.regression_config), 2) + def test_hotflip_freeze_tokens(self): ex = {'sentence': 'this long movie is terrible.'} self.classification_config[hotflip.NUM_EXAMPLES_KEY] = 10 @@ -144,6 +171,13 @@ def test_hotflip_drops_multi_input(self): self.assertLessEqual(len(cf['tokens_sentence1']), 6) self.assertLessEqual(len(cf['tokens_sentence2']), 6) + # Test with the wrapped model that does not expose embeddings. + cfs = self.hotflip.generate(ex, self.regression_model_without_embeddings, + None, self.regression_config) + for cf in cfs: + self.assertLessEqual(len(cf['tokens_sentence1']), 6) + self.assertLessEqual(len(cf['tokens_sentence2']), 6) + def test_hotflip_max_flips(self): ex = {'sentence': 'this long movie is terrible.'} ex_output = list(self.classification_model.predict([ex]))[0] @@ -200,7 +234,6 @@ def test_hotflip_only_flip_one_field(self): (cf['sentence2'] == ex['sentence2'])) def test_hotflip_changes_pred_class(self): - # Test with a classification model. ex = {'sentence': 'this long movie is terrible.'} ex_output = list(self.classification_model.predict([ex]))[0] pred_class = str(np.argmax(ex_output['probas'])) @@ -225,6 +258,14 @@ def test_hotflip_changes_regression_score(self): for cf_output in cf_outputs: self.assertNotEqual((ex_output['score'] <= thresh), (cf_output['score'] <= thresh)) + # Test with the wrapped model that does not expose embeddings. + self.regression_config[hotflip.DROP_TOKENS_KEY] = True + cfs = self.hotflip.generate(ex, self.regression_model_without_embeddings, + None, self.regression_config) + cf_outputs = self.regression_model.predict(cfs) + for cf_output in cf_outputs: + self.assertNotEqual((ex_output['score'] <= thresh), + (cf_output['score'] <= thresh)) def test_hotflip_fails_without_pred_key(self): ex = {'sentence': 'this long movie is terrible.'} From a72fc478a0f84ce8c5b2e1b0b26f645fae18f09b Mon Sep 17 00:00:00 2001 From: James Wexler Date: Fri, 21 May 2021 13:49:20 -0700 Subject: [PATCH 049/213] Allow for multiple confusion matrices in the module. - Adds controls to add new confusion matrices or delete existing ones. - Adds checkbox to have matrices update on selection changes (the existing and default behavior), or always show matrices across entire dataset. PiperOrigin-RevId: 375156835 --- lit_nlp/client/elements/data_matrix.css | 4 + lit_nlp/client/elements/data_matrix.ts | 16 ++ .../modules/confusion_matrix_module.css | 19 ++ .../client/modules/confusion_matrix_module.ts | 193 +++++++++++++----- 4 files changed, 184 insertions(+), 48 deletions(-) diff --git a/lit_nlp/client/elements/data_matrix.css b/lit_nlp/client/elements/data_matrix.css index b66f2505..59555305 100644 --- a/lit_nlp/client/elements/data_matrix.css +++ b/lit_nlp/client/elements/data_matrix.css @@ -74,3 +74,7 @@ opacity: .7; } +.delete-cell { + display: flex; + justify-content: flex-end +} diff --git a/lit_nlp/client/elements/data_matrix.ts b/lit_nlp/client/elements/data_matrix.ts index 25f0f532..5cdc7cd4 100644 --- a/lit_nlp/client/elements/data_matrix.ts +++ b/lit_nlp/client/elements/data_matrix.ts @@ -245,6 +245,21 @@ export class DataMatrix extends LitElement { // clang-format on } + private renderDeleteButton() { + const deleteMatrix = () => { + const event = new CustomEvent('delete-matrix', {}); + this.dispatchEvent(event); + }; + + // clang-format off + return html` + + delete_outline + + `; + // clang-format on + } + render() { if (this.matrixCells.length === 0) { return null; @@ -276,6 +291,7 @@ export class DataMatrix extends LitElement { ${this.colTitle} + ${this.renderDeleteButton()} diff --git a/lit_nlp/client/modules/confusion_matrix_module.css b/lit_nlp/client/modules/confusion_matrix_module.css index db7c3c98..b70b14af 100644 --- a/lit_nlp/client/modules/confusion_matrix_module.css +++ b/lit_nlp/client/modules/confusion_matrix_module.css @@ -8,6 +8,25 @@ max-width: fit-content; } +.matrices-holder { + display: flex; + flex-flow: wrap; +} + +.matrix { + margin-right: 16px; + margin-bottom: 16px; +} + +.flex { + display: flex; +} + +#create-button { + margin-top: 28px; + pointer-events: auto; +} + /* Make the "Rows" and "Columns" labels the same width */ .dropdown-label { display: inline-block; diff --git a/lit_nlp/client/modules/confusion_matrix_module.ts b/lit_nlp/client/modules/confusion_matrix_module.ts index 47621bf0..f1869fe9 100644 --- a/lit_nlp/client/modules/confusion_matrix_module.ts +++ b/lit_nlp/client/modules/confusion_matrix_module.ts @@ -67,13 +67,15 @@ export class ConfusionMatrixModule extends LitModule { @observable verticalColumnLabels = false; @observable hideEmptyLabels = false; - // These are not observable, because we don't want to trigger a re-render - // until the matrix cells are updated asynchronously. - selectedRowOption = 0; - selectedColOption = 0; + @observable updateOnSelection = true; + @observable selectedRowOption = 0; + @observable selectedColOption = 0; - // Output state for rendering. Computed asynchronously. - @observable matrixCells: MatrixCell[][] = []; + private lastSelectedRow = -1; + private lastSelectedCol = -1; + + // Map of matrices for rendering. Computed asynchronously. + @observable matrices: {[id: string]: MatrixCell[][]} = {}; constructor() { super(); @@ -84,25 +86,42 @@ export class ConfusionMatrixModule extends LitModule { // Calculate the initial confusion matrix. const getCurrentInputData = () => this.appState.currentInputData; this.react(getCurrentInputData, currentInputData => { - this.calculateMatrix(); + this.matrices = {}; + this.calculateMatrix(this.selectedRowOption, this.selectedColOption); }); const getMarginSettings = () => this.classificationService.allMarginSettings; this.react(getMarginSettings, margins => { - this.calculateMatrix(); + this.updateMatrices(); + }); + const getUpdateOnSelection = () => this.updateOnSelection; + this.react(getUpdateOnSelection, updateOnSelection => { + this.updateMatrices(); }); const getSelectedInputData = () => this.selectionService.selectedInputData; - this.react(getSelectedInputData, selectedInputData => { - // Don't reset if we just clicked a cell from this module. - if (this.selectionService.lastUser !== this) { - this.calculateMatrix(); + this.react(getSelectedInputData, async selectedInputData => { + // If the selection is from another module and this is set to update + // on selection changes, then update the matrices. + if (this.selectionService.lastUser !== this && this.updateOnSelection) { + await this.updateMatrices(); + } + // If the selection is from this module then update all matrices that + // weren't the cause of the selection, to reset their selection states + // and recalculate their cells if we are updating on selections. + if (this.selectionService.lastUser === this) { + for (const id of Object.keys(this.matrices)) { + const [row, col] = this.getOptionsFromMatrixId(id); + if (row !== this.lastSelectedRow || col !== this.lastSelectedCol) { + await this.calculateMatrix(row, col); + } + } } }); // Update once on init, to avoid duplicate calls. - this.calculateMatrix(); + this.calculateMatrix(this.selectedRowOption, this.selectedColOption); } private setInitialOptions() { @@ -180,12 +199,19 @@ export class ConfusionMatrixModule extends LitModule { return options; } + private async updateMatrices() { + for (const id of Object.keys(this.matrices)) { + const [row, col] = this.getOptionsFromMatrixId(id); + await this.calculateMatrix(row, col); + } + } + /** * Set the matrix cell information based on the selected axes and examples. */ - private async calculateMatrix() { - const rowOption = this.options[this.selectedRowOption]; - const colOption = this.options[this.selectedColOption]; + private async calculateMatrix(row: number, col: number) { + const rowOption = this.options[row]; + const colOption = this.options[col]; const rowLabels = rowOption.labelList; const colLabels = colOption.labelList; @@ -193,7 +219,11 @@ export class ConfusionMatrixModule extends LitModule { const rowName = rowOption.name; const colName = colOption.name; - const data = this.selectionService.selectedOrAllInputData; + // When updating on selection, use the selected data if a selection exists. + // Otherwise use the whole dataset. + const data = this.updateOnSelection ? + this.selectionService.selectedOrAllInputData : + this.appState.currentInputData; // If there is no data loaded, then do not attempt to create a confusion // matrix. @@ -219,7 +249,9 @@ export class ConfusionMatrixModule extends LitModule { const bins = this.groupService.groupExamplesByFeatures( data, [rowName, colName], getFeatFunc); - this.matrixCells = rowLabels.map(rowLabel => { + + const id = this.getMatrixId(row, col); + const matrixCells = rowLabels.map(rowLabel => { return colLabels.map(colLabel => { // If the rows and columns are the same feature but the cells are for // different values of that feature, then by definition no examples can @@ -235,17 +267,55 @@ export class ConfusionMatrixModule extends LitModule { return {ids, selected: false}; }); }); + this.matrices[id] = matrixCells; + } + + private getMatrixId(row: number, col: number) { + return `${row}:${col}`; + } + + private getOptionsFromMatrixId(id: string) { + return id.split(":").map(numStr => +numStr); + } + + private canCreateMatrix(row: number, col: number) { + // Create create a matrix if the rows and columns are for different fields + // and this matrix isn't already created. + if (row === col) { + return false; + } + const id = this.getMatrixId(row, col); + return this.matrices[id] == null; + } + + @computed + get matrixCreateTooltip() { + if (this.selectedRowOption === this.selectedColOption) { + return 'Must set different row and column options'; + } + if (this.matrices[ + this.getMatrixId(this.selectedRowOption, this.selectedColOption)] != + null) { + return 'Matrix for current row and column options already exists'; + } + return ''; } render() { + const renderMatrices = () => { + return Object.keys(this.matrices).map(id => { + const [row, col] = this.getOptionsFromMatrixId(id); + return this.renderMatrix(row, col, this.matrices[id]); + }); + }; // clang-format off return html`
${this.renderControls()}
-
- ${this.renderMatrix()} +
+ ${renderMatrices()}
`; @@ -255,40 +325,59 @@ export class ConfusionMatrixModule extends LitModule { private renderControls() { const rowChange = (e: Event) => { this.selectedRowOption = +((e.target as HTMLSelectElement).value); - this.calculateMatrix(); }; const colChange = (e: Event) => { this.selectedColOption = +((e.target as HTMLSelectElement).value); - this.calculateMatrix(); + }; + const toggleUpdateOnSelection = () => { + this.updateOnSelection = !this.updateOnSelection; }; const toggleHideCheckbox = () => { this.hideEmptyLabels = !this.hideEmptyLabels; - this.calculateMatrix(); + }; + const onCreateMatrix = () => { + this.calculateMatrix(this.selectedRowOption, this.selectedColOption); }; return html` -