Skip to content

Commit

Permalink
tested
Browse files Browse the repository at this point in the history
  • Loading branch information
duwenxin99 committed Dec 15, 2023
1 parent e038d3e commit 2f2455e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 121 deletions.
75 changes: 71 additions & 4 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@
# limitations under the License.

import os
from datetime import date, timedelta

import aiohttp
import dateutil.parser as dparser
import google.auth.transport.requests # type: ignore
import google.oauth2.id_token # type: ignore
from langchain.agents import AgentType, initialize_agent
from langchain.agents.agent import AgentExecutor
from langchain.globals import set_verbose # type: ignore
from langchain.llms.vertexai import VertexAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts.chat import ChatPromptTemplate

from tools import convert_date, initialize_tools
import aiohttp
from tools import initialize_tools

set_verbose(bool(os.getenv("DEBUG", default=False)))
BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080")

# aiohttp context
connector = None
client_agents = {}


class ClientAgent:
Expand All @@ -37,12 +45,71 @@ def __init__(self, client, agent) -> None:
self.agent = agent


def get_id_token(url: str) -> str:
"""Helper method to generate ID tokens for authenticated requests"""
# Use Application Default Credentials on Cloud Run
if os.getenv("K_SERVICE"):
auth_req = google.auth.transport.requests.Request()
return google.oauth2.id_token.fetch_id_token(auth_req, url)
else:
# Use gcloud credentials locally
import subprocess

return (
subprocess.run(
["gcloud", "auth", "print-identity-token"],
stdout=subprocess.PIPE,
check=True,
)
.stdout.strip()
.decode()
)


def convert_date(date_string: str) -> str:
"""Convert date into appropriate date string"""
if date_string == "tomorrow":
converted = date.today() + timedelta(1)
elif date_string == "yesterday":
converted = date.today() - timedelta(1)
elif date_string != "null" and date_string != "today" and date_string is not None:
converted = dparser.parse(date_string, fuzzy=True).date()
else:
converted = date.today()

return converted.strftime("%Y-%m-%d")


def get_header() -> dict:
headers = {}
if "http://" in BASE_URL:
return headers
else:
# Append ID Token to make authenticated requests to Cloud Run services
headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"}
return headers


async def get_connector():
global connector
if connector is None:
connector = aiohttp.TCPConnector()
connector = aiohttp.TCPConnector(limit=100)
return connector


async def handle_error_response(response):
if response.status != 200:
return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}"


async def create_client_session() -> aiohttp.ClientSession:
return aiohttp.ClientSession(
connector=await get_connector(),
headers=get_header(),
raise_for_status=handle_error_response,
)


# Agent
async def init_agent() -> ClientAgent:
"""Load an agent executor with tools and LLM"""
Expand All @@ -51,7 +118,7 @@ async def init_agent() -> ClientAgent:
memory = ConversationBufferMemory(
memory_key="chat_history", input_key="input", output_key="output"
)
client = aiohttp.ClientSession(connector=await get_connector())
client = await create_client_session()
tools = await initialize_tools(client)
agent = initialize_agent(
tools,
Expand Down
12 changes: 6 additions & 6 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import uuid
from contextlib import asynccontextmanager

import asyncio
import uvicorn
from contextlib import asynccontextmanager
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
Expand All @@ -26,7 +26,7 @@
from markdown import markdown
from starlette.middleware.sessions import SessionMiddleware

from agent import ClientAgent, init_agent
from agent import ClientAgent, client_agents, init_agent


@asynccontextmanager
Expand All @@ -36,8 +36,8 @@ async def lifespan(app: FastAPI):
yield
# FastAPI app shutdown event
close_client_tasks = []
for ca in client_agents:
tasks += asyncio.ensure_task(ca.session.close())
for c in client_agents.items:
tasks += asyncio.ensure_task(c.session.close())

asyncio.gather(close_client_tasks)

Expand Down Expand Up @@ -86,7 +86,7 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
client_agents[request.session["uuid"]] = client_agent
try:
# Send prompt to LLM
response = await agent.ainvoke({"input": prompt})
response = await client_agent.agent.ainvoke({"input": prompt})
request.session["messages"] += [
{"role": "assistant", "content": response["output"]}
]
Expand Down
Loading

0 comments on commit 2f2455e

Please sign in to comment.