From a6a88da1b3ef0b22ea46c3574cf202c39d653766 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 15 Jan 2025 21:05:34 +0000 Subject: [PATCH] No Vocabulary is insufficient for Index --- src/error.rs | 2 -- src/index.rs | 65 +++++++++++++++++++++++++++++++++------------------- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/src/error.rs b/src/error.rs index e5781e8..ecb3bbb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,8 +5,6 @@ pub type Result = std::result::Result; #[derive(Error, Debug)] pub enum Error { // Index Errors - #[error("The vocabulary does not allow to build an index that matches the input")] - InsufficientVocabulary, #[error("Failed to build DFA {0}")] IndexDfaError(#[from] Box), #[error("Index failed since anchored universal start state doesn't exist")] diff --git a/src/index.rs b/src/index.rs index ba15235..42185c3 100644 --- a/src/index.rs +++ b/src/index.rs @@ -14,15 +14,44 @@ pub struct Index { initial_state: StateId, /// A collection of states considered as terminal states. final_states: HashSet, - /// A mapping of state transitions, defined by tokens ids and their corresponding state changes: - /// - The outer map's keys are the state IDs. - /// - The inner map's keys are token IDs. - /// - The inner map's values are state IDs, indicating transitions to the next state. + /// A mapping of state transitions, defined by tokens ids and their corresponding state changes. + /// + /// ### Example + /// ``` + /// transitions = { + /// 1: {10: 2, 15: 3}, + /// 2: {20: 4, 25: 3}, + /// 3: {30: 4}, + /// 4: {40: 4}, + /// } + /// +--------------------------------------+ + /// | State 1 | + /// | Initial State | + /// +--------------------------------------+ + /// | | + /// + | + /// Token ID 10 | + /// +-----------------------+ | + /// | State 2 | | + /// +-----------------------+ | + /// | | | + /// | + + + /// | Token ID 25 Token ID 15 + /// | +------------------------+ + /// | | State 3 | + /// | +------------------------+ + /// | | + /// + + + /// Token ID 20 Token ID 30 + /// +--------------------------------------+ + /// | State 4 | + /// | Final state | + /// +--------------------------------------+ + /// ``` transitions: HashMap>, /// The token ID reserved for the "end-of-sequence" token. eos_token_id: TokenId, } - /// The `Index` structure is designed to efficiently map tokens from a given vocabulary /// to state transitions within a finite-state automaton. /// @@ -122,30 +151,19 @@ impl Index { .insert(eos_token_id, final_state); } - // Check if there is at least one valid mapping - let is_valid = transitions.values().any(|mapping| { - mapping - .values() - .any(|end_state| final_states.contains(end_state)) - }); - - if is_valid { - Ok(Self { - initial_state: start_state.as_u32(), - final_states, - transitions, - eos_token_id, - }) - } else { - Err(Error::InsufficientVocabulary) - } + Ok(Self { + initial_state: start_state.as_u32(), + final_states, + transitions, + eos_token_id, + }) } /// Lists allowed tokens for a give state ID or `None` if it is not found in `Index`. pub fn allowed_tokens(&self, state: &StateId) -> Option> { self.transitions .get(state) - .map_or_else(|| None, |res| Some(res.keys().cloned().collect())) + .map(|res| res.keys().cloned().collect()) } /// Returns transition state for a given state and token id or `None` otherwise. @@ -259,7 +277,6 @@ mod tests { } let index = Index::new(regex, &vocabulary).expect("Index failed"); - assert_eq!(index.final_states(), &HashSet::from_iter([208, 128])); let expected = HashMap::from_iter([