From d08d1d32e2e407de1abba4fe91930f889dcc7add Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 9 Jan 2025 15:00:56 +0000 Subject: [PATCH] Stabilize Index interfaces --- python/outlines_core/fsm/outlines_core_rs.pyi | 5 ++- src/index.rs | 44 +++++++++---------- src/python_bindings/mod.rs | 4 +- tests/fsm/test_index.py | 25 +++++++++++ 4 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index a9c8ccd..37fca77 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -66,6 +66,9 @@ class Vocabulary: def __eq__(self, other: object) -> bool: """Compares whether two vocabularies are the same.""" ... + def __len__(self) -> int: + """Returns length of Vocabulary's tokens, excluding EOS token.""" + ... def __deepcopy__(self, memo: dict) -> "Vocabulary": """Makes a deep copy of the Vocabulary.""" ... @@ -83,7 +86,7 @@ class Index: def is_final_state(self, state: int) -> bool: """Determines whether the current state is a final state.""" ... - def final_states(self) -> List[int]: + def get_final_states(self) -> List[int]: """Get all final states.""" ... def get_transitions(self) -> Dict[int, Dict[int, int]]: diff --git a/src/index.rs b/src/index.rs index 407e471..f6c433b 100644 --- a/src/index.rs +++ b/src/index.rs @@ -10,14 +10,14 @@ use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[derive(Clone, Debug, PartialEq, Encode, Decode)] pub struct Index { - initial: StateId, - finals: HashSet, - states_to_token_subsets: HashMap>, + initial_state: StateId, + final_states: HashSet, + transitions: HashMap>, eos_token_id: TokenId, } impl Index { - pub(crate) fn new(regex: &str, vocabulary: &Vocabulary) -> Result { + pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result { let eos_token_id = vocabulary.eos_token_id(); let dfa = DFA::new(regex).map_err(Box::new)?; let start_state = match dfa.universal_start_state(Anchored::Yes) { @@ -83,9 +83,9 @@ impl Index { if is_valid { Ok(Self { - initial: start_state.as_u32(), - finals: final_states, - states_to_token_subsets: transitions, + initial_state: start_state.as_u32(), + final_states, + transitions, eos_token_id, }) } else { @@ -93,40 +93,40 @@ impl Index { } } - pub(crate) fn allowed_tokens(&self, state: StateId) -> Option> { - self.states_to_token_subsets + pub fn allowed_tokens(&self, state: StateId) -> Option> { + self.transitions .get(&state) .map_or_else(|| None, |res| Some(res.keys().cloned().collect())) } - pub(crate) fn next_state(&self, state: StateId, token_id: TokenId) -> Option { + pub fn next_state(&self, state: StateId, token_id: TokenId) -> Option { if token_id == self.eos_token_id { return None; } - Some(*self.states_to_token_subsets.get(&state)?.get(&token_id)?) + Some(*self.transitions.get(&state)?.get(&token_id)?) } - pub(crate) fn initial(&self) -> StateId { - self.initial + pub fn initial_state(&self) -> StateId { + self.initial_state } - pub(crate) fn is_final(&self, state: StateId) -> bool { - self.finals.contains(&state) + pub fn is_final(&self, state: StateId) -> bool { + self.final_states.contains(&state) } - pub(crate) fn final_states(&self) -> &HashSet { - &self.finals + pub fn final_states(&self) -> &HashSet { + &self.final_states } - pub(crate) fn transitions(&self) -> &HashMap> { - &self.states_to_token_subsets + pub fn transitions(&self) -> &HashMap> { + &self.transitions } } impl std::fmt::Display for Index { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "Index object with transitions:")?; - for (state_id, token_ids) in self.states_to_token_subsets.iter() { + for (state_id, token_ids) in self.transitions.iter() { writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?; } Ok(()) @@ -148,7 +148,7 @@ mod tests { } let index = Index::new(regex, &vocabulary).expect("Index failed"); - assert_eq!(index.initial(), 40); + assert_eq!(index.initial_state(), 40); assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56])); let expected = HashMap::from_iter([ @@ -172,7 +172,7 @@ mod tests { let index = Index::new(regex, &vocabulary).expect("Index failed"); let allowed = index - .allowed_tokens(index.initial()) + .allowed_tokens(index.initial_state()) .expect("No allowed tokens"); assert!(allowed.contains(&101)); } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index be2bf97..a7b7a50 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -133,7 +133,7 @@ impl PyIndex { self.0.is_final(state) } - fn final_states(&self) -> HashSet { + fn get_final_states(&self) -> HashSet { self.0.final_states().clone() } @@ -142,7 +142,7 @@ impl PyIndex { } fn get_initial_state(&self) -> StateId { - self.0.initial() + self.0.initial_state() } fn __repr__(&self) -> String { format!("{:#?}", self.0) diff --git a/tests/fsm/test_index.py b/tests/fsm/test_index.py index 799b468..5b56088 100644 --- a/tests/fsm/test_index.py +++ b/tests/fsm/test_index.py @@ -18,6 +18,31 @@ def index() -> Index: return Index(regex, vocabulary) +def test_basic_interface(index): + init_state = index.get_initial_state() + assert init_state == 12 + assert index.is_final_state(init_state) is False + + allowed_tokens = index.get_allowed_tokens(init_state) + assert allowed_tokens == [1, 2] + + next_state = index.get_next_state(init_state, allowed_tokens[-1]) + assert next_state == 20 + assert index.is_final_state(next_state) is True + assert index.get_final_states() == {20} + + expected_transitions = { + 12: { + 1: 20, + 2: 20, + }, + 20: { + 3: 20, + }, + } + assert index.get_transitions() == expected_transitions + + def test_pickling(index): serialized = pickle.dumps(index) deserialized = pickle.loads(serialized)