-
Notifications
You must be signed in to change notification settings - Fork 928
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Diwank Tomer <diwank@julep.ai>
- Loading branch information
Diwank Tomer
committed
Aug 16, 2024
1 parent
b50be03
commit 083b089
Showing
17 changed files
with
405 additions
and
242 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 10 additions & 144 deletions
154
agents-api/agents_api/activities/task_steps/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,144 +1,10 @@ | ||
import asyncio | ||
|
||
from simpleeval import simple_eval | ||
from temporalio import activity | ||
|
||
from ...autogen.openapi_model import ( | ||
CreateTransitionRequest, | ||
EvaluateStep, | ||
IfElseWorkflowStep, | ||
InputChatMLMessage, | ||
PromptStep, | ||
ToolCallStep, | ||
YieldStep, | ||
) | ||
from ...clients import ( | ||
litellm, # We dont directly import `acompletion` so we can mock it | ||
) | ||
from ...common.protocol.tasks import ( | ||
StepContext, | ||
StepOutcome, | ||
) | ||
from ...common.utils.template import render_template | ||
from ...models.execution.create_execution_transition import ( | ||
create_execution_transition as create_execution_transition_query, | ||
) | ||
|
||
|
||
@activity.defn | ||
async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome: | ||
# Get context data | ||
context_data: dict = context.model_dump() | ||
|
||
# Render template messages | ||
prompt = ( | ||
[InputChatMLMessage(content=context.definition.prompt)] | ||
if isinstance(context.definition.prompt, str) | ||
else context.definition.prompt | ||
) | ||
|
||
template_messages: list[InputChatMLMessage] = prompt | ||
messages = await asyncio.gather( | ||
*[ | ||
render_template(msg.content, context_data, skip_vars=["developer_id"]) | ||
for msg in template_messages | ||
] | ||
) | ||
|
||
messages = [ | ||
InputChatMLMessage(role="user", content=m) | ||
if isinstance(m, str) | ||
else InputChatMLMessage(**m) | ||
for m in messages | ||
] | ||
|
||
settings: dict = context.definition.settings.model_dump() | ||
# Get settings and run llm | ||
response = await litellm.acompletion( | ||
messages=messages, | ||
**settings, | ||
) | ||
|
||
return StepOutcome( | ||
output=response.model_dump(), | ||
next=None, | ||
) | ||
|
||
|
||
@activity.defn | ||
async def evaluate_step(context: StepContext) -> dict: | ||
assert isinstance(context.definition, EvaluateStep) | ||
|
||
names = {} | ||
for i in context.inputs: | ||
names.update(i) | ||
|
||
return { | ||
"result": { | ||
k: simple_eval(v, names=names) | ||
for k, v in context.definition.evaluate.items() | ||
} | ||
} | ||
|
||
|
||
@activity.defn | ||
async def yield_step(context: StepContext) -> dict: | ||
if not isinstance(context.definition, YieldStep): | ||
return {} | ||
|
||
# TODO: implement | ||
|
||
return {"test": "result"} | ||
|
||
|
||
@activity.defn | ||
async def tool_call_step(context: StepContext) -> dict: | ||
assert isinstance(context.definition, ToolCallStep) | ||
|
||
context.definition.tool_id | ||
context.definition.arguments | ||
# get tool by id | ||
# call tool | ||
|
||
return {} | ||
|
||
|
||
@activity.defn | ||
async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict: | ||
context_data: dict = context.model_dump() | ||
|
||
next_workflow = ( | ||
context.definition.then | ||
if simple_eval(context.definition.if_, names=context_data) | ||
else context.definition.else_ | ||
) | ||
|
||
return {"goto_workflow": next_workflow} | ||
|
||
|
||
@activity.defn | ||
async def transition_step( | ||
context: StepContext[None], | ||
transition_info: CreateTransitionRequest, | ||
): | ||
need_to_wait = transition_info.type == "wait" | ||
|
||
# Get task token if it's a waiting step | ||
if need_to_wait: | ||
task_token = activity.info().task_token | ||
transition_info.task_token = task_token | ||
|
||
# Create transition | ||
activity.heartbeat("Creating transition in db") | ||
create_execution_transition_query( | ||
developer_id=context.developer_id, | ||
execution_id=context.execution.id, | ||
task_id=context.task.id, | ||
data=transition_info, | ||
update_execution_status=True, | ||
) | ||
|
||
# Raise if it's a waiting step | ||
if need_to_wait: | ||
activity.heartbeat("Starting to wait") | ||
activity.raise_complete_async() | ||
# ruff: noqa: F401, F403, F405 | ||
|
||
from .evaluate_step import evaluate_step | ||
from .if_else_step import if_else_step | ||
from .prompt_step import prompt_step | ||
from .raise_complete_async import raise_complete_async | ||
from .tool_call_step import tool_call_step | ||
from .transition_step import transition_step | ||
from .wait_for_input_step import wait_for_input_step | ||
from .yield_step import yield_step |
22 changes: 22 additions & 0 deletions
22
agents-api/agents_api/activities/task_steps/evaluate_step.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import Any | ||
|
||
from beartype import beartype | ||
from temporalio import activity | ||
|
||
from ...activities.task_steps.utils import simple_eval_dict | ||
from ...autogen.openapi_model import EvaluateStep | ||
from ...common.protocol.tasks import ( | ||
StepContext, | ||
StepOutcome, | ||
) | ||
|
||
|
||
@activity.defn | ||
@beartype | ||
async def evaluate_step( | ||
context: StepContext[EvaluateStep], | ||
) -> StepOutcome[dict[str, Any]]: | ||
exprs = context.definition.arguments | ||
output = simple_eval_dict(exprs, values=context.model_dump()) | ||
|
||
return StepOutcome(output=output) |
25 changes: 25 additions & 0 deletions
25
agents-api/agents_api/activities/task_steps/if_else_step.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from beartype import beartype | ||
from simpleeval import simple_eval | ||
from temporalio import activity | ||
|
||
from ...autogen.openapi_model import ( | ||
IfElseWorkflowStep, | ||
) | ||
from ...common.protocol.tasks import ( | ||
StepContext, | ||
) | ||
|
||
|
||
@activity.defn | ||
@beartype | ||
async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict: | ||
raise NotImplementedError() | ||
# context_data: dict = context.model_dump() | ||
|
||
# next_workflow = ( | ||
# context.definition.then | ||
# if simple_eval(context.definition.if_, names=context_data) | ||
# else context.definition.else_ | ||
# ) | ||
|
||
# return {"goto_workflow": next_workflow} |
60 changes: 60 additions & 0 deletions
60
agents-api/agents_api/activities/task_steps/prompt_step.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import asyncio | ||
|
||
from beartype import beartype | ||
from temporalio import activity | ||
|
||
from ...autogen.openapi_model import ( | ||
InputChatMLMessage, | ||
PromptStep, | ||
) | ||
from ...clients import ( | ||
litellm, # We dont directly import `acompletion` so we can mock it | ||
) | ||
from ...common.protocol.tasks import ( | ||
StepContext, | ||
StepOutcome, | ||
) | ||
from ...common.utils.template import render_template | ||
|
||
|
||
@activity.defn | ||
@beartype | ||
async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome: | ||
# Get context data | ||
context_data: dict = context.model_dump() | ||
|
||
# Render template messages | ||
prompt = ( | ||
[InputChatMLMessage(content=context.definition.prompt)] | ||
if isinstance(context.definition.prompt, str) | ||
else context.definition.prompt | ||
) | ||
|
||
template_messages: list[InputChatMLMessage] = prompt | ||
messages = await asyncio.gather( | ||
*[ | ||
render_template(msg.content, context_data, skip_vars=["developer_id"]) | ||
for msg in template_messages | ||
] | ||
) | ||
|
||
messages = [ | ||
( | ||
InputChatMLMessage(role="user", content=m) | ||
if isinstance(m, str) | ||
else InputChatMLMessage(**m) | ||
) | ||
for m in messages | ||
] | ||
|
||
settings: dict = context.definition.settings.model_dump() | ||
# Get settings and run llm | ||
response = await litellm.acompletion( | ||
messages=messages, | ||
**settings, | ||
) | ||
|
||
return StepOutcome( | ||
output=response.model_dump(), | ||
next=None, | ||
) |
8 changes: 8 additions & 0 deletions
8
agents-api/agents_api/activities/task_steps/raise_complete_async.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from temporalio import activity | ||
|
||
|
||
@activity.defn | ||
async def raise_complete_async() -> None: | ||
|
||
activity.heartbeat("Starting to wait") | ||
activity.raise_complete_async() |
24 changes: 24 additions & 0 deletions
24
agents-api/agents_api/activities/task_steps/tool_call_step.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from beartype import beartype | ||
from temporalio import activity | ||
|
||
from ...autogen.openapi_model import ( | ||
ToolCallStep, | ||
) | ||
from ...common.protocol.tasks import ( | ||
StepContext, | ||
) | ||
|
||
|
||
|
||
@activity.defn | ||
@beartype | ||
async def tool_call_step(context: StepContext) -> dict: | ||
raise NotImplementedError() | ||
# assert isinstance(context.definition, ToolCallStep) | ||
|
||
# context.definition.tool_id | ||
# context.definition.arguments | ||
# # get tool by id | ||
# # call tool | ||
|
||
# return {} |
Oops, something went wrong.