Skip to content

Commit

Permalink
fix(agents-api): Minor fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <diwank@julep.ai>
  • Loading branch information
Diwank Tomer committed Aug 15, 2024
1 parent adb2d9f commit c6456bc
Show file tree
Hide file tree
Showing 20 changed files with 183 additions and 169 deletions.
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from temporalio import activity

from agents_api.clients import embed as embedder
from agents_api.clients.cozo import get_cozo_client
from agents_api.models.docs.embed_snippets import embed_snippets as embed_snippets_query
from ..clients import embed as embedder
from ..clients.cozo import get_cozo_client
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query

snippet_embed_instruction = "Encode this passage for retrieval: "

Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
testing: bool = env.bool("AGENTS_API_TESTING", default=False)
sentry_dsn: str = env.str("SENTRY_DSN", default=None)


Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/agent/patch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["agent_id"], "jobs": [], **d},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/agent/update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["agent_id"], "jobs": [], **d},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/docs/embed_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["doc_id"], "updated_at": utcnow(), "jobs": []},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/execution/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["execution_id"], "jobs": [], **d},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"jobs": [],
**d,
},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"jobs": [],
**d,
},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/task/patch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"updated_at": d["updated_at_ms"][0] / 1000,
**d,
},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/tools/patch_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["tool_id"], "jobs": [], **d},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/tools/update_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["tool_id"], "jobs": [], **d},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/user/patch_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["user_id"], "jobs": [], **d},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/user/update_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["user_id"], "jobs": [], **d},
_kind="replaced",
_kind="inserted",
)
@cozo_query
@beartype
Expand Down
25 changes: 21 additions & 4 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Annotated
from uuid import UUID, uuid4

from fastapi import Depends
from fastapi import BackgroundTasks, Depends
from pydantic import UUID4
from starlette.status import HTTP_201_CREATED
from temporalio.client import Client as TemporalClient

from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...clients import temporal
from ...dependencies.developer_id import get_developer_id
from ...env import temporal_task_queue, testing
from ...models.docs.create_doc import create_doc as create_doc_query
from .router import router

Expand All @@ -18,23 +19,36 @@ async def run_embed_docs_task(
title: str,
content: list[str],
job_id: UUID,
background_tasks: BackgroundTasks,
client: TemporalClient | None = None,
):
from ...workflows.embed_docs import EmbedDocsWorkflow

client = client or (await temporal.get_client())

await client.execute_workflow(
"EmbedDocsWorkflow",
# TODO: Remove this conditional once we have a way to run workflows in
# a test environment.
if testing:
return None

handle = await client.start_workflow(
EmbedDocsWorkflow.run,
args=[str(doc_id), title, content],
task_queue="memory-task-queue",
task_queue=temporal_task_queue,
id=str(job_id),
)

background_tasks.add_task(handle.result)

return handle


@router.post("/users/{user_id}/docs", status_code=HTTP_201_CREATED, tags=["docs"])
async def create_user_doc(
user_id: UUID4,
data: CreateDocRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
doc = create_doc_query(
developer_id=x_developer_id,
Expand All @@ -50,6 +64,7 @@ async def create_user_doc(
title=doc.title,
content=doc.content,
job_id=embed_job_id,
background_tasks=background_tasks,
)

return ResourceCreatedResponse(
Expand All @@ -62,6 +77,7 @@ async def create_agent_doc(
agent_id: UUID4,
data: CreateDocRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
doc = create_doc_query(
developer_id=x_developer_id,
Expand All @@ -77,6 +93,7 @@ async def create_agent_doc(
title=doc.title,
content=doc.content,
job_id=embed_job_id,
background_tasks=background_tasks,
)

return ResourceCreatedResponse(
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio

from ..clients import temporal
from .worker import create_worker


Expand All @@ -16,7 +17,8 @@ async def main():
then starts the worker to listen for tasks on the configured task queue.
"""

worker = await create_worker()
client = await temporal.get_client()
worker = create_worker(client)

# Start the worker to listen for and process tasks
await worker.run()
Expand Down
49 changes: 23 additions & 26 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,35 @@
from temporalio.client import Client
from temporalio.worker import Worker

from ..activities.embed_docs import embed_docs
from ..activities.mem_mgmt import mem_mgmt
from ..activities.mem_rating import mem_rating
from ..activities.summarization import summarization
from ..activities.task_steps import (
evaluate_step,
if_else_step,
prompt_step,
tool_call_step,
transition_step,
yield_step,
)
from ..activities.truncation import truncation
from ..clients.temporal import get_client
from ..env import (
temporal_task_queue,
)
from ..workflows.embed_docs import EmbedDocsWorkflow
from ..workflows.mem_mgmt import MemMgmtWorkflow
from ..workflows.mem_rating import MemRatingWorkflow
from ..workflows.summarization import SummarizationWorkflow
from ..workflows.task_execution import TaskExecutionWorkflow
from ..workflows.truncation import TruncationWorkflow


async def create_worker(client: Client | None = None):
def create_worker(client: Client):
"""
Initializes the Temporal client and worker with TLS configuration (if provided),
then create a worker to listen for tasks on the configured task queue.
"""

client = client or await get_client()
from ..activities.embed_docs import embed_docs
from ..activities.mem_mgmt import mem_mgmt
from ..activities.mem_rating import mem_rating
from ..activities.summarization import summarization
from ..activities.task_steps import (
evaluate_step,
if_else_step,
prompt_step,
tool_call_step,
transition_step,
yield_step,
)
from ..activities.truncation import truncation
from ..env import (
temporal_task_queue,
)
from ..workflows.embed_docs import EmbedDocsWorkflow
from ..workflows.mem_mgmt import MemMgmtWorkflow
from ..workflows.mem_rating import MemRatingWorkflow
from ..workflows.summarization import SummarizationWorkflow
from ..workflows.task_execution import TaskExecutionWorkflow
from ..workflows.truncation import TruncationWorkflow

task_activities = [
prompt_step,
Expand Down
31 changes: 15 additions & 16 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from unittest.mock import patch
from uuid import uuid4

Expand Down Expand Up @@ -38,6 +39,7 @@
from agents_api.models.user.create_user import create_user
from agents_api.models.user.delete_user import delete_user
from agents_api.web import app
# from agents_api.worker.worker import create_worker
from agents_api.worker.worker import create_worker

EMBEDDING_SIZE: int = 1024
Expand All @@ -61,29 +63,26 @@ def activity_environment():


@fixture(scope="global")
async def workflow_environment():
wf_env = await WorkflowEnvironment.start_local()
yield wf_env
await wf_env.shutdown()
async def temporal_worker():
async with (await WorkflowEnvironment.start_local()) as env:
worker = create_worker(client=env.client)
worker_task = asyncio.create_task(worker.run())

yield worker

@fixture(scope="global")
async def temporal_worker(wf_env=workflow_environment):
worker = await create_worker(client=wf_env.client)

# FIXME: This does not stop the worker properly
c = worker.shutdown()
async with worker as running_worker:
yield running_worker
await c

kill_signal = worker.shutdown()
worker_task.cancel()
await asyncio.wait(
[kill_signal, worker_task],
return_when=asyncio.FIRST_COMPLETED,
)


@fixture(scope="test")
def patch_temporal_get_client(
wf_env=workflow_environment,
temporal_worker=temporal_worker,
):
mock_client = wf_env.client
mock_client = temporal_worker.client

with patch("agents_api.clients.temporal.get_client") as get_client:
get_client.return_value = mock_client
Expand Down
7 changes: 1 addition & 6 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
temporal_worker,
test_developer_id,
test_doc,
workflow_environment,
)

# from agents_api.activities.truncation import get_extra_entries
Expand All @@ -18,13 +17,9 @@

@test("activity: check that workflow environment and worker are started correctly")
async def _(
workflow_environment=workflow_environment,
worker=temporal_worker,
):
async with workflow_environment as wf_env:
assert wf_env is not None
assert worker is not None
assert worker.is_running
assert worker.is_running


@test("activity: call direct embed_docs")
Expand Down
Loading

0 comments on commit c6456bc

Please sign in to comment.