Skip to content

Commit

Permalink
feat: Add new agents docs embbeddings functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed May 1, 2024
1 parent 9c8d1a6 commit d3d60c8
Show file tree
Hide file tree
Showing 5 changed files with 435 additions and 2 deletions.
124 changes: 124 additions & 0 deletions agents-api/agents_api/embed_models_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
from typing import TypedDict, Any
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from agents_api.clients.model import openai_client
from agents_api.clients.embed import embed


def normalize_l2(x):
x = np.array(x)
if x.ndim == 1:
norm = np.linalg.norm(x)
if norm == 0:
return x
return x / norm
else:
norm = np.linalg.norm(x, 2, axis=1, keepdims=True)
return np.where(norm == 0, x, x / norm)


class ModelNotSupportedError(Exception):
def __init__(self, model_name):
super().__init__(f"model {model_name} is not supporrted")


class EmbeddingInput(TypedDict):
instruction: str | None
text: str


@dataclass
class EmbeddingModel:
embedding_provider: str
embedding_model_name: str
original_embedding_dimensions: int
output_embedding_dimensions: int
context_window: int
tokenizer: Any

@classmethod
def from_model_name(cls, model_name: str):
try:
return _embedding_model_registry[model_name]
except KeyError:
raise ModelNotSupportedError(model_name)

def preprocess(self, inputs: list[EmbeddingInput]):
"""Maybe use this function from embed() to truncate (if needed) or raise an error"""
pass

async def embed(
self, inputs: list[EmbeddingInput]
) -> list[np.NDArray | list[float]]:
embeddings: list[np.NDArray | list[float]] = []
input = [f"{input.get('instruction', '')} {input['text']}" for input in inputs]

if self.embedding_provider == "julep":
embeddings = await embed(input)
elif self.embedding_provider == "openai":
embeddings = (
await openai_client.embeddings.create(
input=input, model=self.embedding_model_name
)
.data[0]
.embedding
)

return self.normalize(embeddings)

def normalize(
self, embeddings: list[np.NDArray | list[float]]
) -> list[np.NDArray | list[float]]:
return [
(
e
if len(e) <= self.original_embedding_dimensions
else normalize_l2(e[: self.original_embedding_dimensions])
)
for e in embeddings
]


_embedding_model_registry = {
"text-embeddins-3-small": EmbeddingModel(
embedding_provider="openai",
embedding_model_name="text-embeddins-3-small",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("text-embeddins-3-small"),
),
"text-embeddins-3-large": EmbeddingModel(
embedding_provider="openai",
embedding_model_name="text-embeddins-3-large",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("text-embeddins-3-large"),
),
"Alibaba-NLP/gte-large-en-v1.5": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="Alibaba-NLP/gte-large-en-v1.5",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("Alibaba-NLP/gte-large-en-v1.5"),
),
"BAAI/bge-m3": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="BAAI/bge-m3",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/bge-m3"),
),
"BAAI/llm-embedder": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="BAAI/llm-embedder",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/llm-embedder"),
),
}
10 changes: 8 additions & 2 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
PatchToolRequest,
PatchAgentRequest,
)
from agents_api.env import embedding_model_id
from agents_api.embed_models_registry import EmbeddingModel


class AgentList(BaseModel):
Expand Down Expand Up @@ -319,9 +321,13 @@ async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedRes
)

indices, snippets = list(zip(*enumerate(content)))
embeddings = await embed(
model = EmbeddingModel.from_model_name(embedding_model_id)
embeddings = await model.embed(
[
snippet_embed_instruction + request.title + "\n\n" + snippet
{
"instruction": snippet_embed_instruction,
"text": request.title + "\n\n" + snippet,
}
for snippet in snippets
]
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# /usr/bin/env python3

MIGRATION_ID = "change_embeddings_dimensions"
CREATED_AT = 1714566760.731964


change_dimensions = {
"up": """
?[
doc_id,
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
] :=
*information_snippets{
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
additional_info_id: doc_id,
}
:replace information_snippets {
doc_id: Uuid,
snippet_idx: Int,
=>
title: String,
snippet: String,
embed_instruction: String default 'Encode this passage for retrieval: ',
embedding: <F32; 1024>? default null,
}
""",
"down": """
?[
doc_id,
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
] :=
*information_snippets{
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
additional_info_id: doc_id,
}
:replace information_snippets {
doc_id: Uuid,
snippet_idx: Int,
=>
title: String,
snippet: String,
embed_instruction: String default 'Encode this passage for retrieval: ',
embedding: <F32; 768>? default null,
}
""",
}

information_snippets_hnsw_index = dict(
up="""
::hnsw drop information_snippets:embedding_space
::hnsw create information_snippets:embedding_space {
fields: [embedding],
filter: !is_null(embedding),
dim: 1024,
distance: Cosine,
m: 64,
ef_construction: 256,
extend_candidates: false,
keep_pruned_connections: false,
}
""",
down="""
::hnsw drop information_snippets:embedding_space
::hnsw create information_snippets:embedding_space {
fields: [embedding],
filter: !is_null(embedding),
dim: 768,
distance: Cosine,
m: 64,
ef_construction: 256,
extend_candidates: false,
keep_pruned_connections: false,
}
""",
)


queries_to_run = [
change_dimensions,
information_snippets_hnsw_index,
]


def up(client):
for q in queries_to_run:
client.run(q["up"])


def down(client):
for q in reversed(queries_to_run):
client.run(q["down"])
Loading

0 comments on commit d3d60c8

Please sign in to comment.