Skip to content

Commit

Permalink
feat(agents-api): Add parallelism option to map-reduce step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <diwank.singh@gmail.com>
  • Loading branch information
creatorrr committed Sep 4, 2024
1 parent 1860574 commit c96e0fb
Show file tree
Hide file tree
Showing 16 changed files with 243 additions and 70 deletions.
9 changes: 8 additions & 1 deletion agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
async def base_evaluate(
exprs: str | list[str] | dict[str, str],
values: dict[str, Any] = {},
extra_lambda_strs: dict[str, str] | None = None,
) -> Any | list[Any] | dict[str, Any]:
input_len = 1 if isinstance(exprs, str) else len(exprs)
assert input_len > 0, "exprs must be a non-empty string, list or dict"

extra_lambdas = {}
if extra_lambda_strs:
for k, v in extra_lambda_strs.items():
assert v.startswith("lambda "), "All extra lambdas must start with 'lambda'"
extra_lambdas[k] = eval(v)

# Turn the nested dict values from pydantic to dicts where possible
values = {
k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()
Expand All @@ -25,7 +32,7 @@ async def base_evaluate(
# frozen_box doesn't work coz we need some mutability in the values
values = Box(values, frozen_box=False, conversion_box=True)

evaluator = get_evaluator(names=values)
evaluator = get_evaluator(names=values, extra_functions=extra_lambdas)

try:
match exprs:
Expand Down
17 changes: 14 additions & 3 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
from typing import Any
from functools import reduce
from itertools import accumulate
from typing import Any, Callable

import re2
import yaml
Expand All @@ -10,6 +12,7 @@
# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"abs": abs,
"accumulate": accumulate,
"all": all,
"any": any,
"bool": bool,
Expand All @@ -22,9 +25,12 @@
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"map": map,
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"range": range,
"reduce": reduce,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"set": set,
Expand All @@ -36,8 +42,13 @@


@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names, functions=ALLOWED_FUNCTIONS)
def get_evaluator(
names: dict[str, Any], extra_functions: dict[str, Callable] | None = None
) -> SimpleEval:
evaluator = EvalWithCompoundTypes(
names=names, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
)

return evaluator


Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ class Main(BaseModel):
"""
The initial value of the reduce expression
"""
parallel: StrictBool = False
parallelism: int | None = None
"""
Whether to run the reduce expression in parallel
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


Expand Down Expand Up @@ -391,9 +391,9 @@ class MainModel(BaseModel):
"""
The initial value of the reduce expression
"""
parallel: StrictBool = False
parallelism: int | None = None
"""
Whether to run the reduce expression in parallel
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
api_prefix: str = env.str("AGENTS_API_PREFIX", default="")


# Tasks
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)

# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand Down
85 changes: 51 additions & 34 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@
StepOutcome,
)
from ...env import debug, testing
from .transition import transition
from .helpers import (
continue_as_child,
execute_switch_branch,
execute_if_else_branch,
execute_foreach_step,
execute_if_else_branch,
execute_map_reduce_step,
execute_map_reduce_step_parallel,
execute_switch_branch,
)
from .transition import transition

# Supported steps
# ---------------
Expand Down Expand Up @@ -247,12 +248,12 @@ async def run(

case SwitchStep(switch=switch), StepOutcome(output=index) if index >= 0:
result = await execute_switch_branch(
context,
execution_input,
switch,
index,
previous_inputs,
self.user_state,
context=context,
execution_input=execution_input,
switch=switch,
index=index,
previous_inputs=previous_inputs,
user_state=self.user_state,
)
state = PartialTransition(output=result)

Expand All @@ -264,40 +265,56 @@ async def run(
output=condition
):
result = await execute_if_else_branch(
context,
execution_input,
then_branch,
else_branch,
condition,
previous_inputs,
self.user_state,
context=context,
execution_input=execution_input,
then_branch=then_branch,
else_branch=else_branch,
condition=condition,
previous_inputs=previous_inputs,
user_state=self.user_state,
)

state = PartialTransition(output=result)

case ForeachStep(foreach=ForeachDo(do=do_step)), StepOutcome(output=items):
result = await execute_foreach_step(
context,
execution_input,
do_step,
items,
previous_inputs,
self.user_state,
context=context,
execution_input=execution_input,
do_step=do_step,
items=items,
previous_inputs=previous_inputs,
user_state=self.user_state,
)
state = PartialTransition(output=result)

case MapReduceStep(
map=map_defn, reduce=reduce, initial=initial
), StepOutcome(output=items):
map=map_defn, reduce=reduce, initial=initial, parallelism=parallelism
), StepOutcome(output=items) if parallelism is None or parallelism == 1:
result = await execute_map_reduce_step(
context,
execution_input,
map_defn,
reduce,
initial,
items,
previous_inputs,
self.user_state,
context=context,
execution_input=execution_input,
map_defn=map_defn,
items=items,
reduce=reduce,
initial=initial,
previous_inputs=previous_inputs,
user_state=self.user_state,
)
state = PartialTransition(output=result)

case MapReduceStep(
map=map_defn, reduce=reduce, initial=initial, parallelism=parallelism
), StepOutcome(output=items):
result = await execute_map_reduce_step_parallel(
context=context,
execution_input=execution_input,
map_defn=map_defn,
items=items,
previous_inputs=previous_inputs,
user_state=self.user_state,
initial=initial,
reduce=reduce,
parallelism=parallelism,
)
state = PartialTransition(output=result)

Expand Down Expand Up @@ -351,7 +368,7 @@ async def run(
)

result = await continue_as_child(
execution_input=execution_input,
context,
start=yield_next_target,
previous_inputs=[output],
user_state=self.user_state,
Expand Down Expand Up @@ -459,7 +476,7 @@ async def run(

# Continue as a child workflow
return await continue_as_child(
execution_input=execution_input,
context,
start=final_state.next,
previous_inputs=previous_inputs + [final_state.output],
user_state=self.user_state,
Expand Down
Loading

0 comments on commit c96e0fb

Please sign in to comment.