Skip to content

Commit

Permalink
Update overlap refinery to default to Tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavnicksm committed Feb 4, 2025
1 parent bfd2a47 commit 81573fd
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 9 deletions.
4 changes: 1 addition & 3 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def _encode(self, text: str) -> List[int]:
def _encode_batch(self, texts: List[str]) -> List[List[int]]:
"""Encode a batch of texts using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
return self.tokenizer.batch_encode_plus(texts, add_special_tokens=False)[
"input_ids"
]
return self.tokenizer.batch_encode_plus(texts, add_special_tokens=False)["input_ids"]
elif self._tokenizer_backend == "tokenizers":
return [
t.ids
Expand Down
68 changes: 62 additions & 6 deletions src/chonkie/refinery/overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,69 @@ def __init__(
# If tokenizer provided, we can do exact token counting
if tokenizer is not None:
self.tokenizer = tokenizer
self._tokenizer_backend = self._get_tokenizer_backend()
self.approximate = approximate
else:
# Without tokenizer, must use approximate method
self.approximate = True

# Average number of characters per token
self._AVG_CHAR_PER_TOKEN = 7

def _get_tokenizer_backend(self) -> str:
"""Get the tokenizer backend."""
if "tokenizers" in str(type(self.tokenizer)):
return "tokenizers"
elif "tiktoken" in str(type(self.tokenizer)):
return "tiktoken"
elif "transformers" in str(type(self.tokenizer)):
return "transformers"
else:
raise ValueError(f"Unsupported tokenizer backend: {str(type(self.tokenizer))}")

def _encode(self, text: str) -> List[int]:
"""Encode text using the tokenizer backend."""
if self._tokenizer_backend == "tokenizers":
return self.tokenizer.encode(text).ids
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode(text)
elif self._tokenizer_backend == "transformers":
return self.tokenizer.encode(text, add_special_tokens=False)
else:
raise ValueError(f"Unsupported tokenizer backend: {self._tokenizer_backend}")

def _decode(self, tokens: List[int]) -> str:
"""Decode tokens using the tokenizer backend."""
if self._tokenizer_backend == "tokenizers":
return self.tokenizer.decode(tokens)
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.decode(tokens)
elif self._tokenizer_backend == "transformers":
return self.tokenizer.decode(tokens, skip_special_tokens=True)
else:
raise ValueError(f"Unsupported tokenizer backend: {self._tokenizer_backend}")

def _batch_encode(self, texts: List[str]) -> List[List[int]]:
"""Batch encode texts using the tokenizer backend."""
if self._tokenizer_backend == "tokenizers":
return [t.ids for t in self.tokenizer.encode_batch(texts)]
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.encode_batch(texts)
elif self._tokenizer_backend == "transformers":
return self.tokenizer.batch_encode_plus(texts, add_special_tokens=False)["input_ids"]
else:
raise ValueError(f"Unsupported tokenizer backend: {self._tokenizer_backend}")

def _batch_decode(self, tokens: List[List[int]]) -> List[str]:
"""Batch decode tokens using the tokenizer backend."""
if self._tokenizer_backend == "tokenizers":
return self.tokenizer.decode_batch(tokens)
elif self._tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens)
elif self._tokenizer_backend == "transformers":
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
else:
raise ValueError(f"Unsupported tokenizer backend: {self._tokenizer_backend}")

def _get_refined_chunks(
self, chunks: List[Chunk], inplace: bool = True
Expand Down Expand Up @@ -141,10 +197,10 @@ def _prefix_overlap_token_exact(self, chunk: Chunk) -> Optional[Context]:
text_portion = chunk.text[-char_window:]

# Get exact token boundaries
tokens = self.tokenizer.encode(text_portion) #TODO: should be self._encode; need a unified tokenizer interface
tokens = self._encode(text_portion) #TODO: should be self._encode; need a unified tokenizer interface
context_tokens = min(self.context_size, len(tokens))
context_tokens_ids = tokens[-context_tokens:]
context_text = self.tokenizer.decode(context_tokens_ids) #TODO: should be self._decode; need a unified tokenizer interface
context_text = self._decode(context_tokens_ids) #TODO: should be self._decode; need a unified tokenizer interface

# Find where context text starts in chunk
try:
Expand Down Expand Up @@ -175,10 +231,10 @@ def _suffix_overlap_token_exact(self, chunk: Chunk) -> Optional[Context]:
text_portion = chunk.text[:char_window]

# Get exact token boundaries
tokens = self.tokenizer.encode(text_portion)
tokens = self._encode(text_portion)
context_tokens = min(self.context_size, len(tokens))
context_tokens_ids = tokens[:context_tokens]
context_text = self.tokenizer.decode(context_tokens_ids)
context_text = self._decode(context_tokens_ids)

# Find where context text starts in chunk
try:
Expand Down Expand Up @@ -403,7 +459,7 @@ def _refine_prefix(self, chunks: List[Chunk]) -> List[Chunk]:
if hasattr(self, "tokenizer") and not self.approximate:
# Use exact token count if we have a tokenizer
refined_chunks[i].token_count = len(
self.tokenizer.encode(refined_chunks[i].text)
self._encode(refined_chunks[i].text)
)
else:
# Otherwise use approximate by adding context tokens plus one for space
Expand Down Expand Up @@ -453,7 +509,7 @@ def _refine_suffix(self, chunks: List[Chunk]) -> List[Chunk]:
if hasattr(self, "tokenizer") and not self.approximate:
# Use exact token count if we have a tokenizer
refined_chunks[i].token_count = len(
self.tokenizer.encode(refined_chunks[i].text)
self._encode(refined_chunks[i].text)
)
else:
# Otherwise use approximate by adding context tokens
Expand Down

0 comments on commit 81573fd

Please sign in to comment.