forked from neo4j-field/ps-genai-agents
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit_app.py
109 lines (93 loc) · 3.49 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import asyncio
import json
import os
import sys
from typing import Any, Dict, List
import streamlit as st
from dotenv import load_dotenv
from langchain_neo4j import Neo4jGraph
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from neo4j import GraphDatabase
from ps_genai_agents.retrievers.cypher_examples import (
Neo4jVectorSearchCypherExampleRetriever,
)
from ps_genai_agents.ui.components import chat, display_chat_history, sidebar
from ps_genai_agents.workflows.multi_agent import (
create_text2cypher_with_visualization_workflow,
)
if load_dotenv():
print("Env Loaded Successfully!")
else:
print("Unable to Load Environment.")
def get_args() -> Dict[str, Any]:
"""Parse the command line arguments to configure the application."""
args = sys.argv
if len(args) > 1:
config_path: str = args[1]
assert config_path.lower().endswith(
".json"
), f"provided file is not JSON | {config_path}"
with open(config_path, "r") as f:
config: Dict[str, Any] = json.load(f)
else:
config = dict()
return config
def initialize_state(
cypher_query_yaml_file_path: str,
scope_description: str,
example_questions: List[str] = list(),
) -> None:
"""
Initialize the application state.
"""
if "agent" not in st.session_state:
embedder = OpenAIEmbeddings(model="text-embedding-ada-002")
neo4j_driver = GraphDatabase.driver(
uri=os.getenv("NEO4J_URI", ""),
auth=(os.getenv("NEO4J_USERNAME", ""), os.getenv("NEO4J_PASSWORD", "")),
)
vector_index_name = "cypher_query_vector_index"
cypher_example_retriever = Neo4jVectorSearchCypherExampleRetriever(
embedder=embedder,
neo4j_driver=neo4j_driver,
vector_index_name=vector_index_name,
)
st.session_state["llm"] = ChatOpenAI(model="gpt-4o", temperature=0.0)
st.session_state["graph"] = Neo4jGraph(
url=os.environ.get("NEO4J_URI"),
username=os.environ.get("NEO4J_USERNAME"),
password=os.environ.get("NEO4J_PASSWORD"),
enhanced_schema=True,
driver_config={"liveness_check_timeout": 0},
)
st.session_state["agent"] = create_text2cypher_with_visualization_workflow(
llm=st.session_state["llm"],
graph=st.session_state["graph"],
cypher_example_retriever=cypher_example_retriever,
scope_description=scope_description,
max_cypher_generation_attempts=3,
attempt_cypher_execution_on_final_attempt=True,
llm_cypher_validation=False,
)
st.session_state["messages"] = list()
st.session_state["example_questions"] = example_questions
async def run_app(title: str = "Neo4j GenAI Demo") -> None:
"""
Run the Streamlit application.
"""
st.title(title)
sidebar()
display_chat_history()
# Prompt for user input and save and display
if question := st.chat_input():
st.session_state["current_question"] = question
if "current_question" in st.session_state:
await chat(str(st.session_state.get("current_question", "")))
if __name__ == "__main__":
args = get_args()
initialize_state(
cypher_query_yaml_file_path=args.get("cypher_query_yaml_file_path", ""),
scope_description=args.get("scope_description", ""),
example_questions=args.get("example_questions", list()),
)
asyncio.run(run_app(title=args.get("title", "")))