Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mmr reranking #180

Merged
merged 10 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add mmr options to search
  • Loading branch information
prasmussen15 committed Oct 7, 2024
commit fe35c8bfb530b6a634c0b579941da32b6847e611
4 changes: 2 additions & 2 deletions graphiti_core/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from openai import AsyncOpenAI
from openai.types import EmbeddingModel

from .client import EmbedderClient, EmbedderConfig
from ..helpers import normalize_l2
from .client import EmbedderClient, EmbedderConfig

DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'

Expand All @@ -43,7 +43,7 @@ def __init__(self, config: OpenAIEmbedderConfig | None = None):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)

async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
return normalize_l2(result.data[0].embedding[: self.config.embedding_dim])
4 changes: 2 additions & 2 deletions graphiti_core/embedder/voyage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import voyageai # type: ignore
from pydantic import Field

from .client import EmbedderClient, EmbedderConfig
from ..helpers import normalize_l2
from .client import EmbedderClient, EmbedderConfig

DEFAULT_EMBEDDING_MODEL = 'voyage-3'

Expand All @@ -42,7 +42,7 @@ def __init__(self, config: VoyageAIEmbedderConfig | None = None):
self.client = voyageai.AsyncClient(api_key=config.api_key)

async def create(
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embed(input, model=self.config.embedding_model)
return normalize_l2(result.embeddings[0][: self.config.embedding_dim])
43 changes: 40 additions & 3 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
edge_fulltext_search,
edge_similarity_search,
episode_mentions_reranker,
maximal_marginal_relevance,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
Expand Down Expand Up @@ -117,12 +118,14 @@ async def edge_search(
if config is None:
return []

query_vector = await embedder.create(input=[query])

search_results: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
edge_fulltext_search(driver, query, None, None, group_ids, 2 * limit),
edge_similarity_search(
driver, await embedder.create(input=[query]), None, None, group_ids, 2 * limit
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
),
]
)
Expand All @@ -135,6 +138,15 @@ async def edge_search(
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

reranked_uuids = rrf(search_result_uuids)
elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = [
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
for result in search_results
for edge in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
Expand Down Expand Up @@ -175,12 +187,14 @@ async def node_search(
if config is None:
return []

query_vector = await embedder.create(input=[query])

search_results: list[list[EntityNode]] = list(
await asyncio.gather(
*[
node_fulltext_search(driver, query, group_ids, 2 * limit),
node_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
),
]
)
Expand All @@ -192,6 +206,15 @@ async def node_search(
reranked_uuids: list[str] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = [
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
for result in search_results
for node in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids)
elif config.reranker == NodeReranker.node_distance:
Expand All @@ -217,12 +240,14 @@ async def community_search(
if config is None:
return []

query_vector = await embedder.create(input=[query])

search_results: list[list[CommunityNode]] = list(
await asyncio.gather(
*[
community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search(
driver, await embedder.create(input=[query]), group_ids, 2 * limit
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
),
]
)
Expand All @@ -236,6 +261,18 @@ async def community_search(
reranked_uuids: list[str] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = [
(
community.uuid,
community.name_embedding if community.name_embedding is not None else [0.0] * 1024,
)
for result in search_results
for community in result
]
reranked_uuids = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda
)

reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

Expand Down
10 changes: 10 additions & 0 deletions graphiti_core/search/search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_utils import DEFAULT_MIN_SCORE, DEFAULT_MMR_LAMBDA

DEFAULT_SEARCH_LIMIT = 10

Expand All @@ -43,31 +44,40 @@ class EdgeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
mmr = 'mmr'


class NodeReranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'
episode_mentions = 'episode_mentions'
mmr = 'mmr'


class CommunityReranker(Enum):
rrf = 'reciprocal_rank_fusion'
mmr = 'mmr'


class EdgeSearchConfig(BaseModel):
search_methods: list[EdgeSearchMethod]
reranker: EdgeReranker = Field(default=EdgeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)


class NodeSearchConfig(BaseModel):
search_methods: list[NodeSearchMethod]
reranker: NodeReranker = Field(default=NodeReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)


class CommunitySearchConfig(BaseModel):
search_methods: list[CommunitySearchMethod]
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)


class SearchConfig(BaseModel):
Expand Down
40 changes: 40 additions & 0 deletions graphiti_core/search/search_config_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@
),
)

# Performs a hybrid search with mmr reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
),
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
),
)

# performs a hybrid search over edges with rrf reranking
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
edge_config=EdgeSearchConfig(
Expand All @@ -51,6 +67,14 @@
)
)

# performs a hybrid search over edges with mmr reranking
EDGE_HYBRID_SEARCH_mmr = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
)
)

# performs a hybrid search over edges with node distance reranking
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
edge_config=EdgeSearchConfig(
Expand All @@ -75,6 +99,14 @@
)
)

# performs a hybrid search over nodes with mmr reranking
NODE_HYBRID_SEARCH_MMR = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
)
)

# performs a hybrid search over nodes with node distance reranking
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
node_config=NodeSearchConfig(
Expand All @@ -98,3 +130,11 @@
reranker=CommunityReranker.rrf,
)
)

# performs a hybrid search over communities with mmr reranking
COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
)
)
Loading