Skip to content

Commit

Permalink
Change predict_with_metadata() calls to predict() in preparation for …
Browse files Browse the repository at this point in the history
…removal of predict_with_metadata() method across LIT.

PiperOrigin-RevId: 551611787
  • Loading branch information
nadah09 authored and LIT team committed Jul 27, 2023
1 parent bc6f82b commit 7888c66
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 58 deletions.
6 changes: 3 additions & 3 deletions lit_nlp/components/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _fill_indices(self, model_name, dataset_name):
"""Create all indices for a single model."""
model = self._models.get(model_name)
assert model is not None, "Invalid model name."
examples = self.datasets[dataset_name].indexed_examples
examples = self.datasets[dataset_name].examples
model_embeddings_names = utils.find_spec_keys(model.output_spec(),
lit_types.Embeddings)
lookup_key = self._get_lookup_key(model_name, dataset_name)
Expand Down Expand Up @@ -158,7 +158,7 @@ def _fill_indices(self, model_name, dataset_name):
# Cold start: Get embeddings for non-initialized settings.
if self._initialize_new_indices:
for res_ix, (result, example) in enumerate(
zip(model.predict_with_metadata(examples), examples)):
zip(model.predict(examples), examples)):
for emb_name in embeddings_to_index:
index_key = self._get_index_key(model_name, dataset_name, emb_name)
# Initialize saving in the first iteration.
Expand All @@ -170,7 +170,7 @@ def _fill_indices(self, model_name, dataset_name):
# Each item has an incrementing ID res_ix.
self._indices[index_key].add_item(res_ix, result[emb_name])
# Add item to lookup table.
self._example_lookup[lookup_key][res_ix] = example["data"]
self._example_lookup[lookup_key][res_ix] = example

# Create the trees from the indices - using 10 as recommended by doc.
for emb_name in embeddings_to_index:
Expand Down
9 changes: 4 additions & 5 deletions lit_nlp/components/minimal_targeted_counterfactuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,21 +286,20 @@ def _filter_ds_examples(
'Only indexed datasets are currently supported by the TabularMTC'
'generator.')

indexed_examples = list(dataset.indexed_examples)
examples = list(dataset.examples)
filtered_examples = []
preds = model.predict_with_metadata(
indexed_examples, dataset_name=dataset_name)
preds = model.predict(examples)

# Find all DS examples that are flips with respect to the reference example.
for indexed_example, pred in zip(indexed_examples, preds):
for example, pred in zip(examples, preds):
flip = cf_utils.is_prediction_flip(
cf_output=pred,
orig_output=reference_output,
output_spec=model.output_spec(),
pred_key=pred_key,
regression_thresh=regression_thresh)
if flip:
candidate_example = dict(indexed_example['data'])
candidate_example = dict(example)
self._find_dataset_parent_and_set(
model_output_spec=model.output_spec(),
pred_key=pred_key,
Expand Down
10 changes: 3 additions & 7 deletions lit_nlp/components/similarity_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,9 @@ def __init__(self, indexer: index.Indexer):
self.index = indexer

def _get_embedding(self, example: types.Input, model: lit_model.Model,
dataset: lit_dataset.IndexedDataset, embedding_name: str,
dataset_name: str):
embedding_name: str):
"""Calls the model on the example to get the embedding."""
model_input = dataset.index_inputs([example])
model_output = model.predict_with_metadata(
model_input, dataset_name=dataset_name)
model_output = model.predict([example])
embedding = list(model_output)[0][embedding_name]
return embedding

Expand All @@ -66,8 +63,7 @@ def generate( # pytype: disable=signature-mismatch # overriding-parameter-type
model_name = config['model_name']
dataset_name = config['dataset_name']
embedding_name = config['Embedding Field']
embedding = self._get_embedding(example, model, dataset, embedding_name,
dataset_name)
embedding = self._get_embedding(example, model, embedding_name)
neighbors = self._find_nn(model_name, dataset_name, embedding_name,
embedding)
return neighbors
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/thresholder_int_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def setUpClass(cls):
indexed_examples=_INDEXED_EXAMPLES,
)
cls.model_outputs = list(
cls.model.predict_with_metadata(_INDEXED_EXAMPLES, dataset_name='test')
cls.model.predict(_EXAMPLES)
)

def setUp(self):
Expand Down
55 changes: 13 additions & 42 deletions lit_nlp/lib/caching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,79 +33,50 @@ def test_preds_cache(self):
self.assertIsNone(None, cache.get(("a", "2")))
self.assertEqual("test", cache.get(("a", "1")))

def test_caching_model_wrapper_no_dataset_skip_cache(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1}, "id": "my_id"}]
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
results = list(wrapper.predict_with_metadata(examples))
self.assertEqual(2, model.count)
self.assertEqual({"score": 1}, results[0])

def test_caching_model_wrapper_use_cache(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1, "_id": "id_to_cache"}, "id": "id_to_cache"}]
results = list(wrapper.predict_with_metadata(examples))
examples = [{"val": 1, "_id": "id_to_cache"}]
results = wrapper.predict(examples)
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
results = list(wrapper.predict_with_metadata(examples))
results = wrapper.predict(examples)
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
self.assertEmpty(wrapper._cache._pred_locks)

def test_caching_model_wrapper_not_cached(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")
examples = [{"data": {"val": 1}, "id": "my_id"}]
results = list(wrapper.predict_with_metadata(examples))
examples = [{"val": 1, "_id": "my_id"}]
results = wrapper.predict(examples)
self.assertEqual(1, model.count)
self.assertEqual({"score": 1}, results[0])
examples = [{"data": {"val": 2}, "id": "other_id"}]
results = list(wrapper.predict_with_metadata(examples))
examples = [{"val": 2, "_id": "other_id"}]
results = wrapper.predict(examples)
self.assertEqual(2, model.count)
self.assertEqual({"score": 2}, results[0])

def test_caching_model_wrapper_mixed_list(self):
def test_caching_model_wrapper_uses_cached_subset(self):
model = testing_utils.IdentityRegressionModelForTesting()
wrapper = caching.CachingModelWrapper(model, "test")

examples = [
{
"data": {
"val": 0,
"_id": "zeroth_id"
},
"id": "zeroth_id"
},
{
"data": {
"val": 1,
"_id": "first_id"
},
"id": "first_id"
},
{
"data": {
"val": 2,
"_id": "second_id"
},
"id": "second_id"
},
{"val": 0, "_id": "zeroth_id"},
{"val": 1, "_id": "first_id"},
{"val": 2, "_id": "second_id"},
]
subset = examples[:1]

# Run the CachingModelWrapper over a subset of examples
results = list(wrapper.predict_with_metadata(subset))
results = wrapper.predict(subset)
self.assertEqual(1, model.count)
self.assertEqual({"score": 0}, results[0])

# Now, run the CachingModelWrapper over all of the examples. This should
# only pass the examples that were not in subset to the wrapped model, and
# the total number of inputs processed by the wrapped model should be 3
results = list(wrapper.predict_with_metadata(examples))
results = wrapper.predict(examples)
self.assertEqual(3, model.count)
self.assertEqual({"score": 0}, results[0])
self.assertEqual({"score": 1}, results[1])
Expand Down

0 comments on commit 7888c66

Please sign in to comment.