Skip to content

Commit

Permalink
update functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Mar 5, 2024
1 parent 2a5896d commit 25ae764
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
)

from ..orchestrator import BaseOrchestrator, classproperty
from .functions import assistant_tool, function_request
from .functions import BASE_URL, assistant_tool, function_request, get_headers

BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080")
DEBUG = os.getenv("DEBUG", default=False)
BASE_HISTORY = {
"type": "ai",
Expand Down
32 changes: 32 additions & 0 deletions llm_demo/orchestrator/vertexai_function_calling/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@

import os

import aiohttp
from vertexai.preview import generative_models # type: ignore

BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080")
CREDENTIALS = None

search_airports_func = generative_models.FunctionDeclaration(
name="airports_search",
description="Use this tool to list all airports matching search criteria. Takes at least one of country, city, name, or all of the above criteria. This function could also be used to search for airport information such as iata code.",
Expand Down Expand Up @@ -88,6 +92,34 @@
)


def get_id_token():
global CREDENTIALS
if CREDENTIALS is None:
CREDENTIALS, _ = google.auth.default()
if not hasattr(CREDENTIALS, "id_token"):
# Use Compute Engine default credential
CREDENTIALS = compute_engine.IDTokenCredentials(
request=Request(),
target_audience=BASE_URL,
use_metadata_identity_endpoint=True,
)
if not CREDENTIALS.valid:
CREDENTIALS.refresh(Request())
if hasattr(CREDENTIALS, "id_token"):
return CREDENTIALS.id_token
else:
return CREDENTIALS.token


def get_headers(client: aiohttp.ClientSession):
"""Helper method to generate ID tokens for authenticated requests"""
headers = client.headers
if not "http://" in BASE_URL:
# Append ID Token to make authenticated requests to Cloud Run services
headers["Authorization"] = f"Bearer {get_id_token()}"
return headers


def function_request(function_call_name: str) -> str:
functions_url = {
"airports_search": "airports/search",
Expand Down

0 comments on commit 25ae764

Please sign in to comment.