diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index d077cd2..f5a42af 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -38,7 +38,7 @@ def time_regex_to_guide_parallel(self, pattern_name): def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name): # Note: after moving to full rust implementation for index and guide creation, this experiment # is no longer shows the drastic difference as it once showed when python was heavily involved, - # due to on average speedup ~100 times. + # due to speedup up to ~100 times. # This test is to show, that if GIL's switch interval is set to be longer, then the parallel # test's runtime on physical cores will be much closer to the one-threaded case. diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index 77b0823..a2d7921 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -53,6 +53,9 @@ class Vocabulary: def insert(self, token: Union[str, bytes], token_id: int): """Inserts new token with token_id or extends list of token_ids if token already present.""" ... + def remove(self, token: Union[str, bytes]): + """Removes a token from vocabulary.""" + ... def get_eos_token_id(self) -> Optional[int]: """Gets the end of sentence token id.""" ... diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 5655baa..5f1fac3 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -259,6 +259,21 @@ impl PyVocabulary { ))) } + fn remove(&mut self, py: Python<'_>, token: Py) -> PyResult<()> { + if let Ok(t) = token.extract::(py) { + self.0.remove(t); + return Ok(()); + } + if let Ok(t) = token.extract::(py) { + self.0.remove(t); + return Ok(()); + } + Err(PyErr::new::(format!( + "Expected a token of type str or bytes, got {:?}", + type_name!(token) + ))) + } + fn get_eos_token_id(&self) -> TokenId { self.0.eos_token_id() } diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 71f2c42..821abd5 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -57,6 +57,12 @@ impl Vocabulary { Ok(()) } + /// Removes a token from the vocabulary. + pub fn remove(&mut self, token: impl Into) { + let token = token.into(); + self.tokens.remove(&token); + } + /// Creates the vocabulary of pre-trained model from Hugging Face Hub. pub fn from_pretrained( model: &str, @@ -251,6 +257,15 @@ mod tests { .try_insert("six".to_string(), 6) .expect("Insert failed"); assert_eq!(vocabulary.token_to_ids("six"), Some(&vec![6])); + + vocabulary.remove(b"four"); + assert_eq!(vocabulary.token_to_ids("four"), None); + + vocabulary.remove(b"five".to_vec()); + assert_eq!(vocabulary.token_to_ids("five"), None); + + vocabulary.remove("six".to_string()); + assert_eq!(vocabulary.token_to_ids("six"), None); } #[test] diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index f4879d6..e44e2da 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -12,12 +12,8 @@ def vocabulary(): return Vocabulary(eos_token_id, tokens) -def test_basic_vocabulary_interface(): - eos_token_id = 3 - tokens = {"1": [1], "a": [2]} - vocabulary = Vocabulary(eos_token_id, tokens) - - assert vocabulary.get_eos_token_id() == eos_token_id +def test_basic_vocabulary_interface(vocabulary): + assert vocabulary.get_eos_token_id() == 3 assert vocabulary.get("1") == vocabulary.get(b"1") == [1] assert len(vocabulary) == 2 @@ -29,6 +25,17 @@ def test_basic_vocabulary_interface(): assert vocabulary.get("b") == vocabulary.get(b"b") == [4, 5] assert len(vocabulary) == 3 + vocabulary.remove("b") + assert vocabulary.get("b") is None + + # second remove doesn't fail too + vocabulary.remove("b") + assert vocabulary.get("b") is None + + assert vocabulary.get("a") == [2] + vocabulary.remove(b"a") + assert vocabulary.get("a") is None + def test_string_and_bytes_as_tokens(): eos_token_id = 3