Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <diwank@julep.ai>
  • Loading branch information
Diwank Tomer committed Aug 16, 2024
1 parent b50be03 commit 083b089
Show file tree
Hide file tree
Showing 17 changed files with 405 additions and 242 deletions.
2 changes: 2 additions & 0 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from uuid import UUID

from beartype import beartype
from temporalio import activity

from ..clients import embed as embedder
Expand All @@ -10,6 +11,7 @@


@activity.defn
@beartype
async def embed_docs(
developer_id: UUID,
doc_id: UUID,
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/activities/mem_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable
from uuid import UUID

from beartype import beartype
from temporalio import activity

from ..autogen.openapi_model import InputChatMLMessage
Expand Down Expand Up @@ -155,6 +156,7 @@ async def run_prompt(


@activity.defn
@beartype
async def mem_mgmt(
dialog: list[InputChatMLMessage],
session_id: UUID,
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/activities/mem_rating.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from textwrap import dedent
from typing import Callable

from beartype import beartype
from temporalio import activity

from ..clients import litellm
Expand Down Expand Up @@ -67,6 +68,7 @@ async def run_prompt(


@activity.defn
@beartype
async def mem_rating(memory: str) -> None:
# session_id = UUID(session_id)
# entries = [
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3


from beartype import beartype
import pandas as pd
from temporalio import activity

Expand All @@ -20,6 +21,7 @@ def get_toplevel_entries_query(*args, **kwargs):


@activity.defn
@beartype
async def summarization(session_id: str) -> None:
raise NotImplementedError()
# session_id = UUID(session_id)
Expand Down
154 changes: 10 additions & 144 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,10 @@
import asyncio

from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
CreateTransitionRequest,
EvaluateStep,
IfElseWorkflowStep,
InputChatMLMessage,
PromptStep,
ToolCallStep,
YieldStep,
)
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.utils.template import render_template
from ...models.execution.create_execution_transition import (
create_execution_transition as create_execution_transition_query,
)


@activity.defn
async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
# Get context data
context_data: dict = context.model_dump()

# Render template messages
prompt = (
[InputChatMLMessage(content=context.definition.prompt)]
if isinstance(context.definition.prompt, str)
else context.definition.prompt
)

template_messages: list[InputChatMLMessage] = prompt
messages = await asyncio.gather(
*[
render_template(msg.content, context_data, skip_vars=["developer_id"])
for msg in template_messages
]
)

messages = [
InputChatMLMessage(role="user", content=m)
if isinstance(m, str)
else InputChatMLMessage(**m)
for m in messages
]

settings: dict = context.definition.settings.model_dump()
# Get settings and run llm
response = await litellm.acompletion(
messages=messages,
**settings,
)

return StepOutcome(
output=response.model_dump(),
next=None,
)


@activity.defn
async def evaluate_step(context: StepContext) -> dict:
assert isinstance(context.definition, EvaluateStep)

names = {}
for i in context.inputs:
names.update(i)

return {
"result": {
k: simple_eval(v, names=names)
for k, v in context.definition.evaluate.items()
}
}


@activity.defn
async def yield_step(context: StepContext) -> dict:
if not isinstance(context.definition, YieldStep):
return {}

# TODO: implement

return {"test": "result"}


@activity.defn
async def tool_call_step(context: StepContext) -> dict:
assert isinstance(context.definition, ToolCallStep)

context.definition.tool_id
context.definition.arguments
# get tool by id
# call tool

return {}


@activity.defn
async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict:
context_data: dict = context.model_dump()

next_workflow = (
context.definition.then
if simple_eval(context.definition.if_, names=context_data)
else context.definition.else_
)

return {"goto_workflow": next_workflow}


@activity.defn
async def transition_step(
context: StepContext[None],
transition_info: CreateTransitionRequest,
):
need_to_wait = transition_info.type == "wait"

# Get task token if it's a waiting step
if need_to_wait:
task_token = activity.info().task_token
transition_info.task_token = task_token

# Create transition
activity.heartbeat("Creating transition in db")
create_execution_transition_query(
developer_id=context.developer_id,
execution_id=context.execution.id,
task_id=context.task.id,
data=transition_info,
update_execution_status=True,
)

# Raise if it's a waiting step
if need_to_wait:
activity.heartbeat("Starting to wait")
activity.raise_complete_async()
# ruff: noqa: F401, F403, F405

from .evaluate_step import evaluate_step
from .if_else_step import if_else_step
from .prompt_step import prompt_step
from .raise_complete_async import raise_complete_async
from .tool_call_step import tool_call_step
from .transition_step import transition_step
from .wait_for_input_step import wait_for_input_step
from .yield_step import yield_step
22 changes: 22 additions & 0 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import EvaluateStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)


@activity.defn
@beartype
async def evaluate_step(
context: StepContext[EvaluateStep],
) -> StepOutcome[dict[str, Any]]:
exprs = context.definition.arguments
output = simple_eval_dict(exprs, values=context.model_dump())

return StepOutcome(output=output)
25 changes: 25 additions & 0 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
IfElseWorkflowStep,
)
from ...common.protocol.tasks import (
StepContext,
)


@activity.defn
@beartype
async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict:
raise NotImplementedError()
# context_data: dict = context.model_dump()

# next_workflow = (
# context.definition.then
# if simple_eval(context.definition.if_, names=context_data)
# else context.definition.else_
# )

# return {"goto_workflow": next_workflow}
60 changes: 60 additions & 0 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import asyncio

from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
InputChatMLMessage,
PromptStep,
)
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...common.utils.template import render_template


@activity.defn
@beartype
async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome:
# Get context data
context_data: dict = context.model_dump()

# Render template messages
prompt = (
[InputChatMLMessage(content=context.definition.prompt)]
if isinstance(context.definition.prompt, str)
else context.definition.prompt
)

template_messages: list[InputChatMLMessage] = prompt
messages = await asyncio.gather(
*[
render_template(msg.content, context_data, skip_vars=["developer_id"])
for msg in template_messages
]
)

messages = [
(
InputChatMLMessage(role="user", content=m)
if isinstance(m, str)
else InputChatMLMessage(**m)
)
for m in messages
]

settings: dict = context.definition.settings.model_dump()
# Get settings and run llm
response = await litellm.acompletion(
messages=messages,
**settings,
)

return StepOutcome(
output=response.model_dump(),
next=None,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from temporalio import activity


@activity.defn
async def raise_complete_async() -> None:

activity.heartbeat("Starting to wait")
activity.raise_complete_async()
24 changes: 24 additions & 0 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import (
ToolCallStep,
)
from ...common.protocol.tasks import (
StepContext,
)



@activity.defn
@beartype
async def tool_call_step(context: StepContext) -> dict:
raise NotImplementedError()
# assert isinstance(context.definition, ToolCallStep)

# context.definition.tool_id
# context.definition.arguments
# # get tool by id
# # call tool

# return {}
Loading

0 comments on commit 083b089

Please sign in to comment.