-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
52ebb43
commit 734fdd1
Showing
4 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
[project] | ||
name = "toolbox_llamaindex_sdk" | ||
version="0.0.1" | ||
description = "Python SDK for interacting with the Toolbox service with Llamaindex" | ||
license = {file = "LICENSE"} | ||
requires-python = ">=3.9" | ||
authors = [ | ||
{name = "Google LLC", email = "googleapis-packages@google.com"} | ||
] | ||
dependencies = [ | ||
"aiohttp", | ||
"PyYAML", | ||
"llama-index", | ||
"pydantic", | ||
"pytest-asyncio", | ||
] | ||
|
||
classifiers = [ | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
] | ||
|
||
[project.urls] | ||
Homepage = "/~https://github.com/googleapis/genai-toolbox" | ||
Repository = "/~https://github.com/googleapis/genai-toolbox.git" | ||
"Bug Tracker" = "/~https://github.com/googleapis/genai-toolbox/issues" | ||
|
||
[project.optional-dependencies] | ||
test = [ | ||
"black[jupyter]", | ||
"isort", | ||
"mypy", | ||
"pytest-asyncio", | ||
"pytest", | ||
"pytest-cov", | ||
"Pillow" | ||
] | ||
|
||
[build-system] | ||
requires = ["setuptools"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.black] | ||
target-version = ['py39'] | ||
|
||
[tool.isort] | ||
profile = "black" | ||
|
||
[tool.mypy] | ||
python_version = "3.9" | ||
warn_unused_configs = true | ||
disallow_incomplete_defs = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .client import ToolboxClient | ||
# import utils | ||
|
||
__all__ = ["ToolboxClient"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Optional | ||
|
||
from aiohttp import ClientSession | ||
from llama_index.core.tools import FunctionTool | ||
|
||
from .utils import ManifestSchema, _invoke_tool, _load_yaml, _schema_to_model | ||
|
||
|
||
class ToolboxClient: | ||
def __init__(self, url: str, session: ClientSession): | ||
""" | ||
Initializes the ToolboxClient for the Toolbox service at the given URL. | ||
Args: | ||
url: The base URL of the Toolbox service. | ||
session: The HTTP client session. | ||
""" | ||
self._url: str = url | ||
self._session = session | ||
|
||
async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema: | ||
""" | ||
Fetches and parses the YAML manifest for the given tool from the Toolbox service. | ||
Args: | ||
tool_name: The name of the tool to load. | ||
Returns: | ||
The parsed Toolbox manifest. | ||
""" | ||
url = f"{self._url}/api/tool/{tool_name}" | ||
return await _load_yaml(url, self._session) | ||
|
||
async def _load_toolset_manifest( | ||
self, toolset_name: Optional[str] = None | ||
) -> ManifestSchema: | ||
""" | ||
Fetches and parses the YAML manifest from the Toolbox service. | ||
Args: | ||
toolset_name: The name of the toolset to load. | ||
Default: None. If not provided, then all the available tools are loaded. | ||
Returns: | ||
The parsed Toolbox manifest. | ||
""" | ||
url = f"{self._url}/api/toolset/{toolset_name or ''}" | ||
return await _load_yaml(url, self._session) | ||
|
||
def _generate_tool(self, tool_name: str, manifest: ManifestSchema) -> FunctionTool: | ||
""" | ||
Creates a FunctionTool object and a dynamically generated BaseModel for the given tool. | ||
Args: | ||
tool_name: The name of the tool to generate. | ||
manifest: The parsed Toolbox manifest. | ||
Returns: | ||
The generated tool. | ||
""" | ||
tool_schema = manifest.tools[tool_name] | ||
tool_model = _schema_to_model( | ||
model_name=tool_name, schema=tool_schema.parameters | ||
) | ||
|
||
async def _tool_func(**kwargs) -> dict: | ||
return await _invoke_tool(self._url, self._session, tool_name, kwargs) | ||
|
||
return FunctionTool.from_defaults( | ||
async_fn=_tool_func, | ||
name=tool_name, | ||
description=tool_schema.description, | ||
fn_schema=tool_model, | ||
) | ||
|
||
async def load_tool(self, tool_name: str) -> FunctionTool: | ||
""" | ||
Loads the tool, with the given tool name, from the Toolbox service. | ||
Args: | ||
toolset_name: The name of the toolset to load. | ||
Default: None. If not provided, then all the tools are loaded. | ||
Returns: | ||
A tool loaded from the Toolbox | ||
""" | ||
manifest: ManifestSchema = await self._load_tool_manifest(tool_name) | ||
return self._generate_tool(tool_name, manifest) | ||
|
||
async def load_toolset( | ||
self, toolset_name: Optional[str] = None | ||
) -> list[FunctionTool]: | ||
""" | ||
Loads tools from the Toolbox service, optionally filtered by toolset name. | ||
Args: | ||
toolset_name: The name of the toolset to load. | ||
Default: None. If not provided, then all the tools are loaded. | ||
Returns: | ||
A list of all tools loaded from the Toolbox. | ||
""" | ||
tools: list[FunctionTool] = [] | ||
manifest: ManifestSchema = await self._load_toolset_manifest(toolset_name) | ||
for tool_name in manifest.tools: | ||
tools.append(self._generate_tool(tool_name, manifest)) | ||
return tools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from typing import Any, Type, Optional | ||
|
||
import yaml | ||
from aiohttp import ClientSession | ||
from pydantic import BaseModel, Field, create_model | ||
|
||
|
||
class ParameterSchema(BaseModel): | ||
name: str | ||
type: str | ||
description: str | ||
|
||
|
||
class ToolSchema(BaseModel): | ||
description: str | ||
parameters: list[ParameterSchema] | ||
|
||
|
||
class ManifestSchema(BaseModel): | ||
serverVersion: str | ||
tools: dict[str, ToolSchema] | ||
|
||
|
||
async def _load_yaml(url: str, session: ClientSession) -> ManifestSchema: | ||
""" | ||
Asynchronously fetches and parses the YAML data from the given URL. | ||
Args: | ||
url: The base URL to fetch the YAML from. | ||
session: The HTTP client session | ||
Returns: | ||
The parsed Toolbox manifest. | ||
""" | ||
async with session.get(url) as response: | ||
response.raise_for_status() | ||
parsed_yaml = yaml.safe_load(await response.text()) | ||
return ManifestSchema(**parsed_yaml) | ||
|
||
|
||
def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]: | ||
""" | ||
Converts a schema (from the YAML manifest) to a Pydantic BaseModel class. | ||
Args: | ||
model_name: The name of the model to create. | ||
schema: The schema to convert. | ||
Returns: | ||
A Pydantic BaseModel class. | ||
""" | ||
field_definitions = {} | ||
for field in schema: | ||
field_definitions[field.name] = ( | ||
# TODO: Remove the hardcoded optional types once optional fields are supported by Toolbox. | ||
Optional[_parse_type(field.type)], | ||
Field(description=field.description), | ||
) | ||
|
||
return create_model(model_name, **field_definitions) | ||
|
||
|
||
def _parse_type(type_: str) -> Any: | ||
""" | ||
Converts a schema type to a JSON type. | ||
Args: | ||
type_: The type name to convert. | ||
Returns: | ||
A valid JSON type. | ||
""" | ||
|
||
if type_ == "string": | ||
return str | ||
elif type_ == "integer": | ||
return int | ||
elif type_ == "number": | ||
return float | ||
elif type_ == "boolean": | ||
return bool | ||
elif type_ == "array": | ||
return list | ||
else: | ||
raise ValueError(f"Unsupported schema type: {type_}") | ||
|
||
|
||
async def _invoke_tool( | ||
url: str, session: ClientSession, tool_name: str, data: dict | ||
) -> dict: | ||
""" | ||
Asynchronously makes an API call to the Toolbox service to invoke a tool. | ||
Args: | ||
url: The base URL of the Toolbox service. | ||
session: The HTTP client session. | ||
tool_name: The name of the tool to invoke. | ||
data: The input data for the tool. | ||
Returns: | ||
A dictionary containing the parsed JSON response from the tool invocation. | ||
""" | ||
url = f"{url}/api/tool/{tool_name}/invoke" | ||
async with session.post(url, json=_convert_none_to_empty_string(data)) as response: | ||
response.raise_for_status() | ||
json_response = await response.json() | ||
return json_response | ||
|
||
|
||
# TODO: Remove this temporary fix once optional fields are supported by Toolbox. | ||
def _convert_none_to_empty_string(input_dict): | ||
new_dict = {} | ||
for key, value in input_dict.items(): | ||
if value is None: | ||
new_dict[key] = "" | ||
else: | ||
new_dict[key] = value | ||
return new_dict |