Skip to content

Commit

Permalink
feat: Truncate messages (#294)
Browse files Browse the repository at this point in the history
* fix: Truncate initial messages list before response generation if amount of tokens is too big

* feat: Test messages truncation

* Update agents-api/agents_api/routers/sessions/exceptions.py

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

* refactor: Lint agents-api (CI)

---------

Co-authored-by: Diwank Singh Tomer <diwank.singh@gmail.com>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
Co-authored-by: creatorrr <creatorrr@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 25, 2024
1 parent 9a7bb85 commit 677da6f
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 11 deletions.
9 changes: 9 additions & 0 deletions agents-api/agents_api/routers/sessions/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class BaseSessionException(Exception):
pass


class InputTooBigError(BaseSessionException):
def __init__(self, actual_tokens, required_tokens):
super().__init__(
f"Input is too big, {actual_tokens} tokens provided, but only {required_tokens} tokens are allowed."
)
6 changes: 3 additions & 3 deletions agents-api/agents_api/routers/sessions/protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, validator, ConfigDict
from pydantic import BaseModel, Field, field_validator, ConfigDict
from agents_api.autogen.openapi_model import ResponseFormat, Preset, Tool


Expand All @@ -24,10 +24,10 @@ class Settings(BaseModel):
preset: Preset | None = Field(default=None)
tools: list[Tool] | None = Field(default=None)

@validator("max_tokens")
@field_validator("max_tokens")
def set_max_tokens(cls, max_tokens):
return max_tokens if max_tokens is not None else 200

@validator("stream")
@field_validator("stream")
def set_stream(cls, stream):
return stream or False
76 changes: 75 additions & 1 deletion agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from functools import reduce
from json import JSONDecodeError
from typing import Callable
from uuid import uuid4
Expand All @@ -22,6 +23,11 @@
)
from ...common.protocol.sessions import SessionData
from .protocol import Settings
from .exceptions import InputTooBigError


THOUGHTS_STRIP_LEN = 2
MESSAGES_STRIP_LEN = 4


tool_query_instruction = (
Expand All @@ -40,6 +46,72 @@ class BaseSession:
session_id: UUID4
developer_id: UUID4

def _remove_messages(
self,
messages: list[Entry],
start_idx: int | None,
end_idx: int | None,
token_count: int,
summarization_tokens_threshold: int,
predicate: Callable[[Entry], bool],
) -> tuple[list[Entry], int]:
if len(messages) < abs((end_idx or len(messages)) - (start_idx or 0)):
return messages, token_count

result: list[Entry] = messages[: start_idx or 0]
skip_check = False
for m in messages[start_idx:end_idx]:
if predicate(m) and not skip_check:
token_count -= m.token_count
if token_count <= summarization_tokens_threshold:
skip_check = True

continue

result.append(m)

if end_idx is not None:
result += messages[end_idx:]

return result, token_count

def truncate(
self, messages: list[Entry], summarization_tokens_threshold: int
) -> list[Entry]:
def rm_thoughts(m):
return m.role == "system" and m.name == "thought"

def rm_user_assistant(m):
return m.role in ("user", "assistant")

token_count = reduce(lambda c, e: e.token_count + c, messages, 0)

if token_count <= summarization_tokens_threshold:
return messages

for start_idx, end_idx, cond in [
(THOUGHTS_STRIP_LEN, -THOUGHTS_STRIP_LEN, rm_thoughts),
(None, None, rm_thoughts),
(MESSAGES_STRIP_LEN, -MESSAGES_STRIP_LEN, rm_user_assistant),
]:
messages, token_count = self._remove_messages(
messages,
start_idx,
end_idx,
token_count,
summarization_tokens_threshold,
cond,
)

if token_count <= summarization_tokens_threshold and messages:
return messages

# TODO:
# Compress info sections using LLM Lingua
# - If more space is still needed, remove info sections iteratively

raise InputTooBigError(token_count, summarization_tokens_threshold)

async def run(
self, new_input, settings: Settings
) -> tuple[ChatCompletion, Entry, Callable]:
Expand All @@ -53,7 +125,9 @@ async def run(
session_data, new_input, settings
)
# Generate response
response = await self.generate(init_context, final_settings)
response = await self.generate(
self.truncate(init_context, summarization_tokens_threshold), final_settings
)
# Save response to session
# if final_settings.get("remember"):
# await self.add_to_session(new_input, response)
Expand Down
10 changes: 4 additions & 6 deletions agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ openai = "^1.12.0"
httpx = "^0.26.0"
async-lru = "^2.0.4"
sentry-sdk = {extras = ["fastapi"], version = "^1.38.0"}
ward = "^0.68.0b0"
temporalio = "^1.4.0"
pydantic = "^2.5.3"
arrow = "^1.3.0"
Expand All @@ -38,6 +37,7 @@ poethepoet = "^0.25.1"
pytype = ">=2024.4.11"
julep = "^0.2.4"
pyjwt = "^2.8.0"
ward = "^0.68.0b0"

[build-system]
requires = ["poetry-core"]
Expand Down
8 changes: 8 additions & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from uuid import uuid4
from ward import fixture
from julep import AsyncClient, Client
from agents_api.routers.sessions.session import BaseSession


# TODO: make clients connect to real service


@fixture(scope="global")
def base_session():
return BaseSession(uuid4(), uuid4())


@fixture(scope="global")
def client():
# Mock server base url
Expand Down
Loading

0 comments on commit 677da6f

Please sign in to comment.