-
Notifications
You must be signed in to change notification settings - Fork 928
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add new agents docs embbeddings functionality
- Loading branch information
1 parent
9c8d1a6
commit d3d60c8
Showing
5 changed files
with
435 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import numpy as np | ||
from typing import TypedDict, Any | ||
from dataclasses import dataclass | ||
from transformers import PreTrainedTokenizer | ||
from agents_api.clients.model import openai_client | ||
from agents_api.clients.embed import embed | ||
|
||
|
||
def normalize_l2(x): | ||
x = np.array(x) | ||
if x.ndim == 1: | ||
norm = np.linalg.norm(x) | ||
if norm == 0: | ||
return x | ||
return x / norm | ||
else: | ||
norm = np.linalg.norm(x, 2, axis=1, keepdims=True) | ||
return np.where(norm == 0, x, x / norm) | ||
|
||
|
||
class ModelNotSupportedError(Exception): | ||
def __init__(self, model_name): | ||
super().__init__(f"model {model_name} is not supporrted") | ||
|
||
|
||
class EmbeddingInput(TypedDict): | ||
instruction: str | None | ||
text: str | ||
|
||
|
||
@dataclass | ||
class EmbeddingModel: | ||
embedding_provider: str | ||
embedding_model_name: str | ||
original_embedding_dimensions: int | ||
output_embedding_dimensions: int | ||
context_window: int | ||
tokenizer: Any | ||
|
||
@classmethod | ||
def from_model_name(cls, model_name: str): | ||
try: | ||
return _embedding_model_registry[model_name] | ||
except KeyError: | ||
raise ModelNotSupportedError(model_name) | ||
|
||
def preprocess(self, inputs: list[EmbeddingInput]): | ||
"""Maybe use this function from embed() to truncate (if needed) or raise an error""" | ||
pass | ||
|
||
async def embed( | ||
self, inputs: list[EmbeddingInput] | ||
) -> list[np.NDArray | list[float]]: | ||
embeddings: list[np.NDArray | list[float]] = [] | ||
input = [f"{input.get('instruction', '')} {input['text']}" for input in inputs] | ||
|
||
if self.embedding_provider == "julep": | ||
embeddings = await embed(input) | ||
elif self.embedding_provider == "openai": | ||
embeddings = ( | ||
await openai_client.embeddings.create( | ||
input=input, model=self.embedding_model_name | ||
) | ||
.data[0] | ||
.embedding | ||
) | ||
|
||
return self.normalize(embeddings) | ||
|
||
def normalize( | ||
self, embeddings: list[np.NDArray | list[float]] | ||
) -> list[np.NDArray | list[float]]: | ||
return [ | ||
( | ||
e | ||
if len(e) <= self.original_embedding_dimensions | ||
else normalize_l2(e[: self.original_embedding_dimensions]) | ||
) | ||
for e in embeddings | ||
] | ||
|
||
|
||
_embedding_model_registry = { | ||
"text-embeddins-3-small": EmbeddingModel( | ||
embedding_provider="openai", | ||
embedding_model_name="text-embeddins-3-small", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("text-embeddins-3-small"), | ||
), | ||
"text-embeddins-3-large": EmbeddingModel( | ||
embedding_provider="openai", | ||
embedding_model_name="text-embeddins-3-large", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("text-embeddins-3-large"), | ||
), | ||
"Alibaba-NLP/gte-large-en-v1.5": EmbeddingModel( | ||
embedding_provider="julep", | ||
embedding_model_name="Alibaba-NLP/gte-large-en-v1.5", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("Alibaba-NLP/gte-large-en-v1.5"), | ||
), | ||
"BAAI/bge-m3": EmbeddingModel( | ||
embedding_provider="julep", | ||
embedding_model_name="BAAI/bge-m3", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/bge-m3"), | ||
), | ||
"BAAI/llm-embedder": EmbeddingModel( | ||
embedding_provider="julep", | ||
embedding_model_name="BAAI/llm-embedder", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/llm-embedder"), | ||
), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
agents-api/migrations/migrate_1714566760_change_embeddings_dimensions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# /usr/bin/env python3 | ||
|
||
MIGRATION_ID = "change_embeddings_dimensions" | ||
CREATED_AT = 1714566760.731964 | ||
|
||
|
||
change_dimensions = { | ||
"up": """ | ||
?[ | ||
doc_id, | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
] := | ||
*information_snippets{ | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
additional_info_id: doc_id, | ||
} | ||
:replace information_snippets { | ||
doc_id: Uuid, | ||
snippet_idx: Int, | ||
=> | ||
title: String, | ||
snippet: String, | ||
embed_instruction: String default 'Encode this passage for retrieval: ', | ||
embedding: <F32; 1024>? default null, | ||
} | ||
""", | ||
"down": """ | ||
?[ | ||
doc_id, | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
] := | ||
*information_snippets{ | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
additional_info_id: doc_id, | ||
} | ||
:replace information_snippets { | ||
doc_id: Uuid, | ||
snippet_idx: Int, | ||
=> | ||
title: String, | ||
snippet: String, | ||
embed_instruction: String default 'Encode this passage for retrieval: ', | ||
embedding: <F32; 768>? default null, | ||
} | ||
""", | ||
} | ||
|
||
information_snippets_hnsw_index = dict( | ||
up=""" | ||
::hnsw drop information_snippets:embedding_space | ||
::hnsw create information_snippets:embedding_space { | ||
fields: [embedding], | ||
filter: !is_null(embedding), | ||
dim: 1024, | ||
distance: Cosine, | ||
m: 64, | ||
ef_construction: 256, | ||
extend_candidates: false, | ||
keep_pruned_connections: false, | ||
} | ||
""", | ||
down=""" | ||
::hnsw drop information_snippets:embedding_space | ||
::hnsw create information_snippets:embedding_space { | ||
fields: [embedding], | ||
filter: !is_null(embedding), | ||
dim: 768, | ||
distance: Cosine, | ||
m: 64, | ||
ef_construction: 256, | ||
extend_candidates: false, | ||
keep_pruned_connections: false, | ||
} | ||
""", | ||
) | ||
|
||
|
||
queries_to_run = [ | ||
change_dimensions, | ||
information_snippets_hnsw_index, | ||
] | ||
|
||
|
||
def up(client): | ||
for q in queries_to_run: | ||
client.run(q["up"]) | ||
|
||
|
||
def down(client): | ||
for q in reversed(queries_to_run): | ||
client.run(q["down"]) |
Oops, something went wrong.