diff --git a/agents-api/agents_api/routers/sessions/exceptions.py b/agents-api/agents_api/routers/sessions/exceptions.py new file mode 100644 index 000000000..add4b79cb --- /dev/null +++ b/agents-api/agents_api/routers/sessions/exceptions.py @@ -0,0 +1,9 @@ +class BaseSessionException(Exception): + pass + + +class InputTooBigError(BaseSessionException): + def __init__(self, actual_tokens, required_tokens): + super().__init__( + f"Input is too big, {actual_tokens} tokens provided, but only {required_tokens} tokens are allowed." + ) diff --git a/agents-api/agents_api/routers/sessions/protocol.py b/agents-api/agents_api/routers/sessions/protocol.py index bb4ae9b49..d630762db 100644 --- a/agents-api/agents_api/routers/sessions/protocol.py +++ b/agents-api/agents_api/routers/sessions/protocol.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, validator, ConfigDict +from pydantic import BaseModel, Field, field_validator, ConfigDict from agents_api.autogen.openapi_model import ResponseFormat, Preset, Tool @@ -24,10 +24,10 @@ class Settings(BaseModel): preset: Preset | None = Field(default=None) tools: list[Tool] | None = Field(default=None) - @validator("max_tokens") + @field_validator("max_tokens") def set_max_tokens(cls, max_tokens): return max_tokens if max_tokens is not None else 200 - @validator("stream") + @field_validator("stream") def set_stream(cls, stream): return stream or False diff --git a/agents-api/agents_api/routers/sessions/session.py b/agents-api/agents_api/routers/sessions/session.py index 76cf70f54..421bcd6bf 100644 --- a/agents-api/agents_api/routers/sessions/session.py +++ b/agents-api/agents_api/routers/sessions/session.py @@ -1,4 +1,5 @@ import json +from functools import reduce from json import JSONDecodeError from typing import Callable from uuid import uuid4 @@ -22,6 +23,11 @@ ) from ...common.protocol.sessions import SessionData from .protocol import Settings +from .exceptions import InputTooBigError + + +THOUGHTS_STRIP_LEN = 2 +MESSAGES_STRIP_LEN = 4 tool_query_instruction = ( @@ -40,6 +46,72 @@ class BaseSession: session_id: UUID4 developer_id: UUID4 + def _remove_messages( + self, + messages: list[Entry], + start_idx: int | None, + end_idx: int | None, + token_count: int, + summarization_tokens_threshold: int, + predicate: Callable[[Entry], bool], + ) -> tuple[list[Entry], int]: + if len(messages) < abs((end_idx or len(messages)) - (start_idx or 0)): + return messages, token_count + + result: list[Entry] = messages[: start_idx or 0] + skip_check = False + for m in messages[start_idx:end_idx]: + if predicate(m) and not skip_check: + token_count -= m.token_count + if token_count <= summarization_tokens_threshold: + skip_check = True + + continue + + result.append(m) + + if end_idx is not None: + result += messages[end_idx:] + + return result, token_count + + def truncate( + self, messages: list[Entry], summarization_tokens_threshold: int + ) -> list[Entry]: + def rm_thoughts(m): + return m.role == "system" and m.name == "thought" + + def rm_user_assistant(m): + return m.role in ("user", "assistant") + + token_count = reduce(lambda c, e: e.token_count + c, messages, 0) + + if token_count <= summarization_tokens_threshold: + return messages + + for start_idx, end_idx, cond in [ + (THOUGHTS_STRIP_LEN, -THOUGHTS_STRIP_LEN, rm_thoughts), + (None, None, rm_thoughts), + (MESSAGES_STRIP_LEN, -MESSAGES_STRIP_LEN, rm_user_assistant), + ]: + messages, token_count = self._remove_messages( + messages, + start_idx, + end_idx, + token_count, + summarization_tokens_threshold, + cond, + ) + + if token_count <= summarization_tokens_threshold and messages: + return messages + + # TODO: + # Compress info sections using LLM Lingua + # - If more space is still needed, remove info sections iteratively + + raise InputTooBigError(token_count, summarization_tokens_threshold) + async def run( self, new_input, settings: Settings ) -> tuple[ChatCompletion, Entry, Callable]: @@ -53,7 +125,9 @@ async def run( session_data, new_input, settings ) # Generate response - response = await self.generate(init_context, final_settings) + response = await self.generate( + self.truncate(init_context, summarization_tokens_threshold), final_settings + ) # Save response to session # if final_settings.get("remember"): # await self.add_to_session(new_input, response) diff --git a/agents-api/poetry.lock b/agents-api/poetry.lock index 069ae72ad..87aa0d62b 100644 --- a/agents-api/poetry.lock +++ b/agents-api/poetry.lock @@ -369,7 +369,6 @@ python-versions = ">=3.7" files = [ {file = "cozo_embedded-0.7.6-cp37-abi3-macosx_10_14_x86_64.whl", hash = "sha256:d146e76736beb5e14e0cf73dc8babefadfbbc358b325c94c64a51b6d5b0031e9"}, {file = "cozo_embedded-0.7.6-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:7341fa266369181bbc19ad9e68820b51900b0fe1c947318a3d860b570dca6e09"}, - {file = "cozo_embedded-0.7.6-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80de79554138628967d4fd2636fc0a0a8dcca1c0c3bb527e638f1ee6cb763d7d"}, {file = "cozo_embedded-0.7.6-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7571f6521041c13b7e9ca8ab8809cf9c8eaad929726ed6190ffc25a5a3ab57a7"}, {file = "cozo_embedded-0.7.6-cp37-abi3-win_amd64.whl", hash = "sha256:c945ab7b350d0b79d3e643b68ebc8343fc02d223a02ab929eb0fb8e4e0df3542"}, ] @@ -1772,13 +1771,13 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest- [[package]] name = "pluggy" -version = "1.4.0" +version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" files = [ - {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, - {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] [package.extras] @@ -2222,7 +2221,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2875,4 +2873,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "93d7b08d0efde1bc338b304d83e749abafc88fdbaf78a2b5ae2570492fa3485b" +content-hash = "f183ab0cfd97f2a3ad799e6df54e9f85bc50604314b884bb2ef66f3399c8556c" diff --git a/agents-api/pyproject.toml b/agents-api/pyproject.toml index 30e7b9120..a2e06093d 100644 --- a/agents-api/pyproject.toml +++ b/agents-api/pyproject.toml @@ -19,7 +19,6 @@ openai = "^1.12.0" httpx = "^0.26.0" async-lru = "^2.0.4" sentry-sdk = {extras = ["fastapi"], version = "^1.38.0"} -ward = "^0.68.0b0" temporalio = "^1.4.0" pydantic = "^2.5.3" arrow = "^1.3.0" @@ -38,6 +37,7 @@ poethepoet = "^0.25.1" pytype = ">=2024.4.11" julep = "^0.2.4" pyjwt = "^2.8.0" +ward = "^0.68.0b0" [build-system] requires = ["poetry-core"] diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 0c847b3a5..d409e4e3e 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,9 +1,17 @@ +from uuid import uuid4 from ward import fixture from julep import AsyncClient, Client +from agents_api.routers.sessions.session import BaseSession + # TODO: make clients connect to real service +@fixture(scope="global") +def base_session(): + return BaseSession(uuid4(), uuid4()) + + @fixture(scope="global") def client(): # Mock server base url diff --git a/agents-api/tests/test_messages_truncation.py b/agents-api/tests/test_messages_truncation.py new file mode 100644 index 000000000..14a905d90 --- /dev/null +++ b/agents-api/tests/test_messages_truncation.py @@ -0,0 +1,314 @@ +from uuid import uuid4 +from ward import test, raises +from agents_api.common.protocol.entries import Entry +from agents_api.autogen.openapi_model import Role +from agents_api.routers.sessions.exceptions import InputTooBigError +from tests.fixtures import base_session + + +@test("truncate empty messages list", tags=["messages_truncate"]) +def _(session=base_session): + messages: list[Entry] = [] + result = session.truncate(messages, 10) + + assert messages == result + + +@test("do not truncate", tags=["messages_truncate"]) +def _(session=base_session): + contents = [ + "content1", + "content2", + "content3", + ] + threshold = sum([len(c) // 3.5 for c in contents]) + + messages: list[Entry] = [ + Entry(session_id=uuid4(), role=Role.user, content=contents[0][0]), + Entry(session_id=uuid4(), role=Role.assistant, content=contents[1][0]), + Entry(session_id=uuid4(), role=Role.user, content=contents[2][0]), + ] + result = session.truncate(messages, threshold) + + assert messages == result + + +@test("truncate thoughts partially", tags=["messages_truncate"]) +def _(session=base_session): + contents = [ + ("content1", True), + ("content2", True), + ("content3", False), + ("content4", True), + ("content5", True), + ("content6", True), + ] + session_ids = [uuid4()] * len(contents) + threshold = sum([len(c) // 3.5 for c, i in contents if i]) + + messages: list[Entry] = [ + Entry( + session_id=session_ids[0], + role=Role.system, + name="thought", + content=contents[0][0], + ), + Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]), + Entry( + session_id=session_ids[2], + role=Role.system, + name="thought", + content=contents[2][0], + ), + Entry( + session_id=session_ids[3], + role=Role.system, + name="thought", + content=contents[3][0], + ), + Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]), + Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]), + ] + result = session.truncate(messages, threshold) + [ + messages[0], + messages[1], + messages[3], + messages[4], + messages[5], + ] + + assert result == [ + messages[0], + messages[1], + messages[3], + messages[4], + messages[5], + ] + + +@test("truncate thoughts partially 2", tags=["messages_truncate"]) +def _(session=base_session): + contents = [ + ("content1", True), + ("content2", True), + ("content3", False), + ("content4", False), + ("content5", True), + ("content6", True), + ] + session_ids = [uuid4()] * len(contents) + threshold = sum([len(c) // 3.5 for c, i in contents if i]) + + messages: list[Entry] = [ + Entry( + session_id=session_ids[0], + role=Role.system, + name="thought", + content=contents[0][0], + ), + Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]), + Entry( + session_id=session_ids[2], + role=Role.system, + name="thought", + content=contents[2][0], + ), + Entry( + session_id=session_ids[3], + role=Role.system, + name="thought", + content=contents[3][0], + ), + Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]), + Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]), + ] + result = session.truncate(messages, threshold) + + assert result == [ + messages[0], + messages[1], + messages[4], + messages[5], + ] + + +@test("truncate all thoughts", tags=["messages_truncate"]) +def _(session=base_session): + contents = [ + ("content1", False), + ("content2", True), + ("content3", False), + ("content4", False), + ("content5", True), + ("content6", True), + ("content7", False), + ] + session_ids = [uuid4()] * len(contents) + threshold = sum([len(c) // 3.5 for c, i in contents if i]) + + messages: list[Entry] = [ + Entry( + session_id=session_ids[0], + role=Role.system, + name="thought", + content=contents[0][0], + ), + Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]), + Entry( + session_id=session_ids[2], + role=Role.system, + name="thought", + content=contents[2][0], + ), + Entry( + session_id=session_ids[3], + role=Role.system, + name="thought", + content=contents[3][0], + ), + Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]), + Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]), + Entry( + session_id=session_ids[6], + role=Role.system, + name="thought", + content=contents[6][0], + ), + ] + result = session.truncate(messages, threshold) + + assert result == [ + messages[1], + messages[4], + messages[5], + ] + + +@test("truncate user assistant pairs", tags=["messages_truncate"]) +def _(session=base_session): + contents = [ + ("content1", False), + ("content2", True), + ("content3", False), + ("content4", False), + ("content5", True), + ("content6", True), + ("content7", True), + ("content8", False), + ("content9", True), + ("content10", True), + ("content11", True), + ("content12", True), + ("content13", False), + ] + session_ids = [uuid4()] * len(contents) + threshold = sum([len(c) // 3.5 for c, i in contents if i]) + + messages: list[Entry] = [ + Entry( + session_id=session_ids[0], + role=Role.system, + name="thought", + content=contents[0][0], + ), + Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]), + Entry( + session_id=session_ids[2], + role=Role.system, + name="thought", + content=contents[2][0], + ), + Entry( + session_id=session_ids[3], + role=Role.system, + name="thought", + content=contents[3][0], + ), + Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]), + Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]), + Entry(session_id=session_ids[6], role=Role.user, content=contents[6][0]), + Entry(session_id=session_ids[7], role=Role.assistant, content=contents[7][0]), + Entry(session_id=session_ids[8], role=Role.user, content=contents[8][0]), + Entry(session_id=session_ids[9], role=Role.assistant, content=contents[9][0]), + Entry(session_id=session_ids[10], role=Role.user, content=contents[10][0]), + Entry(session_id=session_ids[11], role=Role.assistant, content=contents[11][0]), + Entry( + session_id=session_ids[12], + role=Role.system, + name="thought", + content=contents[12][0], + ), + ] + + result = session.truncate(messages, threshold) + + assert result == [ + messages[1], + messages[4], + messages[5], + messages[6], + messages[8], + messages[9], + messages[10], + messages[11], + ] + + +@test("unable to truncate", tags=["messages_truncate"]) +def _(session=base_session): + contents = [ + ("content1", False), + ("content2", True), + ("content3", False), + ("content4", False), + ("content5", False), + ("content6", False), + ("content7", True), + ("content8", False), + ("content9", True), + ("content10", False), + ] + session_ids = [uuid4()] * len(contents) + threshold = sum([len(c) // 3.5 for c, i in contents if i]) + all_tokens = sum([len(c) // 3.5 for c, _ in contents]) + + messages: list[Entry] = [ + Entry( + session_id=session_ids[0], + role=Role.system, + name="thought", + content=contents[0][0], + ), + Entry(session_id=session_ids[1], role=Role.assistant, content=contents[1][0]), + Entry( + session_id=session_ids[2], + role=Role.system, + name="thought", + content=contents[2][0], + ), + Entry( + session_id=session_ids[3], + role=Role.system, + name="thought", + content=contents[3][0], + ), + Entry(session_id=session_ids[4], role=Role.user, content=contents[4][0]), + Entry(session_id=session_ids[5], role=Role.assistant, content=contents[5][0]), + Entry(session_id=session_ids[6], role=Role.user, content=contents[6][0]), + Entry(session_id=session_ids[7], role=Role.assistant, content=contents[7][0]), + Entry(session_id=session_ids[8], role=Role.user, content=contents[8][0]), + Entry( + session_id=session_ids[9], + role=Role.system, + name="thought", + content=contents[9][0], + ), + ] + with raises(InputTooBigError) as ex: + session.truncate(messages, threshold) + + assert ( + str(ex.raised) + == f"input is too big, {threshold} tokens required, but you got {all_tokens} tokens" + )