Skip to content

Commit

Permalink
Add docs, polish interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Jan 13, 2025
1 parent 3fef1d8 commit 73e4bfe
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 24 deletions.
6 changes: 3 additions & 3 deletions benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ def setup(self, pattern_name):
def time_regex_to_guide(self, pattern_name):
Index(self.pattern, self.vocabulary)

def time_regex_to_guide_parallel(self, pattern_name):
def time_regex_to_guide_threads(self, pattern_name):
# Default GIL switch interval is 5ms (0.005), which isn't helpful for cpu heavy tasks,
# this parallel case should be relatively close in runtime to one thread, but it is not,
# because of the GIL.
core_count = psutil.cpu_count(logical=False)
with ThreadPoolExecutor(max_workers=core_count) as executor:
list(executor.map(self._from_regex, [pattern_name] * core_count))

def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name):
def time_regex_to_guide_threads_with_custom_switch_interval(self, pattern_name):
# Note: after moving to full rust implementation for index and guide creation, this experiment
# is no longer shows the drastic difference as it once showed when python was heavily involved,
# due to speedup up to ~100 times.
# due to average speedup ~10 times.

# This test is to show, that if GIL's switch interval is set to be longer, then the parallel
# test's runtime on physical cores will be much closer to the one-threaded case.
Expand Down
81 changes: 71 additions & 10 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,61 @@ use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};

#[derive(Clone, Debug, PartialEq, Encode, Decode)]
pub struct Index {
/// The ID of the initial state in the automaton, processing begins from this state.
initial_state: StateId,
/// A collection of states considered as terminal states.
final_states: HashSet<StateId>,
/// A mapping of state transitions, defined by tokens ids and their corresponding state changes:
/// - The outer map's keys are the state IDs.
/// - The inner map's keys are token IDs.
/// - The inner map's values are state IDs, indicating transitions to the next state.
transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
/// The token ID reserved for the "end-of-sequence" token.
eos_token_id: TokenId,
}

/// The `Index` structure is designed to efficiently map tokens from a given vocabulary
/// to state transitions within a finite-state automaton.
///
/// ## Usage:
/// The `Index` is typically constructed by combining a vocabulary and regular expressions.
/// Once built, it can be used to efficiently evaluate token sequences or to validate input data.
///
/// ## Example:
/// ```rust
/// use outlines_core::prelude::*;
///
/// # fn run() -> Result<(), outlines_core::Error> {
/// let regex = "0|[1-9][0-9]*";
/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None)?;
/// let index = Index::new(regex, &vocabulary)?;
///
/// let initial_state = index.initial_state();
/// println!("Initial state is {}", initial_state);
/// println!("Is initial state a final state? {}", index.is_final_state(&initial_state));
///
/// let allowed_tokens = index.allowed_tokens(&initial_state).unwrap();
/// println!("Allowed tokens at initial state are {:?}", allowed_tokens);
///
/// let token_id = allowed_tokens.first().unwrap();
/// println!("Next state for the token_id {} is {:?}", token_id, index.next_state(&initial_state, token_id));
///
/// println!("Final states are {:?}", index.final_states());
/// println!("Index has exactly {} transitions", index.transitions().len());
/// # Ok(())
/// # }
///
/// ```
///
/// ## Performance:
/// - **Complexity**:
/// The `Index` can accommodate large vocabularies and complex regular expressions.
/// However, its size may grow significantly with the complexity of the input.
/// - **Construction Cost**:
/// Building the `Index` involves processing the vocabulary and regular expressions,
/// which may require a considerable amount of time and computational resources.
impl Index {
/// Builds an `Index` from regular expression and vocabulary tokens.
pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let eos_token_id = vocabulary.eos_token_id();
let dfa = DFA::new(regex).map_err(Box::new)?;
Expand Down Expand Up @@ -93,31 +141,37 @@ impl Index {
}
}

pub fn allowed_tokens(&self, state: StateId) -> Option<Vec<TokenId>> {
/// Lists allowed tokens for a give state ID or `None` if it is not found in `Index`.
pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
self.transitions
.get(&state)
.get(state)
.map_or_else(|| None, |res| Some(res.keys().cloned().collect()))
}

pub fn next_state(&self, state: StateId, token_id: TokenId) -> Option<StateId> {
if token_id == self.eos_token_id {
/// Returns transition state for a given state and token id or `None` otherwise.
pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
if token_id == &self.eos_token_id {
return None;
}
Some(*self.transitions.get(&state)?.get(&token_id)?)
Some(*self.transitions.get(state)?.get(token_id)?)
}

/// Returns the ID of the initial state in the automaton.
pub fn initial_state(&self) -> StateId {
self.initial_state
}

pub fn is_final(&self, state: StateId) -> bool {
self.final_states.contains(&state)
/// Checks if state is in final states set or not.
pub fn is_final_state(&self, state: &StateId) -> bool {
self.final_states.contains(state)
}

/// Returns set of final states.
pub fn final_states(&self) -> &HashSet<StateId> {
&self.final_states
}

/// Returns state transitions map of tokens ids and their corresponding transition states.
pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
&self.transitions
}
Expand Down Expand Up @@ -146,10 +200,11 @@ mod tests {
.try_insert(token, token_id as u32)
.expect("Insert failed");
}

let index = Index::new(regex, &vocabulary).expect("Index failed");
assert_eq!(index.initial_state(), 40);
let initial_state = index.initial_state();
assert_eq!(initial_state, 40);
assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
assert!(!index.is_final_state(&initial_state));

let expected = HashMap::from_iter([
(24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
Expand All @@ -158,6 +213,12 @@ mod tests {
(56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
]);
assert_eq!(index.transitions(), &expected);

let allowed_tokens = index
.allowed_tokens(&initial_state)
.expect("No allowed tokens");
let token_id = allowed_tokens.first().expect("No first tokens");
assert_eq!(index.next_state(&initial_state, token_id), Some(48));
}

#[test]
Expand All @@ -172,7 +233,7 @@ mod tests {

let index = Index::new(regex, &vocabulary).expect("Index failed");
let allowed = index
.allowed_tokens(index.initial_state())
.allowed_tokens(&index.initial_state())
.expect("No allowed tokens");
assert!(allowed.contains(&101));
}
Expand Down
2 changes: 2 additions & 0 deletions src/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub use tokenizers::FromPretrainedParameters;

pub use super::{
index::Index,
primitives::{StateId, Token, TokenId},
Expand Down
6 changes: 3 additions & 3 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ impl PyIndex {
}

fn get_allowed_tokens(&self, state: StateId) -> Option<Vec<TokenId>> {
self.0.allowed_tokens(state)
self.0.allowed_tokens(&state)
}

fn get_next_state(&self, state: StateId, token_id: TokenId) -> Option<StateId> {
self.0.next_state(state, token_id)
self.0.next_state(&state, &token_id)
}

fn is_final_state(&self, state: StateId) -> bool {
self.0.is_final(state)
self.0.is_final_state(&state)
}

fn get_final_states(&self) -> HashSet<StateId> {
Expand Down
35 changes: 27 additions & 8 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use bincode::{Decode, Encode};
use rustc_hash::FxHashMap as HashMap;

use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};
use tokenizers::{NormalizerWrapper, Tokenizer};

use crate::prelude::*;
use crate::{Error, Result};
Expand All @@ -19,18 +19,37 @@ mod processor;
///
/// ### Create a vocabulary from a pretrained model.
/// ```rust
/// # use outlines_core::prelude::*;
/// #
/// use outlines_core::prelude::*;
///
/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None);
/// ```
///
/// ### Create an empty vocabulary and manually insert tokens.
/// ### Create a vocabulary from a pretrained model with some additional parameters.
/// ``` rust
/// use outlines_core::prelude::*;
///
/// let params = FromPretrainedParameters {
/// revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e".to_string(),
/// ..Default::default()
/// };
/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", Some(params));
///
/// ```
///
/// ### Create an empty vocabulary and manually insert some tokens.
/// ```rust
/// # use outlines_core::prelude::*;
/// #
/// use outlines_core::prelude::*;
///
/// let eos_token_id = 1;
/// let mut vocabulary = Vocabulary::new(eos_token_id);
///
/// vocabulary.try_insert("token", 0).expect("New token inserted");
/// assert_eq!(vocabulary.token_to_ids("token"), Some(&vec![0]));
/// assert_eq!(vocabulary.tokens_to_ids().len(), 1);
/// assert_eq!(vocabulary.eos_token_id(), eos_token_id);
///
/// vocabulary.remove("token");
/// assert_eq!(vocabulary.token_to_ids("token"), None);
/// ```
#[derive(Clone, Debug, Default, PartialEq, Encode, Decode)]
pub struct Vocabulary {
Expand All @@ -57,7 +76,7 @@ impl Vocabulary {
Ok(())
}

/// Removes a token from the vocabulary.
/// Removes given token from the vocabulary.
pub fn remove(&mut self, token: impl Into<Token>) {
let token = token.into();
self.tokens.remove(&token);
Expand Down Expand Up @@ -119,7 +138,7 @@ impl Vocabulary {
&self.tokens
}

/// Per provided token returns vector of `TokenId`s if available in the vocabulary.
/// Returns all token ids per provided token if available in the vocabulary.
pub fn token_to_ids(&self, token: impl AsRef<[u8]>) -> Option<&Vec<TokenId>> {
self.tokens.get(token.as_ref())
}
Expand Down
2 changes: 2 additions & 0 deletions tests/fsm/test_statistical.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Callable, List, Optional

import numpy as np
import pytest
from outlines_core.fsm import Guide, Index, Vocabulary
from pytest import approx
from scipy.stats import ks_2samp


@pytest.mark.skip("Needs fixing")
def test_generate_length():
class NextToken:
def __init__(
Expand Down

0 comments on commit 73e4bfe

Please sign in to comment.