From 7888c6677081049111f1c3d51943dea2c9351c59 Mon Sep 17 00:00:00 2001 From: Nada Hussein Date: Thu, 27 Jul 2023 12:41:31 -0700 Subject: [PATCH] Change predict_with_metadata() calls to predict() in preparation for removal of predict_with_metadata() method across LIT. PiperOrigin-RevId: 551611787 --- lit_nlp/components/index.py | 6 +- .../minimal_targeted_counterfactuals.py | 9 ++- lit_nlp/components/similarity_searcher.py | 10 +--- lit_nlp/components/thresholder_int_test.py | 2 +- lit_nlp/lib/caching_test.py | 55 +++++-------------- 5 files changed, 24 insertions(+), 58 deletions(-) diff --git a/lit_nlp/components/index.py b/lit_nlp/components/index.py index 4122b5d1..7a98d6c4 100644 --- a/lit_nlp/components/index.py +++ b/lit_nlp/components/index.py @@ -130,7 +130,7 @@ def _fill_indices(self, model_name, dataset_name): """Create all indices for a single model.""" model = self._models.get(model_name) assert model is not None, "Invalid model name." - examples = self.datasets[dataset_name].indexed_examples + examples = self.datasets[dataset_name].examples model_embeddings_names = utils.find_spec_keys(model.output_spec(), lit_types.Embeddings) lookup_key = self._get_lookup_key(model_name, dataset_name) @@ -158,7 +158,7 @@ def _fill_indices(self, model_name, dataset_name): # Cold start: Get embeddings for non-initialized settings. if self._initialize_new_indices: for res_ix, (result, example) in enumerate( - zip(model.predict_with_metadata(examples), examples)): + zip(model.predict(examples), examples)): for emb_name in embeddings_to_index: index_key = self._get_index_key(model_name, dataset_name, emb_name) # Initialize saving in the first iteration. @@ -170,7 +170,7 @@ def _fill_indices(self, model_name, dataset_name): # Each item has an incrementing ID res_ix. self._indices[index_key].add_item(res_ix, result[emb_name]) # Add item to lookup table. - self._example_lookup[lookup_key][res_ix] = example["data"] + self._example_lookup[lookup_key][res_ix] = example # Create the trees from the indices - using 10 as recommended by doc. for emb_name in embeddings_to_index: diff --git a/lit_nlp/components/minimal_targeted_counterfactuals.py b/lit_nlp/components/minimal_targeted_counterfactuals.py index 990fdf78..a43b98f5 100644 --- a/lit_nlp/components/minimal_targeted_counterfactuals.py +++ b/lit_nlp/components/minimal_targeted_counterfactuals.py @@ -286,13 +286,12 @@ def _filter_ds_examples( 'Only indexed datasets are currently supported by the TabularMTC' 'generator.') - indexed_examples = list(dataset.indexed_examples) + examples = list(dataset.examples) filtered_examples = [] - preds = model.predict_with_metadata( - indexed_examples, dataset_name=dataset_name) + preds = model.predict(examples) # Find all DS examples that are flips with respect to the reference example. - for indexed_example, pred in zip(indexed_examples, preds): + for example, pred in zip(examples, preds): flip = cf_utils.is_prediction_flip( cf_output=pred, orig_output=reference_output, @@ -300,7 +299,7 @@ def _filter_ds_examples( pred_key=pred_key, regression_thresh=regression_thresh) if flip: - candidate_example = dict(indexed_example['data']) + candidate_example = dict(example) self._find_dataset_parent_and_set( model_output_spec=model.output_spec(), pred_key=pred_key, diff --git a/lit_nlp/components/similarity_searcher.py b/lit_nlp/components/similarity_searcher.py index 970f51c0..acd17336 100644 --- a/lit_nlp/components/similarity_searcher.py +++ b/lit_nlp/components/similarity_searcher.py @@ -34,12 +34,9 @@ def __init__(self, indexer: index.Indexer): self.index = indexer def _get_embedding(self, example: types.Input, model: lit_model.Model, - dataset: lit_dataset.IndexedDataset, embedding_name: str, - dataset_name: str): + embedding_name: str): """Calls the model on the example to get the embedding.""" - model_input = dataset.index_inputs([example]) - model_output = model.predict_with_metadata( - model_input, dataset_name=dataset_name) + model_output = model.predict([example]) embedding = list(model_output)[0][embedding_name] return embedding @@ -66,8 +63,7 @@ def generate( # pytype: disable=signature-mismatch # overriding-parameter-type model_name = config['model_name'] dataset_name = config['dataset_name'] embedding_name = config['Embedding Field'] - embedding = self._get_embedding(example, model, dataset, embedding_name, - dataset_name) + embedding = self._get_embedding(example, model, embedding_name) neighbors = self._find_nn(model_name, dataset_name, embedding_name, embedding) return neighbors diff --git a/lit_nlp/components/thresholder_int_test.py b/lit_nlp/components/thresholder_int_test.py index 12a1ea4f..21660fbe 100644 --- a/lit_nlp/components/thresholder_int_test.py +++ b/lit_nlp/components/thresholder_int_test.py @@ -66,7 +66,7 @@ def setUpClass(cls): indexed_examples=_INDEXED_EXAMPLES, ) cls.model_outputs = list( - cls.model.predict_with_metadata(_INDEXED_EXAMPLES, dataset_name='test') + cls.model.predict(_EXAMPLES) ) def setUp(self): diff --git a/lit_nlp/lib/caching_test.py b/lit_nlp/lib/caching_test.py index 5832751d..f4f21a1a 100644 --- a/lit_nlp/lib/caching_test.py +++ b/lit_nlp/lib/caching_test.py @@ -33,25 +33,14 @@ def test_preds_cache(self): self.assertIsNone(None, cache.get(("a", "2"))) self.assertEqual("test", cache.get(("a", "1"))) - def test_caching_model_wrapper_no_dataset_skip_cache(self): - model = testing_utils.IdentityRegressionModelForTesting() - wrapper = caching.CachingModelWrapper(model, "test") - examples = [{"data": {"val": 1}, "id": "my_id"}] - results = list(wrapper.predict_with_metadata(examples)) - self.assertEqual(1, model.count) - self.assertEqual({"score": 1}, results[0]) - results = list(wrapper.predict_with_metadata(examples)) - self.assertEqual(2, model.count) - self.assertEqual({"score": 1}, results[0]) - def test_caching_model_wrapper_use_cache(self): model = testing_utils.IdentityRegressionModelForTesting() wrapper = caching.CachingModelWrapper(model, "test") - examples = [{"data": {"val": 1, "_id": "id_to_cache"}, "id": "id_to_cache"}] - results = list(wrapper.predict_with_metadata(examples)) + examples = [{"val": 1, "_id": "id_to_cache"}] + results = wrapper.predict(examples) self.assertEqual(1, model.count) self.assertEqual({"score": 1}, results[0]) - results = list(wrapper.predict_with_metadata(examples)) + results = wrapper.predict(examples) self.assertEqual(1, model.count) self.assertEqual({"score": 1}, results[0]) self.assertEmpty(wrapper._cache._pred_locks) @@ -59,53 +48,35 @@ def test_caching_model_wrapper_use_cache(self): def test_caching_model_wrapper_not_cached(self): model = testing_utils.IdentityRegressionModelForTesting() wrapper = caching.CachingModelWrapper(model, "test") - examples = [{"data": {"val": 1}, "id": "my_id"}] - results = list(wrapper.predict_with_metadata(examples)) + examples = [{"val": 1, "_id": "my_id"}] + results = wrapper.predict(examples) self.assertEqual(1, model.count) self.assertEqual({"score": 1}, results[0]) - examples = [{"data": {"val": 2}, "id": "other_id"}] - results = list(wrapper.predict_with_metadata(examples)) + examples = [{"val": 2, "_id": "other_id"}] + results = wrapper.predict(examples) self.assertEqual(2, model.count) self.assertEqual({"score": 2}, results[0]) - def test_caching_model_wrapper_mixed_list(self): + def test_caching_model_wrapper_uses_cached_subset(self): model = testing_utils.IdentityRegressionModelForTesting() wrapper = caching.CachingModelWrapper(model, "test") examples = [ - { - "data": { - "val": 0, - "_id": "zeroth_id" - }, - "id": "zeroth_id" - }, - { - "data": { - "val": 1, - "_id": "first_id" - }, - "id": "first_id" - }, - { - "data": { - "val": 2, - "_id": "second_id" - }, - "id": "second_id" - }, + {"val": 0, "_id": "zeroth_id"}, + {"val": 1, "_id": "first_id"}, + {"val": 2, "_id": "second_id"}, ] subset = examples[:1] # Run the CachingModelWrapper over a subset of examples - results = list(wrapper.predict_with_metadata(subset)) + results = wrapper.predict(subset) self.assertEqual(1, model.count) self.assertEqual({"score": 0}, results[0]) # Now, run the CachingModelWrapper over all of the examples. This should # only pass the examples that were not in subset to the wrapped model, and # the total number of inputs processed by the wrapped model should be 3 - results = list(wrapper.predict_with_metadata(examples)) + results = wrapper.predict(examples) self.assertEqual(3, model.count) self.assertEqual({"score": 0}, results[0]) self.assertEqual({"score": 1}, results[1])