Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add orchestration interface #226

Merged
merged 9 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 0 additions & 169 deletions langchain_tools_demo/agent.py

This file was deleted.

6 changes: 5 additions & 1 deletion langchain_tools_demo/int.tests.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ steps:
gcloud run deploy ${_SERVICE} \
--source . \
--region ${_REGION} \
--no-allow-unauthenticated
--no-allow-unauthenticated \
--update-env-vars ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE}

- id: "Test Frontend"
name: "gcr.io/cloud-builders/gcloud:latest"
entrypoint: /bin/bash
env: # Set env var expected by app
args:
- "-c"
- |
export URL=$(gcloud run services describe ${_SERVICE} --region ${_REGION} --format 'value(status.url)')
export ID_TOKEN=$(gcloud auth print-identity-token --audiences $$URL)
export ORCHESTRATION_TYPE=${_ORCHESTRATION_TYPE}

# Test `/` route
curl -c cookies.txt -si --fail --show-error -H "Authorization: Bearer $$ID_TOKEN" $$URL
Expand Down Expand Up @@ -77,3 +80,4 @@ substitutions:
_GCR_HOSTNAME: ${_REGION}-docker.pkg.dev
_SERVICE: demo-service-${BUILD_ID}
_REGION: us-central1
_ORCHESTRATION_TYPE: langchain-tools
84 changes: 27 additions & 57 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import uuid
from contextlib import asynccontextmanager
from typing import Any, Optional
from typing import Optional

import uvicorn
from fastapi import APIRouter, Body, FastAPI, HTTPException, Request
from fastapi.responses import PlainTextResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
message_to_dict,
messages_from_dict,
messages_to_dict,
)
from markdown import markdown
from starlette.middleware.sessions import SessionMiddleware

from agent import init_agent, user_agents
from orchestrator import BaseOrchestrator, createOrchestrator

BASE_HISTORY: list[BaseMessage] = [
AIMessage(content="I am an SFO Airport Assistant, ready to assist you.")
]
routes = APIRouter()
templates = Jinja2Templates(directory="templates")

Expand All @@ -49,19 +36,16 @@ async def lifespan(app: FastAPI):
print("Loading application...")
yield
# FastAPI app shutdown event
close_client_tasks = [
asyncio.create_task(a.client.close()) for a in user_agents.values()
]

asyncio.gather(*close_client_tasks)
app.state.orchestration_type.close_clients()


@routes.get("/")
@routes.post("/")
async def index(request: Request):
"""Render the default template."""
# Agent setup
agent = await get_agent(request.session, user_id_token=None)
orchestrator = request.app.state.orchestration_type
await orchestrator.user_session_create(request.session)
return templates.TemplateResponse(
"index.html",
{
Expand All @@ -81,7 +65,8 @@ async def login_google(
if user_id_token is None:
raise HTTPException(status_code=401, detail="No user credentials found")
# create new request session
_ = await get_agent(request.session, str(user_id_token))
orchestrator = request.app.state.orchestration_type
orchestrator.set_user_session_header(request.session["uuid"], str(user_id_token))
print("Logged in to Google.")

# Redirect to source URL
Expand All @@ -101,34 +86,12 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
)

# Add user message to chat history
request.session["history"].append(message_to_dict(HumanMessage(content=prompt)))
user_agent = await get_agent(request.session, user_id_token=None)
try:
print(prompt)
# Send prompt to LLM
response = await user_agent.agent.ainvoke({"input": prompt})
# Return assistant response
request.session["history"].append(
message_to_dict(AIMessage(content=response["output"]))
)
return markdown(response["output"])
except Exception as err:
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")


async def get_agent(session: dict[str, Any], user_id_token: Optional[str]):
global user_agents
if "uuid" not in session:
session["uuid"] = str(uuid.uuid4())
id = session["uuid"]
if "history" not in session:
session["history"] = messages_to_dict(BASE_HISTORY)
if id not in user_agents:
user_agents[id] = await init_agent(messages_from_dict(session["history"]))
user_agent = user_agents[id]
if user_id_token is not None:
user_agent.client.headers["User-Id-Token"] = f"Bearer {user_id_token}"
return user_agent
request.session["history"].append({"type": "human", "data": {"content": prompt}})
orchestrator = request.app.state.orchestration_type
output = await orchestrator.user_session_invoke(request.session["uuid"], prompt)
# Return assistant response
request.session["history"].append({"type": "ai", "data": {"content": output}})
return markdown(output)


@routes.post("/reset")
Expand All @@ -139,19 +102,25 @@ async def reset(request: Request):
raise HTTPException(status_code=400, detail=f"No session to reset.")

uuid = request.session["uuid"]
global user_agents
if uuid not in user_agents.keys():
raise HTTPException(status_code=500, detail=f"Current agent not found")
orchestrator = request.app.state.orchestration_type
if not orchestrator.user_session_exist(uuid):
raise HTTPException(status_code=500, detail=f"Current user session not found")

await user_agents[uuid].client.close()
del user_agents[uuid]
await orchestrator.user_session_reset(uuid)
request.session.clear()


def init_app(client_id: Optional[str], secret_key: Optional[str]) -> FastAPI:
def init_app(
orchestration_type: Optional[str],
client_id: Optional[str],
secret_key: Optional[str],
) -> FastAPI:
# FastAPI setup
if orchestration_type is None:
raise HTTPException(status_code=500, detail="Orchestrator not found")
app = FastAPI(lifespan=lifespan)
app.state.client_id = client_id
app.state.orchestration_type = createOrchestrator(orchestration_type)
app.include_router(routes)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(SessionMiddleware, secret_key=secret_key)
Expand All @@ -161,9 +130,10 @@ def init_app(client_id: Optional[str], secret_key: Optional[str]) -> FastAPI:
if __name__ == "__main__":
PORT = int(os.getenv("PORT", default=8081))
HOST = os.getenv("HOST", default="0.0.0.0")
ORCHESTRATION_TYPE = os.getenv("ORCHESTRATION_TYPE")
CLIENT_ID = os.getenv("CLIENT_ID")
SECRET_KEY = os.getenv("SECRET_KEY")
app = init_app(client_id=CLIENT_ID, secret_key=SECRET_KEY)
app = init_app(ORCHESTRATION_TYPE, client_id=CLIENT_ID, secret_key=SECRET_KEY)
if app is None:
raise TypeError("app not instantiated")
uvicorn.run(app, host=HOST, port=PORT)
18 changes: 18 additions & 0 deletions langchain_tools_demo/orchestrator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import langchain_tools
from .orchestrator import BaseOrchestrator, createOrchestrator

__ALL__ = ["BaseOrchestrator", "createOrchestrator", "langchain_tools"]
17 changes: 17 additions & 0 deletions langchain_tools_demo/orchestrator/langchain_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .langchain_tools_orchestrator import LangChainToolsOrchestrator

__ALL__ = ["LangChainToolsOrchestrator"]
Loading