Skip to content

Commit

Permalink
LlamaIndex SDK code
Browse files Browse the repository at this point in the history
  • Loading branch information
twishabansal committed Nov 6, 2024
1 parent 52ebb43 commit 734fdd1
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 0 deletions.
58 changes: 58 additions & 0 deletions sdks/llamaindex/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
[project]
name = "toolbox_llamaindex_sdk"
version="0.0.1"
description = "Python SDK for interacting with the Toolbox service with Llamaindex"
license = {file = "LICENSE"}
requires-python = ">=3.9"
authors = [
{name = "Google LLC", email = "googleapis-packages@google.com"}
]
dependencies = [
"aiohttp",
"PyYAML",
"llama-index",
"pydantic",
"pytest-asyncio",
]

classifiers = [
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]

[project.urls]
Homepage = "/~https://github.com/googleapis/genai-toolbox"
Repository = "/~https://github.com/googleapis/genai-toolbox.git"
"Bug Tracker" = "/~https://github.com/googleapis/genai-toolbox/issues"

[project.optional-dependencies]
test = [
"black[jupyter]",
"isort",
"mypy",
"pytest-asyncio",
"pytest",
"pytest-cov",
"Pillow"
]

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.black]
target-version = ['py39']

[tool.isort]
profile = "black"

[tool.mypy]
python_version = "3.9"
warn_unused_configs = true
disallow_incomplete_defs = true
4 changes: 4 additions & 0 deletions sdks/llamaindex/src/toolbox_llamaindex_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .client import ToolboxClient
# import utils

__all__ = ["ToolboxClient"]
107 changes: 107 additions & 0 deletions sdks/llamaindex/src/toolbox_llamaindex_sdk/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Optional

from aiohttp import ClientSession
from llama_index.core.tools import FunctionTool

from .utils import ManifestSchema, _invoke_tool, _load_yaml, _schema_to_model


class ToolboxClient:
def __init__(self, url: str, session: ClientSession):
"""
Initializes the ToolboxClient for the Toolbox service at the given URL.
Args:
url: The base URL of the Toolbox service.
session: The HTTP client session.
"""
self._url: str = url
self._session = session

async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema:
"""
Fetches and parses the YAML manifest for the given tool from the Toolbox service.
Args:
tool_name: The name of the tool to load.
Returns:
The parsed Toolbox manifest.
"""
url = f"{self._url}/api/tool/{tool_name}"
return await _load_yaml(url, self._session)

async def _load_toolset_manifest(
self, toolset_name: Optional[str] = None
) -> ManifestSchema:
"""
Fetches and parses the YAML manifest from the Toolbox service.
Args:
toolset_name: The name of the toolset to load.
Default: None. If not provided, then all the available tools are loaded.
Returns:
The parsed Toolbox manifest.
"""
url = f"{self._url}/api/toolset/{toolset_name or ''}"
return await _load_yaml(url, self._session)

def _generate_tool(self, tool_name: str, manifest: ManifestSchema) -> FunctionTool:
"""
Creates a FunctionTool object and a dynamically generated BaseModel for the given tool.
Args:
tool_name: The name of the tool to generate.
manifest: The parsed Toolbox manifest.
Returns:
The generated tool.
"""
tool_schema = manifest.tools[tool_name]
tool_model = _schema_to_model(
model_name=tool_name, schema=tool_schema.parameters
)

async def _tool_func(**kwargs) -> dict:
return await _invoke_tool(self._url, self._session, tool_name, kwargs)

return FunctionTool.from_defaults(
async_fn=_tool_func,
name=tool_name,
description=tool_schema.description,
fn_schema=tool_model,
)

async def load_tool(self, tool_name: str) -> FunctionTool:
"""
Loads the tool, with the given tool name, from the Toolbox service.
Args:
toolset_name: The name of the toolset to load.
Default: None. If not provided, then all the tools are loaded.
Returns:
A tool loaded from the Toolbox
"""
manifest: ManifestSchema = await self._load_tool_manifest(tool_name)
return self._generate_tool(tool_name, manifest)

async def load_toolset(
self, toolset_name: Optional[str] = None
) -> list[FunctionTool]:
"""
Loads tools from the Toolbox service, optionally filtered by toolset name.
Args:
toolset_name: The name of the toolset to load.
Default: None. If not provided, then all the tools are loaded.
Returns:
A list of all tools loaded from the Toolbox.
"""
tools: list[FunctionTool] = []
manifest: ManifestSchema = await self._load_toolset_manifest(toolset_name)
for tool_name in manifest.tools:
tools.append(self._generate_tool(tool_name, manifest))
return tools
118 changes: 118 additions & 0 deletions sdks/llamaindex/src/toolbox_llamaindex_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Any, Type, Optional

import yaml
from aiohttp import ClientSession
from pydantic import BaseModel, Field, create_model


class ParameterSchema(BaseModel):
name: str
type: str
description: str


class ToolSchema(BaseModel):
description: str
parameters: list[ParameterSchema]


class ManifestSchema(BaseModel):
serverVersion: str
tools: dict[str, ToolSchema]


async def _load_yaml(url: str, session: ClientSession) -> ManifestSchema:
"""
Asynchronously fetches and parses the YAML data from the given URL.
Args:
url: The base URL to fetch the YAML from.
session: The HTTP client session
Returns:
The parsed Toolbox manifest.
"""
async with session.get(url) as response:
response.raise_for_status()
parsed_yaml = yaml.safe_load(await response.text())
return ManifestSchema(**parsed_yaml)


def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]:
"""
Converts a schema (from the YAML manifest) to a Pydantic BaseModel class.
Args:
model_name: The name of the model to create.
schema: The schema to convert.
Returns:
A Pydantic BaseModel class.
"""
field_definitions = {}
for field in schema:
field_definitions[field.name] = (
# TODO: Remove the hardcoded optional types once optional fields are supported by Toolbox.
Optional[_parse_type(field.type)],
Field(description=field.description),
)

return create_model(model_name, **field_definitions)


def _parse_type(type_: str) -> Any:
"""
Converts a schema type to a JSON type.
Args:
type_: The type name to convert.
Returns:
A valid JSON type.
"""

if type_ == "string":
return str
elif type_ == "integer":
return int
elif type_ == "number":
return float
elif type_ == "boolean":
return bool
elif type_ == "array":
return list
else:
raise ValueError(f"Unsupported schema type: {type_}")


async def _invoke_tool(
url: str, session: ClientSession, tool_name: str, data: dict
) -> dict:
"""
Asynchronously makes an API call to the Toolbox service to invoke a tool.
Args:
url: The base URL of the Toolbox service.
session: The HTTP client session.
tool_name: The name of the tool to invoke.
data: The input data for the tool.
Returns:
A dictionary containing the parsed JSON response from the tool invocation.
"""
url = f"{url}/api/tool/{tool_name}/invoke"
async with session.post(url, json=_convert_none_to_empty_string(data)) as response:
response.raise_for_status()
json_response = await response.json()
return json_response


# TODO: Remove this temporary fix once optional fields are supported by Toolbox.
def _convert_none_to_empty_string(input_dict):
new_dict = {}
for key, value in input_dict.items():
if value is None:
new_dict[key] = ""
else:
new_dict[key] = value
return new_dict

0 comments on commit 734fdd1

Please sign in to comment.