Skip to content

Commit

Permalink
feat: Test messages truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Apr 24, 2024
1 parent 3affbdc commit 59471d4
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 27 deletions.
6 changes: 2 additions & 4 deletions agents-api/agents_api/activities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ class ChatML(BaseModel):
token_count: Optional[int] = None


class BaseTask(BaseModel):
...
class BaseTask(BaseModel): ...


class BaseTaskArgs(BaseModel):
...
class BaseTaskArgs(BaseModel): ...


class AddPrinciplesTaskArgs(BaseTaskArgs):
Expand Down
6 changes: 2 additions & 4 deletions agents-api/agents_api/clients/worker/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ class ChatML(BaseModel):
token_count: Optional[int] = None


class BaseTask(BaseModel):
...
class BaseTask(BaseModel): ...


class BaseTaskArgs(BaseModel):
...
class BaseTaskArgs(BaseModel): ...


class MemoryManagementTaskArgs(BaseTaskArgs):
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/routers/sessions/protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, validator, ConfigDict
from pydantic import BaseModel, Field, field_validator, ConfigDict
from agents_api.autogen.openapi_model import ResponseFormat, Preset, Tool


Expand All @@ -24,10 +24,10 @@ class Settings(BaseModel):
preset: Preset | None = Field(default=None)
tools: list[Tool] | None = Field(default=None)

@validator("max_tokens")
@field_validator("max_tokens")
def set_max_tokens(cls, max_tokens):
return max_tokens if max_tokens is not None else 200

@validator("stream")
@field_validator("stream")
def set_stream(cls, stream):
return stream or False
19 changes: 10 additions & 9 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,27 @@ def _remove_messages(
messages: list[Entry],
start_idx: int | None,
end_idx: int | None,
token_count: float,
token_count: int,
summarization_tokens_threshold: int,
predicate: Callable[[Entry], bool],
) -> tuple[list[Entry], float]:
) -> tuple[list[Entry], int]:
if len(messages) < abs((end_idx or len(messages)) - (start_idx or 0)):
return messages, token_count

result: list[Entry] = messages[:start_idx]
result: list[Entry] = messages[: start_idx or 0]
skip_check = False
for m in messages[start_idx:end_idx]:
if predicate(m) and not skip_check:
token_count -= len(m.content) / 3.5
if token_count < summarization_tokens_threshold:
token_count -= m.token_count
if token_count <= summarization_tokens_threshold:
skip_check = True

continue

result.append(m)

result += messages[end_idx:]
if end_idx is not None:
result += messages[end_idx:]

return result, token_count

Expand All @@ -83,9 +84,9 @@ def rm_thoughts(m):
def rm_user_assistant(m):
return m.role in ("user", "assistant")

token_count = reduce(lambda c, e: len(e.content) + c, messages, 0) / 3.5
token_count = reduce(lambda c, e: e.token_count + c, messages, 0)

if token_count < summarization_tokens_threshold and messages:
if token_count <= summarization_tokens_threshold:
return messages

for start_idx, end_idx, cond in [
Expand All @@ -102,7 +103,7 @@ def rm_user_assistant(m):
cond,
)

if token_count < summarization_tokens_threshold and messages:
if token_count <= summarization_tokens_threshold and messages:
return messages

# TODO:
Expand Down
10 changes: 4 additions & 6 deletions agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ openai = "^1.12.0"
httpx = "^0.26.0"
async-lru = "^2.0.4"
sentry-sdk = {extras = ["fastapi"], version = "^1.38.0"}
ward = "^0.68.0b0"
temporalio = "^1.4.0"
pydantic = "^2.5.3"
arrow = "^1.3.0"
Expand All @@ -38,6 +37,7 @@ poethepoet = "^0.25.1"
pytype = ">=2024.4.11"
julep = "^0.2.4"
pyjwt = "^2.8.0"
ward = "^0.68.0b0"

[build-system]
requires = ["poetry-core"]
Expand Down
8 changes: 8 additions & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from uuid import uuid4
from ward import fixture
from julep import AsyncClient, Client
from agents_api.routers.sessions.session import BaseSession


# TODO: make clients connect to real service


@fixture(scope="global")
def base_session():
return BaseSession(uuid4(), uuid4())


@fixture(scope="global")
def client():
# Mock server base url
Expand Down
Loading

0 comments on commit 59471d4

Please sign in to comment.