Skip to content

Commit

Permalink
Use bytes as Token type, more tests for Index
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 18, 2024
1 parent 52b1093 commit 5b8808d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 43 deletions.
55 changes: 50 additions & 5 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl Index {
}

let mut next_state = current_state;
for transition_byte in token.as_bytes() {
for transition_byte in token {
next_state = dfa.next_state(next_state, *transition_byte);
if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
continue 'token_loop;
Expand Down Expand Up @@ -230,19 +230,64 @@ mod tests {
.insert("blah", 0)
.insert("1a", 1)
.insert("2", 2)
.insert("0", 3)
.insert("<eos>", 4);
.insert("0", 3);

let index = Index::from_regex(regex, &vocabulary).expect("Index failed");
assert_eq!(index.initial(), 40);
assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));

let expected: HashMap<u32, HashMap<u32, u32>> = HashMap::from_iter([
let expected = HashMap::from_iter([
(24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
(48, HashMap::from_iter([(4, 48)])),
(40, HashMap::from_iter([(3, 48), (2, 56)])),
(56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
]);
assert_eq!(&expected, index.transitions());
assert_eq!(index.transitions(), &expected);
}

#[test]
fn index_from_regex_initital_in_allowed() {
let regex = "`\\n(\\.\\n)?`\\n";
let vocabulary = Vocabulary::new(Some(104))
.insert("\n", 103)
.insert(".", 102)
.insert("`", 101);

let index = Index::from_regex(regex, &vocabulary).expect("Index failed");
let allowed = index
.allowed_tokens(index.initial())
.expect("No allowed tokens");
assert!(allowed.contains(&101));
}

#[test]
fn index_from_regex_multibyte() {
let regex = "😇| [😈-😍][😇-😎]*";
let vocabulary = Vocabulary::new(Some(8))
.insert(" 😍", 5)
.insert("blah", 0)
.insert("😇", 2)
.insert("😈a", 1)
.insert("😍", 3)
.insert(vec![32, 240, 159, 152], 7)
.insert(vec![32, 240, 159, 152, 141], 6)
.insert(vec![240, 159, 152, 141], 4);

let index = Index::from_regex(regex, &vocabulary).expect("Index failed");

assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));

let expected = HashMap::from_iter([
(
208,
HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
),
(
80,
HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]),
),
(128, HashMap::from_iter([(8, 128)])),
]);
assert_eq!(index.transitions(), &expected);
}
}
2 changes: 1 addition & 1 deletion src/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pub type TransitionKey = u32;

/// Token content.
pub type Token = String;
pub type Token = Vec<u8>;

/// Token identifier.
pub type TokenId = u32;
Expand Down
91 changes: 54 additions & 37 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_hash::FxHashMap as HashMap;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};

use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};
Expand Down Expand Up @@ -90,11 +90,7 @@ impl Vocabulary {
});
};
for (token, token_id) in tokenizer.get_vocab(false) {
let token_bytes = processor.process(token)?;
// TODO: lossy is temp:
// - in python in was handled by byte_symbol function
// - interface needs to be redefined to treat Token type as bytes: Vec<u8>
let processed_token = String::from_utf8_lossy(&token_bytes);
let processed_token= processor.process(token)?;
vocabulary = vocabulary.insert(processed_token, token_id);
}

Expand All @@ -107,7 +103,7 @@ impl Vocabulary {
}

/// Per provided token returns vector of `TokenId`s if available in the vocabulary.
pub fn token_to_ids(&self, token: &str) -> Option<&Vec<TokenId>> {
pub fn token_to_ids(&self, token: &Token) -> Option<&Vec<TokenId>> {
self.tokens.get(token)
}

Expand Down Expand Up @@ -214,6 +210,18 @@ impl From<HashMap<Token, Vec<TokenId>>> for Vocabulary {
}
}

impl From<HashMap<String, Vec<TokenId>>> for Vocabulary {
fn from(tokens: HashMap<String, Vec<TokenId>>) -> Vocabulary {
Vocabulary {
eos_token_id: None,
tokens: tokens
.into_iter()
.map(|(k,v)| (k.as_bytes().to_vec(), v))
.collect::<HashMap<Token, Vec<TokenId>>>(),
}
}
}

impl<T, I> FromIterator<(T, I)> for Vocabulary
where
T: Into<Token>,
Expand All @@ -237,10 +245,10 @@ mod tests {
.insert("0", 3);

assert_eq!(vocabulary.len(), 4);
assert_eq!(vocabulary["blah"], &[0]);
assert_eq!(vocabulary["1a"], &[1]);
assert_eq!(vocabulary["2"], &[2]);
assert_eq!(vocabulary["0"], &[3]);
assert_eq!(vocabulary["blah".as_bytes()], &[0]);
assert_eq!(vocabulary["1a".as_bytes()], &[1]);
assert_eq!(vocabulary["2".as_bytes()], &[2]);
assert_eq!(vocabulary["0".as_bytes()], &[3]);
}

#[test]
Expand All @@ -253,10 +261,10 @@ mod tests {
]);

assert_eq!(vocabulary.len(), 4);
assert_eq!(vocabulary["blah"], &[0]);
assert_eq!(vocabulary["1a"], &[1]);
assert_eq!(vocabulary["2"], &[2]);
assert_eq!(vocabulary["0"], &[3]);
assert_eq!(vocabulary["blah".as_bytes()], &[0]);
assert_eq!(vocabulary["1a".as_bytes()], &[1]);
assert_eq!(vocabulary["2".as_bytes()], &[2]);
assert_eq!(vocabulary["0".as_bytes()], &[3]);
}

#[test]
Expand All @@ -268,15 +276,15 @@ mod tests {

#[test]
fn new_empty_vocabulary_from_hashmap() {
let map = HashMap::default();
let map: HashMap<Token, Vec<TokenId>> = HashMap::default();
let vocabulary = Vocabulary::from(map);
assert!(vocabulary.eos_token_id.is_none());
assert!(vocabulary.tokens.is_empty());
}

#[test]
fn new_vocabulary_from_iterator() {
let token: Token = "abc".to_string();
let token: Token = "abc".as_bytes().to_vec();
let id: Vec<TokenId> = vec![1];
let it = vec![(token, id)];
let vocabulary = Vocabulary::from_iter(it);
Expand Down Expand Up @@ -330,11 +338,12 @@ mod tests {
);

let token = "Ġal";
assert!(vocabulary.token_to_ids(token).is_none());
assert!(tokenizer.token_to_id(token).is_some());
let btoken = token.as_bytes().to_vec();
assert!(vocabulary.token_to_ids(&btoken).is_none());
assert!(tokenizer.token_to_id(&token).is_some());

for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] {
let v_ids = vocabulary.token_to_ids(v_token);
let v_ids = vocabulary.token_to_ids(&v_token.as_bytes().to_vec());
assert!(v_ids.is_some());
for v_id in v_ids.unwrap() {
let t_token = tokenizer
Expand All @@ -361,24 +370,32 @@ mod tests {
tokenizer.id_to_token(v_eos).expect("Token not found"),
"</s>"
);

for (v_token, t_token_expected) in [
("abc", "abc"),
(" al", "▁al"),
(" O", "▁O"),
(" ", "▁▁▁"),
// TODO: won't pass since first we need to change token's type to bytes
// ("<0xFF>", "ÿ"),
// ("<0x20>", "▁"),
] {
let v_ids = vocabulary.token_to_ids(v_token);

let tests: &[(Vec<u8>, &[&str])] = &[
("abc".as_bytes().to_vec(), &["abc"]),
(" al".as_bytes().to_vec(), &["▁al"]),
(" O".as_bytes().to_vec(), &["▁O"]),
(" ".as_bytes().to_vec(), &["▁▁▁"]),
(" ".as_bytes().to_vec(), &["▁", "<0x20>"]),
("a".as_bytes().to_vec(), &["a", "<0x61>"]),
(vec![0xFF], &["<0xFF>"]),
(vec![0x20], &["▁", "<0x20>"]),
];
for (v_token, t_tokens_expected) in tests {
let v_ids = vocabulary.token_to_ids(&v_token);
assert!(v_ids.is_some());
for v_id in v_ids.unwrap() {
let t_token = tokenizer
.id_to_token(*v_id)
.expect("Token id not found in tokenizer");
assert_eq!(&t_token, t_token_expected);
}

let t_tokens = v_ids.unwrap()
.iter()
.map(|v_id| {
tokenizer
.id_to_token(*v_id)
.expect("Token id not found in tokenizer")
}
)
.collect::<HashSet<String>>();
let expected = HashSet::from_iter(t_tokens_expected.iter().map(|s| s.to_string()));
assert_eq!(t_tokens, expected)
}
}

Expand Down

0 comments on commit 5b8808d

Please sign in to comment.