From 5b8808d61c0ded3a9741064a91a78a527b76f9b8 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 18 Dec 2024 15:48:36 +0000 Subject: [PATCH] Use bytes as Token type, more tests for Index --- src/index.rs | 55 +++++++++++++++++++++++--- src/primitives.rs | 2 +- src/vocabulary/mod.rs | 91 +++++++++++++++++++++++++------------------ 3 files changed, 105 insertions(+), 43 deletions(-) diff --git a/src/index.rs b/src/index.rs index a915766c..3df6e742 100644 --- a/src/index.rs +++ b/src/index.rs @@ -138,7 +138,7 @@ impl Index { } let mut next_state = current_state; - for transition_byte in token.as_bytes() { + for transition_byte in token { next_state = dfa.next_state(next_state, *transition_byte); if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) { continue 'token_loop; @@ -230,19 +230,64 @@ mod tests { .insert("blah", 0) .insert("1a", 1) .insert("2", 2) - .insert("0", 3) - .insert("", 4); + .insert("0", 3); let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); assert_eq!(index.initial(), 40); assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); - let expected: HashMap> = HashMap::from_iter([ + let expected = HashMap::from_iter([ (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])), (48, HashMap::from_iter([(4, 48)])), (40, HashMap::from_iter([(3, 48), (2, 56)])), (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])), ]); - assert_eq!(&expected, index.transitions()); + assert_eq!(index.transitions(), &expected); + } + + #[test] + fn index_from_regex_initital_in_allowed() { + let regex = "`\\n(\\.\\n)?`\\n"; + let vocabulary = Vocabulary::new(Some(104)) + .insert("\n", 103) + .insert(".", 102) + .insert("`", 101); + + let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + let allowed = index + .allowed_tokens(index.initial()) + .expect("No allowed tokens"); + assert!(allowed.contains(&101)); + } + + #[test] + fn index_from_regex_multibyte() { + let regex = "πŸ˜‡| [😈-😍][πŸ˜‡-😎]*"; + let vocabulary = Vocabulary::new(Some(8)) + .insert(" 😍", 5) + .insert("blah", 0) + .insert("πŸ˜‡", 2) + .insert("😈a", 1) + .insert("😍", 3) + .insert(vec![32, 240, 159, 152], 7) + .insert(vec![32, 240, 159, 152, 141], 6) + .insert(vec![240, 159, 152, 141], 4); + + let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); + + assert_eq!(index.final_states(), &HashSet::from_iter([208, 128])); + + let expected = HashMap::from_iter([ + ( + 208, + HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]), + ), + ( + 80, + HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]), + ), + (128, HashMap::from_iter([(8, 128)])), + ]); + assert_eq!(index.transitions(), &expected); } } diff --git a/src/primitives.rs b/src/primitives.rs index e12bf036..0976f76d 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -2,7 +2,7 @@ pub type TransitionKey = u32; /// Token content. -pub type Token = String; +pub type Token = Vec; /// Token identifier. pub type TokenId = u32; diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 13156ade..613d735f 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap as HashMap; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -90,11 +90,7 @@ impl Vocabulary { }); }; for (token, token_id) in tokenizer.get_vocab(false) { - let token_bytes = processor.process(token)?; - // TODO: lossy is temp: - // - in python in was handled by byte_symbol function - // - interface needs to be redefined to treat Token type as bytes: Vec - let processed_token = String::from_utf8_lossy(&token_bytes); + let processed_token= processor.process(token)?; vocabulary = vocabulary.insert(processed_token, token_id); } @@ -107,7 +103,7 @@ impl Vocabulary { } /// Per provided token returns vector of `TokenId`s if available in the vocabulary. - pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { + pub fn token_to_ids(&self, token: &Token) -> Option<&Vec> { self.tokens.get(token) } @@ -214,6 +210,18 @@ impl From>> for Vocabulary { } } +impl From>> for Vocabulary { + fn from(tokens: HashMap>) -> Vocabulary { + Vocabulary { + eos_token_id: None, + tokens: tokens + .into_iter() + .map(|(k,v)| (k.as_bytes().to_vec(), v)) + .collect::>>(), + } + } +} + impl FromIterator<(T, I)> for Vocabulary where T: Into, @@ -237,10 +245,10 @@ mod tests { .insert("0", 3); assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); + assert_eq!(vocabulary["blah".as_bytes()], &[0]); + assert_eq!(vocabulary["1a".as_bytes()], &[1]); + assert_eq!(vocabulary["2".as_bytes()], &[2]); + assert_eq!(vocabulary["0".as_bytes()], &[3]); } #[test] @@ -253,10 +261,10 @@ mod tests { ]); assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); + assert_eq!(vocabulary["blah".as_bytes()], &[0]); + assert_eq!(vocabulary["1a".as_bytes()], &[1]); + assert_eq!(vocabulary["2".as_bytes()], &[2]); + assert_eq!(vocabulary["0".as_bytes()], &[3]); } #[test] @@ -268,7 +276,7 @@ mod tests { #[test] fn new_empty_vocabulary_from_hashmap() { - let map = HashMap::default(); + let map: HashMap> = HashMap::default(); let vocabulary = Vocabulary::from(map); assert!(vocabulary.eos_token_id.is_none()); assert!(vocabulary.tokens.is_empty()); @@ -276,7 +284,7 @@ mod tests { #[test] fn new_vocabulary_from_iterator() { - let token: Token = "abc".to_string(); + let token: Token = "abc".as_bytes().to_vec(); let id: Vec = vec![1]; let it = vec![(token, id)]; let vocabulary = Vocabulary::from_iter(it); @@ -330,11 +338,12 @@ mod tests { ); let token = "Δ al"; - assert!(vocabulary.token_to_ids(token).is_none()); - assert!(tokenizer.token_to_id(token).is_some()); + let btoken = token.as_bytes().to_vec(); + assert!(vocabulary.token_to_ids(&btoken).is_none()); + assert!(tokenizer.token_to_id(&token).is_some()); for (v_token, t_token_expected) in [("abc", "abc"), (" O", "Δ O")] { - let v_ids = vocabulary.token_to_ids(v_token); + let v_ids = vocabulary.token_to_ids(&v_token.as_bytes().to_vec()); assert!(v_ids.is_some()); for v_id in v_ids.unwrap() { let t_token = tokenizer @@ -361,24 +370,32 @@ mod tests { tokenizer.id_to_token(v_eos).expect("Token not found"), "" ); - - for (v_token, t_token_expected) in [ - ("abc", "abc"), - (" al", "▁al"), - (" O", "▁O"), - (" ", "▁▁▁"), - // TODO: won't pass since first we need to change token's type to bytes - // ("<0xFF>", "ΓΏ"), - // ("<0x20>", "▁"), - ] { - let v_ids = vocabulary.token_to_ids(v_token); + + let tests: &[(Vec, &[&str])] = &[ + ("abc".as_bytes().to_vec(), &["abc"]), + (" al".as_bytes().to_vec(), &["▁al"]), + (" O".as_bytes().to_vec(), &["▁O"]), + (" ".as_bytes().to_vec(), &["▁▁▁"]), + (" ".as_bytes().to_vec(), &["▁", "<0x20>"]), + ("a".as_bytes().to_vec(), &["a", "<0x61>"]), + (vec![0xFF], &["<0xFF>"]), + (vec![0x20], &["▁", "<0x20>"]), + ]; + for (v_token, t_tokens_expected) in tests { + let v_ids = vocabulary.token_to_ids(&v_token); assert!(v_ids.is_some()); - for v_id in v_ids.unwrap() { - let t_token = tokenizer - .id_to_token(*v_id) - .expect("Token id not found in tokenizer"); - assert_eq!(&t_token, t_token_expected); - } + + let t_tokens = v_ids.unwrap() + .iter() + .map(|v_id| { + tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer") + } + ) + .collect::>(); + let expected = HashSet::from_iter(t_tokens_expected.iter().map(|s| s.to_string())); + assert_eq!(t_tokens, expected) } }