From d15b59725477e1b8b59275a5d5ae5d5b8caa9637 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:41:45 -0700 Subject: [PATCH] Update search method to return EntityEdge objects --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/getzep/graphiti?shareId=XXXX-XXXX-XXXX-XXXX). --- graphiti_core/graphiti.py | 6 ++---- tests/test_graphiti_int.py | 12 ++++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 6ff5b52b..1296cc12 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -534,7 +534,7 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu Returns ------- list - A list of facts (strings) that are relevant to the search query. + A list of EntityEdge objects that are relevant to the search query. Notes ----- @@ -564,9 +564,7 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu ) ).edges - facts = [edge.fact for edge in edges] - - return facts + return edges async def _search( self, diff --git a/tests/test_graphiti_int.py b/tests/test_graphiti_int.py index 2ab6c8f0..61e724be 100644 --- a/tests/test_graphiti_int.py +++ b/tests/test_graphiti_int.py @@ -76,17 +76,17 @@ async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) - facts = await graphiti.search('Freakenomics guest') + edges = await graphiti.search('Freakenomics guest') - logger.info('\nQUERY: Freakenomics guest\n' + format_context(facts)) + logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges])) - facts = await graphiti.search('tania tetlow\n') + edges = await graphiti.search('tania tetlow\n') - logger.info('\nQUERY: Tania Tetlow\n' + format_context(facts)) + logger.info('\nQUERY: Tania Tetlow\n' + format_context([edge.fact for edge in edges])) - facts = await graphiti.search('issues with higher ed') + edges = await graphiti.search('issues with higher ed') - logger.info('\nQUERY: issues with higher ed\n' + format_context(facts)) + logger.info('\nQUERY: issues with higher ed\n' + format_context([edge.fact for edge in edges])) graphiti.close()