Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tasks): Enable all fields of ExecutionInput #396

Merged
merged 9 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-06-17T07:05:47+00:00
# timestamp: 2024-06-21T03:13:24+00:00

from __future__ import annotations

Expand Down Expand Up @@ -1140,18 +1140,21 @@ class PatchToolRequest(BaseModel):
class Execution(BaseModel):
id: UUID
task_id: UUID
created_at: UUID
arguments: Dict[str, Any]
"""
JSON Schema of parameters
"""
status: Annotated[
str,
Field(pattern="^(queued|starting|running|awaiting_input|succeeded|failed)$"),
]
"""
Execution Status
"""
arguments: Dict[str, Any]
"""
JSON of parameters
"""
user_id: UUID | None = None
session_id: UUID | None = None
created_at: AwareDatetime
updated_at: AwareDatetime


class ExecutionTransition(BaseModel):
Expand Down Expand Up @@ -1259,4 +1262,5 @@ class Task(BaseModel):
ID of the Task
"""
created_at: AwareDatetime
updated_at: AwareDatetime | None = None
agent_id: UUID
78 changes: 73 additions & 5 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from pydantic import BaseModel, Field, UUID4, computed_field

from ...autogen.openapi_model import (
User,
Agent,
Session,
Tool,
FunctionDef,
PromptWorkflowStep,
EvaluateWorkflowStep,
YieldWorkflowStep,
Expand All @@ -14,6 +19,9 @@
Execution,
)

from ...models.execution.get_execution_input import get_execution_input_query
from ..utils.cozo import uuid_int_list_to_uuid4

WorkflowStep = (
PromptWorkflowStep
| EvaluateWorkflowStep
Expand Down Expand Up @@ -53,6 +61,24 @@ class TaskSpec(BaseModel):


class TaskProtocol(SerializableTask):
@classmethod
def from_cozo_data(cls, task_data: dict[str, Any]) -> "SerializableTask":

workflows = task_data.pop("workflows")
assert len(workflows) > 0

main_wf_idx, main_wf = next(
(i, wf) for i, wf in enumerate(workflows) if wf["name"] == "main"
)

task_data["main"] = main_wf["steps"]
workflows.pop(main_wf_idx)

for workflow in workflows:
task_data[workflow["name"]] = workflow["steps"]

return cls(**task_data)

@computed_field
@property
def spec(self) -> TaskSpec:
Expand All @@ -79,17 +105,59 @@ def spec(self) -> TaskSpec:
)


# FIXME: Enable all of these
class ExecutionInput(BaseModel):
developer_id: UUID4
execution: Execution
task: TaskProtocol
# agent: Agent
# user: User | None
# session: Session | None
# tools: list[Tool]
agent: Agent
user: User | None
session: Session | None
tools: list[Tool]
arguments: dict[str, Any]

@classmethod
def fetch(
cls, *, developer_id: UUID4, task_id: UUID4, execution_id: UUID4, client: Any
) -> "ExecutionInput":
[data] = get_execution_input_query(
task_id=task_id,
execution_id=execution_id,
client=client,
).to_dict(orient="records")

# FIXME: Need to manually convert id from list of int to UUID4
# because cozo has a bug with UUID4
# See: /~https://github.com/cozodb/cozo/issues/269
for kind in ["task", "execution", "agent", "user", "session"]:
if not data[kind]:
continue

for key in data[kind]:
if key == "id" or key.endswith("_id") and data[kind][key] is not None:
data[kind][key] = uuid_int_list_to_uuid4(data[kind][key])

agent = Agent(**data["agent"])
task = TaskProtocol.from_cozo_data(data["task"])
execution = Execution(**data["execution"])
user = User(**data["user"]) if data["user"] else None
session = Session(**data["session"]) if data["session"] else None
tools = [
Tool(type="function", id=function["id"], function=FunctionDef(**function))
for function in data["tools"]
]
arguments = execution.arguments

return cls(
developer_id=developer_id,
execution=execution,
task=task,
agent=agent,
user=user,
session=session,
tools=tools,
arguments=arguments,
)


class StepContext(ExecutionInput):
definition: WorkflowStep
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/common/utils/cozo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""This module provides utility functions for interacting with the Cozo API client, including data mutation processes."""

from types import SimpleNamespace
from uuid import UUID

from pycozo import Client

Expand All @@ -17,3 +18,7 @@
cozo_process_mutate_data = _fake_client._process_mutate_data = lambda data: (
Client._process_mutate_data(_fake_client, data)
)

uuid_int_list_to_uuid4 = lambda data: UUID(
bytes=b"".join([i.to_bytes(1, "big") for i in data])
)
15 changes: 9 additions & 6 deletions agents-api/agents_api/models/execution/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def create_execution_query(
agent_id: UUID,
task_id: UUID,
execution_id: UUID,
session_id: UUID | None = None,
status: Literal[
"queued", "starting", "running", "awaiting_input", "succeeded", "failed"
] = "queued",
Expand All @@ -22,16 +23,17 @@ def create_execution_query(

query = """
{
?[task_id, execution_id, status, arguments] <- [[
to_uuid($task_id),
to_uuid($execution_id),
$status,
$arguments
]]
?[task_id, execution_id, session_id, status, arguments] :=
task_id = to_uuid($task_id),
execution_id = to_uuid($execution_id),
session_id = if(is_null($session_id), null, to_uuid($session_id)),
status = $status,
arguments = $arguments

:insert executions {
task_id,
execution_id,
session_id,
status,
arguments
}
Expand All @@ -42,6 +44,7 @@ def create_execution_query(
{
"task_id": str(task_id),
"execution_id": str(execution_id),
"session_id": str(session_id) if session_id is not None else None,
"status": status,
"arguments": arguments,
},
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/models/execution/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def get_execution_query(
) -> tuple[str, dict]:
query = """
{
?[status, arguments, created_at, updated_at] := *executions {
?[status, arguments, session_id, created_at, updated_at] := *executions {
task_id: to_uuid($task_id),
execution_id: to_uuid($execution_id),
status,
arguments,
session_id,
created_at,
updated_at,
}
Expand Down
Loading