From 083b0892f2844c4d4f52e839ae81d909f510c292 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Fri, 16 Aug 2024 01:54:42 -0400 Subject: [PATCH] wip Signed-off-by: Diwank Tomer --- .../agents_api/activities/embed_docs.py | 2 + agents-api/agents_api/activities/mem_mgmt.py | 2 + .../agents_api/activities/mem_rating.py | 2 + .../agents_api/activities/summarization.py | 2 + .../activities/task_steps/__init__.py | 154 +----------- .../activities/task_steps/evaluate_step.py | 22 ++ .../activities/task_steps/if_else_step.py | 25 ++ .../activities/task_steps/prompt_step.py | 60 +++++ .../task_steps/raise_complete_async.py | 8 + .../activities/task_steps/tool_call_step.py | 24 ++ .../activities/task_steps/transition_step.py | 36 +++ .../agents_api/activities/task_steps/utils.py | 11 + .../activities/task_steps/yield_step.py | 29 +++ .../agents_api/activities/truncation.py | 2 + agents-api/agents_api/clients/temporal.py | 13 +- .../agents_api/common/protocol/tasks.py | 26 +- .../agents_api/workflows/task_execution.py | 229 +++++++++++------- 17 files changed, 405 insertions(+), 242 deletions(-) create mode 100644 agents-api/agents_api/activities/task_steps/evaluate_step.py create mode 100644 agents-api/agents_api/activities/task_steps/if_else_step.py create mode 100644 agents-api/agents_api/activities/task_steps/prompt_step.py create mode 100644 agents-api/agents_api/activities/task_steps/raise_complete_async.py create mode 100644 agents-api/agents_api/activities/task_steps/tool_call_step.py create mode 100644 agents-api/agents_api/activities/task_steps/transition_step.py create mode 100644 agents-api/agents_api/activities/task_steps/utils.py create mode 100644 agents-api/agents_api/activities/task_steps/yield_step.py diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index 1000f456d..7198a7e54 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,5 +1,6 @@ from uuid import UUID +from beartype import beartype from temporalio import activity from ..clients import embed as embedder @@ -10,6 +11,7 @@ @activity.defn +@beartype async def embed_docs( developer_id: UUID, doc_id: UUID, diff --git a/agents-api/agents_api/activities/mem_mgmt.py b/agents-api/agents_api/activities/mem_mgmt.py index ea4bb84d2..7cd4a7d6b 100644 --- a/agents-api/agents_api/activities/mem_mgmt.py +++ b/agents-api/agents_api/activities/mem_mgmt.py @@ -2,6 +2,7 @@ from typing import Callable from uuid import UUID +from beartype import beartype from temporalio import activity from ..autogen.openapi_model import InputChatMLMessage @@ -155,6 +156,7 @@ async def run_prompt( @activity.defn +@beartype async def mem_mgmt( dialog: list[InputChatMLMessage], session_id: UUID, diff --git a/agents-api/agents_api/activities/mem_rating.py b/agents-api/agents_api/activities/mem_rating.py index 222148f4c..c681acbc3 100644 --- a/agents-api/agents_api/activities/mem_rating.py +++ b/agents-api/agents_api/activities/mem_rating.py @@ -1,6 +1,7 @@ from textwrap import dedent from typing import Callable +from beartype import beartype from temporalio import activity from ..clients import litellm @@ -67,6 +68,7 @@ async def run_prompt( @activity.defn +@beartype async def mem_rating(memory: str) -> None: # session_id = UUID(session_id) # entries = [ diff --git a/agents-api/agents_api/activities/summarization.py b/agents-api/agents_api/activities/summarization.py index 8a45927ee..662554181 100644 --- a/agents-api/agents_api/activities/summarization.py +++ b/agents-api/agents_api/activities/summarization.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 +from beartype import beartype import pandas as pd from temporalio import activity @@ -20,6 +21,7 @@ def get_toplevel_entries_query(*args, **kwargs): @activity.defn +@beartype async def summarization(session_id: str) -> None: raise NotImplementedError() # session_id = UUID(session_id) diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index f4cef8d3b..28932c02e 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -1,144 +1,10 @@ -import asyncio - -from simpleeval import simple_eval -from temporalio import activity - -from ...autogen.openapi_model import ( - CreateTransitionRequest, - EvaluateStep, - IfElseWorkflowStep, - InputChatMLMessage, - PromptStep, - ToolCallStep, - YieldStep, -) -from ...clients import ( - litellm, # We dont directly import `acompletion` so we can mock it -) -from ...common.protocol.tasks import ( - StepContext, - StepOutcome, -) -from ...common.utils.template import render_template -from ...models.execution.create_execution_transition import ( - create_execution_transition as create_execution_transition_query, -) - - -@activity.defn -async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome: - # Get context data - context_data: dict = context.model_dump() - - # Render template messages - prompt = ( - [InputChatMLMessage(content=context.definition.prompt)] - if isinstance(context.definition.prompt, str) - else context.definition.prompt - ) - - template_messages: list[InputChatMLMessage] = prompt - messages = await asyncio.gather( - *[ - render_template(msg.content, context_data, skip_vars=["developer_id"]) - for msg in template_messages - ] - ) - - messages = [ - InputChatMLMessage(role="user", content=m) - if isinstance(m, str) - else InputChatMLMessage(**m) - for m in messages - ] - - settings: dict = context.definition.settings.model_dump() - # Get settings and run llm - response = await litellm.acompletion( - messages=messages, - **settings, - ) - - return StepOutcome( - output=response.model_dump(), - next=None, - ) - - -@activity.defn -async def evaluate_step(context: StepContext) -> dict: - assert isinstance(context.definition, EvaluateStep) - - names = {} - for i in context.inputs: - names.update(i) - - return { - "result": { - k: simple_eval(v, names=names) - for k, v in context.definition.evaluate.items() - } - } - - -@activity.defn -async def yield_step(context: StepContext) -> dict: - if not isinstance(context.definition, YieldStep): - return {} - - # TODO: implement - - return {"test": "result"} - - -@activity.defn -async def tool_call_step(context: StepContext) -> dict: - assert isinstance(context.definition, ToolCallStep) - - context.definition.tool_id - context.definition.arguments - # get tool by id - # call tool - - return {} - - -@activity.defn -async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict: - context_data: dict = context.model_dump() - - next_workflow = ( - context.definition.then - if simple_eval(context.definition.if_, names=context_data) - else context.definition.else_ - ) - - return {"goto_workflow": next_workflow} - - -@activity.defn -async def transition_step( - context: StepContext[None], - transition_info: CreateTransitionRequest, -): - need_to_wait = transition_info.type == "wait" - - # Get task token if it's a waiting step - if need_to_wait: - task_token = activity.info().task_token - transition_info.task_token = task_token - - # Create transition - activity.heartbeat("Creating transition in db") - create_execution_transition_query( - developer_id=context.developer_id, - execution_id=context.execution.id, - task_id=context.task.id, - data=transition_info, - update_execution_status=True, - ) - - # Raise if it's a waiting step - if need_to_wait: - activity.heartbeat("Starting to wait") - activity.raise_complete_async() +# ruff: noqa: F401, F403, F405 + +from .evaluate_step import evaluate_step +from .if_else_step import if_else_step +from .prompt_step import prompt_step +from .raise_complete_async import raise_complete_async +from .tool_call_step import tool_call_step +from .transition_step import transition_step +from .wait_for_input_step import wait_for_input_step +from .yield_step import yield_step \ No newline at end of file diff --git a/agents-api/agents_api/activities/task_steps/evaluate_step.py b/agents-api/agents_api/activities/task_steps/evaluate_step.py new file mode 100644 index 000000000..6f6630f4a --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/evaluate_step.py @@ -0,0 +1,22 @@ +from typing import Any + +from beartype import beartype +from temporalio import activity + +from ...activities.task_steps.utils import simple_eval_dict +from ...autogen.openapi_model import EvaluateStep +from ...common.protocol.tasks import ( + StepContext, + StepOutcome, +) + + +@activity.defn +@beartype +async def evaluate_step( + context: StepContext[EvaluateStep], +) -> StepOutcome[dict[str, Any]]: + exprs = context.definition.arguments + output = simple_eval_dict(exprs, values=context.model_dump()) + + return StepOutcome(output=output) diff --git a/agents-api/agents_api/activities/task_steps/if_else_step.py b/agents-api/agents_api/activities/task_steps/if_else_step.py new file mode 100644 index 000000000..e179b05a6 --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/if_else_step.py @@ -0,0 +1,25 @@ +from beartype import beartype +from simpleeval import simple_eval +from temporalio import activity + +from ...autogen.openapi_model import ( + IfElseWorkflowStep, +) +from ...common.protocol.tasks import ( + StepContext, +) + + +@activity.defn +@beartype +async def if_else_step(context: StepContext[IfElseWorkflowStep]) -> dict: + raise NotImplementedError() + # context_data: dict = context.model_dump() + + # next_workflow = ( + # context.definition.then + # if simple_eval(context.definition.if_, names=context_data) + # else context.definition.else_ + # ) + + # return {"goto_workflow": next_workflow} \ No newline at end of file diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py new file mode 100644 index 000000000..90ab975d3 --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -0,0 +1,60 @@ +import asyncio + +from beartype import beartype +from temporalio import activity + +from ...autogen.openapi_model import ( + InputChatMLMessage, + PromptStep, +) +from ...clients import ( + litellm, # We dont directly import `acompletion` so we can mock it +) +from ...common.protocol.tasks import ( + StepContext, + StepOutcome, +) +from ...common.utils.template import render_template + + +@activity.defn +@beartype +async def prompt_step(context: StepContext[PromptStep]) -> StepOutcome: + # Get context data + context_data: dict = context.model_dump() + + # Render template messages + prompt = ( + [InputChatMLMessage(content=context.definition.prompt)] + if isinstance(context.definition.prompt, str) + else context.definition.prompt + ) + + template_messages: list[InputChatMLMessage] = prompt + messages = await asyncio.gather( + *[ + render_template(msg.content, context_data, skip_vars=["developer_id"]) + for msg in template_messages + ] + ) + + messages = [ + ( + InputChatMLMessage(role="user", content=m) + if isinstance(m, str) + else InputChatMLMessage(**m) + ) + for m in messages + ] + + settings: dict = context.definition.settings.model_dump() + # Get settings and run llm + response = await litellm.acompletion( + messages=messages, + **settings, + ) + + return StepOutcome( + output=response.model_dump(), + next=None, + ) diff --git a/agents-api/agents_api/activities/task_steps/raise_complete_async.py b/agents-api/agents_api/activities/task_steps/raise_complete_async.py new file mode 100644 index 000000000..ca1200a87 --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/raise_complete_async.py @@ -0,0 +1,8 @@ +from temporalio import activity + + +@activity.defn +async def raise_complete_async() -> None: + + activity.heartbeat("Starting to wait") + activity.raise_complete_async() diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py new file mode 100644 index 000000000..0b12cad97 --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -0,0 +1,24 @@ +from beartype import beartype +from temporalio import activity + +from ...autogen.openapi_model import ( + ToolCallStep, +) +from ...common.protocol.tasks import ( + StepContext, +) + + + +@activity.defn +@beartype +async def tool_call_step(context: StepContext) -> dict: + raise NotImplementedError() + # assert isinstance(context.definition, ToolCallStep) + + # context.definition.tool_id + # context.definition.arguments + # # get tool by id + # # call tool + + # return {} \ No newline at end of file diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py new file mode 100644 index 000000000..b70428df8 --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -0,0 +1,36 @@ +from beartype import beartype +from temporalio import activity + +from ...autogen.openapi_model import ( + CreateTransitionRequest, +) +from ...common.protocol.tasks import ( + StepContext, +) +from ...models.execution.create_execution_transition import ( + create_execution_transition as create_execution_transition_query, +) + + +@activity.defn +@beartype +async def transition_step( + context: StepContext[None], + transition_info: CreateTransitionRequest, +) -> None: + need_to_wait = transition_info.type == "wait" + + # Get task token if it's a waiting step + if need_to_wait: + task_token = activity.info().task_token + transition_info.task_token = task_token + + # Create transition + activity.heartbeat("Creating transition in db") + create_execution_transition_query( + developer_id=context.developer_id, + execution_id=context.execution.id, + task_id=context.task.id, + data=transition_info, + update_execution_status=True, + ) \ No newline at end of file diff --git a/agents-api/agents_api/activities/task_steps/utils.py b/agents-api/agents_api/activities/task_steps/utils.py new file mode 100644 index 000000000..e3d953a4a --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/utils.py @@ -0,0 +1,11 @@ +from typing import Any + +from beartype import beartype +from simpleeval import simple_eval + + +@beartype +def simple_eval_dict( + exprs: dict[str, str], *, values: dict[str, Any] +) -> dict[str, Any]: + return {k: simple_eval(v, names=values) for k, v in exprs.items()} diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py new file mode 100644 index 000000000..0710c99e6 --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -0,0 +1,29 @@ +from typing import Any +from agents_api.autogen.Executions import TransitionTarget +from beartype import beartype +from temporalio import activity + +from ...autogen.openapi_model import ( + YieldStep, +) +from ...common.protocol.tasks import ( + StepContext, + StepOutcome, +) + +from .utils import simple_eval_dict + +@activity.defn +@beartype +async def yield_step(context: StepContext[YieldStep]) -> StepOutcome[dict[str, Any]]: + workflow = context.definition.workflow + exprs = context.definition.arguments + arguments = simple_eval_dict(exprs, values=context.model_dump()) + + transition_target = TransitionTarget( + workflow=workflow, + step=0, + ) + + return StepOutcome(output=arguments, transition_to=("step", transition_target)) + diff --git a/agents-api/agents_api/activities/truncation.py b/agents-api/agents_api/activities/truncation.py index 7f381ac0f..d0ce919f0 100644 --- a/agents-api/agents_api/activities/truncation.py +++ b/agents-api/agents_api/activities/truncation.py @@ -1,5 +1,6 @@ from uuid import UUID +from beartype import beartype from temporalio import activity from agents_api.autogen.openapi_model import Entry @@ -26,6 +27,7 @@ def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list @activity.defn +@beartype async def truncation(session_id: str, token_count_threshold: int) -> None: session_id = UUID(session_id) diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index 29ceedded..2130c3947 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -2,14 +2,14 @@ from temporalio.client import Client, TLSConfig -from agents_api.env import ( +from ..autogen.openapi_model import TransitionTarget +from ..common.protocol.tasks import ExecutionInput +from ..env import ( temporal_client_cert, temporal_namespace, temporal_private_key, temporal_worker_url, ) - -from ..common.protocol.tasks import ExecutionInput from ..worker.codec import pydantic_data_converter @@ -35,16 +35,19 @@ async def get_client( async def run_task_execution_workflow( + *, execution_input: ExecutionInput, job_id: UUID, - start: tuple[str, int] = ("main", 0), + start: TransitionTarget = TransitionTarget(workflow="main", step=0), previous_inputs: list[dict] = [], client: Client | None = None, ): + from ..workflows.task_execution import TaskExecutionWorkflow + client = client or (await get_client()) return await client.start_workflow( - "TaskExecutionWorkflow", + TaskExecutionWorkflow.run, args=[execution_input, start, previous_inputs], task_queue="memory-task-queue", id=str(job_id), diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 88a5dd3fd..5e273d657 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -1,7 +1,7 @@ from typing import Any, Generic, TypeVar from uuid import UUID -from pydantic import BaseModel +from pydantic import BaseModel, computed_field from ...autogen.openapi_model import ( Agent, @@ -16,6 +16,7 @@ TaskToolDef, Tool, TransitionTarget, + TransitionType, UpdateTaskRequest, User, Workflow, @@ -73,18 +74,29 @@ class StepContext(ExecutionInput, Generic[WorkflowStepType]): inputs: list[dict[str, Any]] current: TransitionTarget + @computed_field + @property + def outputs(self) -> list[dict[str, Any]]: + return self.inputs[1:] + + @computed_field + @property + def current_input(self) -> dict[str, Any]: + return self.inputs[-1] + def model_dump(self, *args, **kwargs) -> dict[str, Any]: dump = super().model_dump(*args, **kwargs) - - dump["_"] = self.inputs[-1] - dump["outputs"] = self.inputs[1:] + dump["_"] = self.current_input return dump -class StepOutcome(BaseModel): - output: dict[str, Any] - next: TransitionTarget | None = None +OutcomeType = TypeVar("OutcomeType", bound=BaseModel) + + +class StepOutcome(BaseModel, Generic[OutcomeType]): + output: OutcomeType | None + transition_to: tuple[TransitionType, TransitionTarget] | None = None def task_to_spec( diff --git a/agents-api/agents_api/workflows/task_execution.py b/agents-api/agents_api/workflows/task_execution.py index 495b62f62..be639697f 100644 --- a/agents-api/agents_api/workflows/task_execution.py +++ b/agents-api/agents_api/workflows/task_execution.py @@ -3,6 +3,8 @@ from datetime import timedelta +from agents_api.autogen.Executions import TransitionTarget +from agents_api.autogen.openapi_model import CreateTransitionRequest from temporalio import workflow with workflow.unsafe.imports_passed_through(): @@ -12,7 +14,9 @@ prompt_step, tool_call_step, transition_step, + yield_step, ) + from ..autogen.openapi_model import ( ErrorWorkflowStep, EvaluateStep, @@ -22,28 +26,40 @@ WaitForInputStep, YieldStep, ) + from ..common.protocol.tasks import ( ExecutionInput, StepContext, + StepOutcome, ) +STEP_TO_ACTIVITY = { + PromptStep: prompt_step, + EvaluateStep: evaluate_step, + ToolCallStep: tool_call_step, + IfElseWorkflowStep: if_else_step, + YieldStep: yield_step, +} + + @workflow.defn class TaskExecutionWorkflow: @workflow.run async def run( self, execution_input: ExecutionInput, - current: tuple[str, int] = ("main", 0), + start: TransitionTarget = TransitionTarget(workflow="main", step=0), previous_inputs: list[dict] = [], ) -> None: - wf_name, step_idx = current - workflow_map = {wf.name: wf.steps for wf in execution_input.task.workflows} - current_workflow = workflow_map[wf_name] previous_inputs = previous_inputs or [execution_input.arguments] - step = current_workflow[step_idx] + workflow_map = {wf.name: wf.steps for wf in execution_input.task.workflows} + + current_workflow = workflow_map[start.workflow] + step = current_workflow[start.step] + step_type = type(step) - context = StepContext( + context = StepContext[step_type]( developer_id=execution_input.developer_id, execution=execution_input.execution, task=execution_input.task, @@ -56,91 +72,132 @@ async def run( inputs=previous_inputs, ) - should_wait, is_error = False, False - # Run the step - match step: - case PromptStep(): - outputs = await workflow.execute_activity( - prompt_step, - context, - schedule_to_close_timeout=timedelta(seconds=600), - ) - - # TODO: ChatCompletion does not have tool_calls - # if outputs.tool_calls is not None: - # should_wait = True + next = None + outcome = None + activity = STEP_TO_ACTIVITY.get(step_type) + final_output = None + is_last = False - case EvaluateStep(): - outputs = await workflow.execute_activity( - evaluate_step, - context, - schedule_to_close_timeout=timedelta(seconds=600), - ) - case YieldStep(): - outputs = await workflow.execute_child_workflow( - TaskExecutionWorkflow.run, - args=[execution_input, (step.workflow, 0), previous_inputs], - ) - case ToolCallStep(): - outputs = await workflow.execute_activity( - tool_call_step, - context, - schedule_to_close_timeout=timedelta(seconds=600), + if activity: + outcome = await workflow.execute_activity( + activity, + context, + schedule_to_close_timeout=timedelta(seconds=600), + ) + + match step, outcome: + case YieldStep(), StepOutcome(output=output, transition_to=(transition_type, next)): + transition_request = CreateTransitionRequest( + type=transition_type, + current=start, + next=next, + output=output, ) - case ErrorWorkflowStep(): - is_error = True - case IfElseWorkflowStep(): - outputs = await workflow.execute_activity( - if_else_step, - context, + + await workflow.execute_activity( + transition_step, + args=[context, transition_request], schedule_to_close_timeout=timedelta(seconds=600), ) - workflow_step = YieldStep(**outputs["goto_workflow"]) - outputs = await workflow.execute_child_workflow( + yield_outcome: StepOutcome = await workflow.execute_child_workflow( TaskExecutionWorkflow.run, - args=[ - execution_input, - (workflow_step.workflow, 0), - previous_inputs, - ], + args=[execution_input, next, [output]], ) - case WaitForInputStep(): - should_wait = True - - is_last = step_idx + 1 == len(current_workflow) - # Transition type - transition_type = ( - "awaiting_input" - if should_wait - else ("finish" if is_last else ("error" if is_error else "step")) - ) - - # Transition to the next step - transition_info = TransitionInfo( - from_=(wf_name, step_idx), - to=None if (is_last or should_wait) else (wf_name, step_idx + 1), - type=transition_type, - ) - await workflow.execute_activity( - transition_step, - args=[ - context, - transition_info, - "failed" if is_error else "awaiting_input", - ], - schedule_to_close_timeout=timedelta(seconds=600), - ) - - # FIXME: this is just a demo, we should handle the end of the workflow properly - # ----- - - # End if the last step - if is_last: - return outputs - - # Otherwise, recurse to the next step - workflow.continue_as_new( - execution_input, (wf_name, step_idx + 1), previous_inputs + [outputs] - ) + final_output = yield_outcome.output + + case _: + raise NotImplementedError() + + is_last = start.step + 1 == len(current_workflow) + + ################## + + # should_wait, is_error = False, False + # # Run the step + # match step: + # case PromptStep(): + # outputs = await workflow.execute_activity( + # prompt_step, + # context, + # schedule_to_close_timeout=timedelta(seconds=600), + # ) + + # # TODO: ChatCompletion does not have tool_calls + # # if outputs.tool_calls is not None: + # # should_wait = True + + # case EvaluateStep(): + # outputs = await workflow.execute_activity( + # evaluate_step, + # context, + # schedule_to_close_timeout=timedelta(seconds=600), + # ) + # case YieldStep(): + # outputs = await workflow.execute_child_workflow( + # TaskExecutionWorkflow.run, + # args=[execution_input, (step.workflow, 0), previous_inputs], + # ) + # case ToolCallStep(): + # outputs = await workflow.execute_activity( + # tool_call_step, + # context, + # schedule_to_close_timeout=timedelta(seconds=600), + # ) + # case ErrorWorkflowStep(): + # is_error = True + # case IfElseWorkflowStep(): + # outputs = await workflow.execute_activity( + # if_else_step, + # context, + # schedule_to_close_timeout=timedelta(seconds=600), + # ) + # workflow_step = YieldStep(**outputs["goto_workflow"]) + + # outputs = await workflow.execute_child_workflow( + # TaskExecutionWorkflow.run, + # args=[ + # execution_input, + # (workflow_step.workflow, 0), + # previous_inputs, + # ], + # ) + # case WaitForInputStep(): + # should_wait = True + + # # Transition type + # transition_type = ( + # "awaiting_input" + # if should_wait + # else ("finish" if is_last else ("error" if is_error else "step")) + # ) + + # # Transition to the next step + # transition_info = TransitionInfo( + # from_=(wf_name, step_idx), + # to=None if (is_last or should_wait) else (wf_name, step_idx + 1), + # type=transition_type, + # ) + + # await workflow.execute_activity( + # transition_step, + # args=[ + # context, + # transition_info, + # "failed" if is_error else "awaiting_input", + # ], + # schedule_to_close_timeout=timedelta(seconds=600), + # ) + + # # FIXME: this is just a demo, we should handle the end of the workflow properly + # # ----- + + # # End if the last step + # if is_last: + # return outputs + + # # Otherwise, recurse to the next step + # workflow.continue_as_new( + # execution_input, (wf_name, step_idx + 1), previous_inputs + [outputs] + # )