Skip to content

Commit

Permalink
feat: Test messages truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Apr 24, 2024
1 parent 3affbdc commit c0a9076
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 19 deletions.
6 changes: 3 additions & 3 deletions agents-api/agents_api/routers/sessions/protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, validator, ConfigDict
from pydantic import BaseModel, Field, field_validator, ConfigDict
from agents_api.autogen.openapi_model import ResponseFormat, Preset, Tool


Expand All @@ -24,10 +24,10 @@ class Settings(BaseModel):
preset: Preset | None = Field(default=None)
tools: list[Tool] | None = Field(default=None)

@validator("max_tokens")
@field_validator("max_tokens")
def set_max_tokens(cls, max_tokens):
return max_tokens if max_tokens is not None else 200

@validator("stream")
@field_validator("stream")
def set_stream(cls, stream):
return stream or False
19 changes: 10 additions & 9 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,27 @@ def _remove_messages(
messages: list[Entry],
start_idx: int | None,
end_idx: int | None,
token_count: float,
token_count: int,
summarization_tokens_threshold: int,
predicate: Callable[[Entry], bool],
) -> tuple[list[Entry], float]:
) -> tuple[list[Entry], int]:
if len(messages) < abs((end_idx or len(messages)) - (start_idx or 0)):
return messages, token_count

result: list[Entry] = messages[:start_idx]
result: list[Entry] = messages[:start_idx or 0]
skip_check = False
for m in messages[start_idx:end_idx]:
if predicate(m) and not skip_check:
token_count -= len(m.content) / 3.5
if token_count < summarization_tokens_threshold:
token_count -= m.token_count
if token_count <= summarization_tokens_threshold:
skip_check = True

continue

result.append(m)

result += messages[end_idx:]
if end_idx is not None:
result += messages[end_idx:]

return result, token_count

Expand All @@ -83,9 +84,9 @@ def rm_thoughts(m):
def rm_user_assistant(m):
return m.role in ("user", "assistant")

token_count = reduce(lambda c, e: len(e.content) + c, messages, 0) / 3.5
token_count = reduce(lambda c, e: e.token_count + c, messages, 0)

if token_count < summarization_tokens_threshold and messages:
if token_count <= summarization_tokens_threshold:
return messages

for start_idx, end_idx, cond in [
Expand All @@ -102,7 +103,7 @@ def rm_user_assistant(m):
cond,
)

if token_count < summarization_tokens_threshold and messages:
if token_count <= summarization_tokens_threshold and messages:
return messages

# TODO:
Expand Down
10 changes: 4 additions & 6 deletions agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ openai = "^1.12.0"
httpx = "^0.26.0"
async-lru = "^2.0.4"
sentry-sdk = {extras = ["fastapi"], version = "^1.38.0"}
ward = "^0.68.0b0"
temporalio = "^1.4.0"
pydantic = "^2.5.3"
arrow = "^1.3.0"
Expand All @@ -38,6 +37,7 @@ poethepoet = "^0.25.1"
pytype = ">=2024.4.11"
julep = "^0.2.4"
pyjwt = "^2.8.0"
ward = "^0.68.0b0"

[build-system]
requires = ["poetry-core"]
Expand Down
8 changes: 8 additions & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from uuid import uuid4
from ward import fixture
from julep import AsyncClient, Client
from agents_api.routers.sessions.session import BaseSession


# TODO: make clients connect to real service


@fixture(scope="global")
def session():
return BaseSession(uuid4(), uuid4())


@fixture(scope="global")
def client():
# Mock server base url
Expand Down
222 changes: 222 additions & 0 deletions agents-api/tests/test_messages_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from uuid import uuid4
from ward import test, raises
from agents_api.common.protocol.entries import Entry
from agents_api.autogen.openapi_model import Role
from agents_api.routers.sessions.exceptions import InputTooBigError
from tests.fixtures import session


@test("truncate empty messages list", tags=["messages_truncate"])
def _(session=session):
messages: list[Entry] = []
result = session.truncate(messages, 10)

assert messages == result


@test("do not truncate", tags=["messages_truncate"])
def _(session=session):
contents = [
"content1",
"content2",
"content3",
]
threshold = sum([len(c) // 3.5 for c in contents])

messages: list[Entry] = [
Entry(session_id=uuid4(), role=Role.user, content=contents[0][0]),
Entry(session_id=uuid4(), role=Role.assistant, content=contents[1][0]),
Entry(session_id=uuid4(), role=Role.user, content=contents[2][0]),
]
result = session.truncate(messages, threshold)

assert messages == result


@test("truncate thoughts partially", tags=["messages_truncate"])
def _(session=session):
contents = [
("content1", True),
("content2", True),
("content3", False),
("content4", True),
("content5", True),
("content6", True),
]
session_ids = [uuid4()] * len(contents)
threshold = sum([len(c) // 3.5 for c, i in contents if i])

messages: list[Entry] = [
Entry(session_id=session_ids[0], role=Role.system, name="thought", content=contents[0][0]),
Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]),
Entry(session_id=session_ids[2], role=Role.system, name="thought", content=contents[2][0]),
Entry(session_id=session_ids[3], role=Role.system, name="thought", content=contents[3][0]),
Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]),
Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]),
]
result = session.truncate(messages, threshold)
x = [
messages[0],
messages[1],
messages[3],
messages[4],
messages[5],
]

assert result == [
messages[0],
messages[1],
messages[3],
messages[4],
messages[5],
]


@test("truncate thoughts partially 2", tags=["messages_truncate"])
def _(session=session):
contents = [
("content1", True),
("content2", True),
("content3", False),
("content4", False),
("content5", True),
("content6", True),
]
session_ids = [uuid4()] * len(contents)
threshold = sum([len(c) // 3.5 for c, i in contents if i])

messages: list[Entry] = [
Entry(session_id=session_ids[0], role=Role.system, name="thought", content=contents[0][0]),
Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]),
Entry(session_id=session_ids[2], role=Role.system, name="thought", content=contents[2][0]),
Entry(session_id=session_ids[3], role=Role.system, name="thought", content=contents[3][0]),
Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]),
Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]),
]
result = session.truncate(messages, threshold)

assert result == [
messages[0],
messages[1],
messages[4],
messages[5],
]


@test("truncate all thoughts", tags=["messages_truncate"])
def _(session=session):
contents = [
("content1", False),
("content2", True),
("content3", False),
("content4", False),
("content5", True),
("content6", True),
("content7", False),
]
session_ids = [uuid4()] * len(contents)
threshold = sum([len(c) // 3.5 for c, i in contents if i])

messages: list[Entry] = [
Entry(session_id=session_ids[0], role=Role.system, name="thought", content=contents[0][0]),
Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]),
Entry(session_id=session_ids[2], role=Role.system, name="thought", content=contents[2][0]),
Entry(session_id=session_ids[3], role=Role.system, name="thought", content=contents[3][0]),
Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]),
Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]),
Entry(session_id=session_ids[6], role=Role.system, name="thought", content=contents[6][0]),
]
result = session.truncate(messages, threshold)

assert result == [
messages[1],
messages[4],
messages[5],
]


@test("truncate user assistant pairs", tags=["messages_truncate"])
def _(session=session):
contents = [
("content1", False),
("content2", True),
("content3", False),
("content4", False),
("content5", True),
("content6", True),
("content7", True),
("content8", False),
("content9", True),
("content10", True),
("content11", True),
("content12", True),
("content13", False),
]
session_ids = [uuid4()] * len(contents)
threshold = sum([len(c) // 3.5 for c, i in contents if i])

messages: list[Entry] = [
Entry(session_id=session_ids[0], role=Role.system, name="thought", content=contents[0][0]),
Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]),
Entry(session_id=session_ids[2], role=Role.system, name="thought", content=contents[2][0]),
Entry(session_id=session_ids[3], role=Role.system, name="thought", content=contents[3][0]),
Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]),
Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]),
Entry(session_id=session_ids[6], role=Role.user, content=contents[6][0]),
Entry(session_id=session_ids[7], role=Role.assistant, content=contents[7][0]),
Entry(session_id=session_ids[8], role=Role.user, content=contents[8][0]),
Entry(session_id=session_ids[9], role=Role.assistant, content=contents[9][0]),
Entry(session_id=session_ids[10], role=Role.user, content=contents[10][0]),
Entry(session_id=session_ids[11], role=Role.assistant, content=contents[11][0]),
Entry(session_id=session_ids[12], role=Role.system, name="thought", content=contents[12][0]),
]

result = session.truncate(messages, threshold)

assert result == [
messages[1],
messages[4],
messages[5],
messages[6],
messages[8],
messages[9],
messages[10],
messages[11],
]


@test("unable to truncate", tags=["messages_truncate"])
def _(session=session):
contents = [
("content1", False),
("content2", True),
("content3", False),
("content4", False),
("content5", False),
("content6", False),
("content7", True),
("content8", False),
("content9", True),
("content10", False),
]
session_ids = [uuid4()] * len(contents)
threshold = sum([len(c) // 3.5 for c, i in contents if i])
all_tokens = sum([len(c) // 3.5 for c, _ in contents])

messages: list[Entry] = [
Entry(session_id=session_ids[0], role=Role.system, name="thought", content=contents[0][0]),
Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]),
Entry(session_id=session_ids[2], role=Role.system, name="thought", content=contents[2][0]),
Entry(session_id=session_ids[3], role=Role.system, name="thought", content=contents[3][0]),
Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]),
Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]),
Entry(session_id=session_ids[6], role=Role.user, content=contents[6][0]),
Entry(session_id=session_ids[7], role=Role.assistant, content=contents[7][0]),
Entry(session_id=session_ids[8], role=Role.user, content=contents[8][0]),
Entry(session_id=session_ids[9], role=Role.system, name="thought", content=contents[9][0]),
]
with raises(InputTooBigError) as ex:
session.truncate(messages, threshold)

assert str(ex.raised) == f"input is too big, {threshold} tokens required, but you got {all_tokens} tokens"

0 comments on commit c0a9076

Please sign in to comment.