Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <diwank@julep.ai>
  • Loading branch information
Diwank Tomer committed Aug 16, 2024
1 parent c6456bc commit b50be03
Show file tree
Hide file tree
Showing 34 changed files with 394 additions and 215 deletions.
51 changes: 20 additions & 31 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import asyncio
from typing import Literal
from uuid import uuid4

from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
CreateTransitionRequest,
EvaluateStep,
IfElseWorkflowStep,
InputChatMLMessage,
Expand All @@ -18,7 +17,7 @@
)
from ...common.protocol.tasks import (
StepContext,
TransitionInfo,
StepOutcome,
)
from ...common.utils.template import render_template
from ...models.execution.create_execution_transition import (
Expand All @@ -27,9 +26,7 @@


@activity.defn
async def prompt_step(context: StepContext) -> dict:
assert isinstance(context.definition, PromptStep)

async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
# Get context data
context_data: dict = context.model_dump()

Expand All @@ -39,6 +36,7 @@ async def prompt_step(context: StepContext) -> dict:
if isinstance(context.definition.prompt, str)
else context.definition.prompt
)

template_messages: list[InputChatMLMessage] = prompt
messages = await asyncio.gather(
*[
Expand All @@ -61,7 +59,10 @@ async def prompt_step(context: StepContext) -> dict:
**settings,
)

return response.model_dump()
return StepOutcome(
output=response.model_dump(),
next=None,
)


@activity.defn
Expand Down Expand Up @@ -103,10 +104,9 @@ async def tool_call_step(context: StepContext) -> dict:


@activity.defn
async def if_else_step(context: StepContext) -> dict:
assert isinstance(context.definition, IfElseWorkflowStep)

async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict:
context_data: dict = context.model_dump()

next_workflow = (
context.definition.then
if simple_eval(context.definition.if_, names=context_data)
Expand All @@ -118,38 +118,27 @@ async def if_else_step(context: StepContext) -> dict:

@activity.defn
async def transition_step(
context: StepContext,
transition_info: TransitionInfo,
execution_status: Literal[
"queued",
"starting",
"running",
"awaiting_input",
"succeeded",
"failed",
"cancelled",
] = "awaiting_input",
context: StepContext[None],
transition_info: CreateTransitionRequest,
):
activity.heartbeat("Running transition step")

# Get transition info
transition_data = transition_info.model_dump(by_alias=False)
need_to_wait = transition_info.type == "wait"

# Get task token if it's a waiting step
if transition_info.type == "awaiting_input":
if need_to_wait:
task_token = activity.info().task_token
transition_data["__task_token"] = task_token
transition_info.task_token = task_token

# Create transition
activity.heartbeat("Creating transition in db")
create_execution_transition_query(
developer_id=context.developer_id,
execution_id=context.execution.id,
transition_id=uuid4(),
update_execution_status=True,
task_id=context.task.id,
**transition_data,
data=transition_info,
update_execution_status=True,
)

# Raise if it's a waiting step
if execution_status == "awaiting_input":
if need_to_wait:
activity.heartbeat("Starting to wait")
activity.raise_complete_async()
17 changes: 15 additions & 2 deletions agents-api/agents_api/autogen/Executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ class Transition(BaseModel):
]
execution_id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]
output: Annotated[dict[str, Any], Field(json_schema_extra={"readOnly": True})]
current: Annotated[list, Field(json_schema_extra={"readOnly": True})]
next: Annotated[list | None, Field(json_schema_extra={"readOnly": True})]
current: Annotated[TransitionTarget, Field(json_schema_extra={"readOnly": True})]
next: Annotated[
TransitionTarget | None, Field(json_schema_extra={"readOnly": True})
]
id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]
metadata: dict[str, Any] | None = None
created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
Expand All @@ -98,6 +100,17 @@ class Transition(BaseModel):
"""


class TransitionTarget(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
workflow: Annotated[str, Field(pattern="^[^\\W0-9]\\w*$")]
"""
Valid python identifier names
"""
step: int


class UpdateExecutionRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand Down
125 changes: 67 additions & 58 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,31 @@
from .Tools import *
from .Users import *

# Generic models
# --------------

DataT = TypeVar("DataT", bound=BaseModel)


class ListResponse(BaseModel, Generic[DataT]):
items: list[DataT]


# Aliases
# -------

CreateToolRequest = UpdateToolRequest
CreateOrUpdateAgentRequest = UpdateAgentRequest
CreateOrUpdateUserRequest = UpdateUserRequest
CreateOrUpdateSessionRequest = CreateSessionRequest
CreateOrUpdateTaskRequest = CreateTaskRequest
ChatResponse = ChunkChatResponse | MessageChatResponse

CreateTransitionRequest = create_partial_model(
Transition,
# The following fields are optional
"id",
"execution_id",
"created_at",
"updated_at",
"metadata",
)

ChatMLRole = Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]
# Custom types (not generated correctly)
# --------------------------------------

# TODO: Remove these when auto-population is fixed

ChatMLContent = (
list[ChatMLTextContentPart | ChatMLImageContentPart]
Expand All @@ -61,9 +61,53 @@
]
)

ChatMLRole = Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]
assert BaseEntry.model_fields["role"].annotation == ChatMLRole

ChatMLSource = Literal[
"api_request", "api_response", "tool_response", "internal", "summarizer", "meta"
]
assert BaseEntry.model_fields["source"].annotation == ChatMLSource


ExecutionStatus = Literal[
"queued",
"starting",
"running",
"awaiting_input",
"succeeded",
"failed",
"cancelled",
]
assert Execution.model_fields["status"].annotation == ExecutionStatus


TransitionType = Literal["finish", "wait", "resume", "error", "step", "cancelled"]
assert Transition.model_fields["type"].annotation == TransitionType


# Create models
# -------------

CreateTransitionRequest = create_partial_model(
Transition,
#
# The following fields are optional
"id",
"execution_id",
"created_at",
"updated_at",
"metadata",
)
CreateTransitionRequest.model_rebuild()


class CreateEntryRequest(BaseEntry):
Expand Down Expand Up @@ -98,35 +142,8 @@ def from_model_input(
)


def make_session(
*,
agents: list[UUID],
users: list[UUID],
**data: dict,
) -> Session:
"""
Create a new session object.
"""
cls, participants = None, {}

match (len(agents), len(users)):
case (0, _):
raise ValueError("At least one agent must be provided.")
case (1, 0):
cls = SingleAgentNoUserSession
participants = {"agent": agents[0]}
case (1, 1):
cls = SingleAgentSingleUserSession
participants = {"agent": agents[0], "user": users[0]}
case (1, u) if u > 1:
cls = SingleAgentMultiUserSession
participants = {"agent": agents[0], "users": users}
case _:
cls = MultiAgentMultiUserSession
participants = {"agents": agents, "users": users}

return cls(**{**data, **participants})

# Task related models
# -------------------

WorkflowStep = (
PromptStep
Expand Down Expand Up @@ -157,7 +174,9 @@ class TaskSpec(_Task):
model_config = ConfigDict(extra="ignore")

workflows: list[Workflow]
main: list[WorkflowStep] | None = None

# Remove main field from the model
main: None = None


class TaskSpecDef(TaskSpec):
Expand Down Expand Up @@ -213,13 +232,3 @@ class UpdateTaskRequest(_UpdateTaskRequest):
"extra": "allow",
}
)


DataT = TypeVar("DataT", bound=BaseModel)


class ListResponse(BaseModel, Generic[DataT]):
items: list[DataT]


ChatResponse = ChunkChatResponse | MessageChatResponse
3 changes: 2 additions & 1 deletion agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from functools import wraps

from litellm import acompletion as _acompletion
from litellm.utils import CustomStreamWrapper, ModelResponse

from ..env import litellm_master_key, litellm_url

__all__ = ["acompletion"]


@wraps(_acompletion)
async def acompletion(*, model: str, **kwargs):
async def acompletion(*, model: str, **kwargs) -> ModelResponse | CustomStreamWrapper:
return await _acompletion(
model=f"openai/{model}", # This is here because litellm proxy expects this format
**kwargs,
Expand Down
34 changes: 34 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
Agent,
ChatInput,
ChatSettings,
MultiAgentMultiUserSession,
Session,
SingleAgentMultiUserSession,
SingleAgentNoUserSession,
SingleAgentSingleUserSession,
Tool,
User,
)
Expand Down Expand Up @@ -107,3 +111,33 @@ def get_chat_environment(self) -> dict[str, dict | list[dict]]:
"settings": self.settings.model_dump(),
"tools": [tool.model_dump() for tool in tools],
}


def make_session(
*,
agents: list[UUID],
users: list[UUID],
**data: dict,
) -> Session:
"""
Create a new session object.
"""
cls, participants = None, {}

match (len(agents), len(users)):
case (0, _):
raise ValueError("At least one agent must be provided.")
case (1, 0):
cls = SingleAgentNoUserSession
participants = {"agent": agents[0]}
case (1, 1):
cls = SingleAgentSingleUserSession
participants = {"agent": agents[0], "user": users[0]}
case (1, u) if u > 1:
cls = SingleAgentMultiUserSession
participants = {"agent": agents[0], "users": users}
case _:
cls = MultiAgentMultiUserSession
participants = {"agents": agents, "users": users}

return cls(**{**data, **participants})
Loading

0 comments on commit b50be03

Please sign in to comment.