diff --git a/src/error.rs b/src/error.rs index a8e5864..53a8728 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,7 +12,7 @@ pub enum Error { #[error("Failed to build DFA {0}")] IndexDfaError(#[from] Box), #[error("Index failed since anchored universal start state doesn't exist")] - IndexNoAnchoredUniversalStartState, + DfaHasNoStartState, #[error(transparent)] TokenizersError(#[from] tokenizers::Error), #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: /~https://github.com/dottxt-ai/outlines-core/issues")] diff --git a/src/index.rs b/src/index.rs index 127ea4c..a915766 100644 --- a/src/index.rs +++ b/src/index.rs @@ -7,24 +7,24 @@ use bincode::{Decode, Encode}; use regex_automata::dfa::{dense::DFA, Automaton}; use regex_automata::util::primitives::StateID as AutomataStateId; use regex_automata::Anchored; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[derive(Debug)] pub struct FSMInfo { pub(crate) initial: State, - pub(crate) finals: FxHashSet, - pub(crate) transitions: FxHashMap<(State, TransitionKey), State>, + pub(crate) finals: HashSet, + pub(crate) transitions: HashMap<(State, TransitionKey), State>, pub(crate) alphabet_anything_value: TransitionKey, - pub(crate) alphabet_symbol_mapping: FxHashMap, + pub(crate) alphabet_symbol_mapping: HashMap, } impl FSMInfo { pub fn new( initial: State, - finals: FxHashSet, - transitions: FxHashMap<(State, TransitionKey), State>, + finals: HashSet, + transitions: HashMap<(State, TransitionKey), State>, alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: FxHashMap, + alphabet_symbol_mapping: HashMap, ) -> Self { Self { initial, @@ -39,8 +39,8 @@ impl FSMInfo { #[derive(Debug, Encode, Decode)] pub struct Index { initial: u32, - finals: FxHashSet, - states_to_token_subsets: FxHashMap>, + finals: HashSet, + states_to_token_subsets: HashMap>, eos_token_id: u32, } @@ -49,11 +49,11 @@ impl Index { fsm_info: &FSMInfo, vocabulary: &Vocabulary, eos_token_id: u32, - frozen_tokens: FxHashSet, + frozen_tokens: HashSet, ) -> Result { - let mut states_to_token_subsets: FxHashMap> = FxHashMap::default(); - let mut seen: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter([fsm_info.initial]); + let mut states_to_token_subsets: HashMap> = HashMap::default(); + let mut seen: HashSet = HashSet::default(); + let mut next_states: HashSet = HashSet::from_iter([fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, @@ -111,26 +111,25 @@ impl Index { pub(crate) fn from_regex(regex: &str, vocabulary: &Vocabulary) -> Result { let eos_token_id = match vocabulary.eos_token_id() { Some(s) => s, + // TODO: this error will be removed once eos_token_id for vocabulary won't be optional None => return Err(Error::IndexEosTokenIdNotAvailable), }; - let dfa = DFA::builder().build(regex).map_err(Box::new)?; + let dfa = DFA::new(regex).map_err(Box::new)?; let start_state = match dfa.universal_start_state(Anchored::Yes) { Some(s) => s, - None => return Err(Error::IndexNoAnchoredUniversalStartState), + None => return Err(Error::DfaHasNoStartState), }; - let mut index: FxHashMap> = FxHashMap::default(); - let mut seen: FxHashSet = FxHashSet::default(); - let mut final_states: FxHashSet = FxHashSet::default(); - let mut next_states: FxHashSet = FxHashSet::from_iter([start_state]); + let mut transitions: HashMap> = HashMap::default(); + let mut final_states: HashSet = HashSet::default(); - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - seen.insert(start_state); + let mut seen: HashSet = HashSet::from_iter([start_state]); + let mut next_states: Vec = vec![start_state]; - if dfa.is_match_state(dfa.next_eoi_state(start_state)) { - final_states.insert(start_state.as_u32()); + while let Some(current_state) = next_states.pop() { + if dfa.is_match_state(dfa.next_eoi_state(current_state)) { + final_states.insert(current_state.as_u32()); } 'token_loop: for (token, ids) in vocabulary.tokens_to_ids().iter() { @@ -138,7 +137,7 @@ impl Index { continue; } - let mut next_state = start_state; + let mut next_state = current_state; for transition_byte in token.as_bytes() { next_state = dfa.next_state(next_state, *transition_byte); if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) { @@ -146,40 +145,33 @@ impl Index { } } - if dfa.is_match_state(next_state) { - // Token either matched or matched except the last character. - // Check what happens if the input suddenly ends after reaching this state. - // If the automata still matches, then token is exactly matched, if not - // then token didn't match. - let next_eoi_state = dfa.next_eoi_state(next_state); - let token_matched = dfa.is_match_state(next_eoi_state); - if !token_matched { - continue; + let is_intermediate_state = !dfa.is_match_state(next_state); + let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state)); + if is_intermediate_state || is_full_match_state { + for token_id in ids { + transitions + .entry(current_state.as_u32()) + .or_default() + .insert(*token_id, next_state.as_u32()); } } - - for token_id in ids { - let mapping = index.entry(start_state.as_u32()).or_default(); - mapping.insert(*token_id, next_state.as_u32()); - - if !seen.contains(&next_state) { - next_states.insert(next_state); - } + if !seen.contains(&next_state) { + seen.insert(next_state); + next_states.push(next_state); } } } - let start_state = start_state.as_u32(); - - // Populate `index` with mappings from `final_states` to `eos_token_id` + // Populate `transitions` with mappings from `final_states` to `eos_token_id` for &final_state in &final_states { - index + transitions .entry(final_state) .or_default() .insert(eos_token_id, final_state); } + // Check if there is at least one valid mapping - let is_valid = index.values().any(|mapping| { + let is_valid = transitions.values().any(|mapping| { mapping .values() .any(|end_state| final_states.contains(end_state)) @@ -187,9 +179,9 @@ impl Index { if is_valid { Ok(Self { - initial: start_state, + initial: start_state.as_u32(), finals: final_states, - states_to_token_subsets: index, + states_to_token_subsets: transitions, eos_token_id, }) } else { @@ -218,11 +210,11 @@ impl Index { self.finals.contains(&state) } - pub(crate) fn final_states(&self) -> &FxHashSet { + pub(crate) fn final_states(&self) -> &HashSet { &self.finals } - pub(crate) fn transitions(&self) -> &FxHashMap> { + pub(crate) fn transitions(&self) -> &HashMap> { &self.states_to_token_subsets } } @@ -243,10 +235,14 @@ mod tests { let index = Index::from_regex(regex, &vocabulary).expect("Index failed"); assert_eq!(index.initial(), 40); - assert_eq!(index.final_states(), &FxHashSet::from_iter([24, 48, 56])); - assert_eq!( - "{24: {3: 24, 4: 24, 2: 24}, 48: {4: 48}, 40: {3: 48, 2: 56}, 56: {3: 24, 4: 56, 2: 24}}", - format!("{:?}", index.transitions()) - ); + assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); + + let expected: HashMap> = HashMap::from_iter([ + (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])), + (48, HashMap::from_iter([(4, 48)])), + (40, HashMap::from_iter([(3, 48), (2, 56)])), + (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])), + ]); + assert_eq!(&expected, index.transitions()); } } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index a3ed0d1..5c76f30 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -10,8 +10,8 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3::wrap_pyfunction; -use serde_json::Value; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use serde_json::Value; #[pyclass(name = "FSMInfo")] pub struct PyFSMInfo {