Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llm client retry #102

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion graphiti_core/llm_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
from .openai_client import OpenAIClient

__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
10 changes: 9 additions & 1 deletion graphiti_core/llm_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import logging
import typing

import anthropic
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError

logger = logging.getLogger(__name__)

Expand All @@ -35,7 +37,11 @@ def __init__(self, config: LLMConfig | None = None, cache: bool = False):
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.client = AsyncAnthropic(api_key=config.api_key)
self.client = AsyncAnthropic(
api_key=config.api_key,
# we'll use tenacity to retry
max_retries=1,
)

def get_embedder(self) -> typing.Any:
openai_client = AsyncOpenAI()
Expand All @@ -58,6 +64,8 @@ async def _generate_response(self, messages: list[Message]) -> dict[str, typing.
)

return json.loads('{' + result.content[0].text) # type: ignore
except anthropic.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
27 changes: 17 additions & 10 deletions graphiti_core/llm_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@

import httpx
from diskcache import Cache
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential

from ..prompts.models import Message
from .config import LLMConfig
from .errors import RateLimitError

DEFAULT_TEMPERATURE = 0
DEFAULT_CACHE_DIR = './llm_cache'

logger = logging.getLogger(__name__)


def is_server_error(exception):
def is_server_or_retry_error(exception):
if isinstance(exception, RateLimitError):
return True

return (
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
)
Expand All @@ -56,18 +60,21 @@
pass

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception(is_server_error),
stop=stop_after_attempt(4),
wait=wait_random_exponential(multiplier=10, min=5, max=120),
retry=retry_if_exception(is_server_or_retry_error),
after=lambda retry_state: logger.warning(
f'Retrying {retry_state.fn.__name__} after {retry_state.attempt_number} attempts...'

Check failure on line 67 in graphiti_core/llm_client/client.py

View workflow job for this annotation

GitHub Actions / mypy

union-attr

Item "None" of "WrappedFn | None" has no attribute "__name__"
)
if retry_state.attempt_number > 1
else None,
reraise=True,
)
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages)
except httpx.HTTPStatusError as e:
if not is_server_error(e):
raise Exception(f'LLM request error: {e}') from e
else:
raise
except (httpx.HTTPStatusError, RateLimitError) as e:
raise e

@abstractmethod
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
Expand Down
6 changes: 6 additions & 0 deletions graphiti_core/llm_client/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class RateLimitError(Exception):
"""Exception raised when the rate limit is exceeded."""

def __init__(self, message='Rate limit exceeded. Please try again later.'):
self.message = message
super().__init__(self.message)
4 changes: 4 additions & 0 deletions graphiti_core/llm_client/groq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import logging
import typing

import groq
from groq import AsyncGroq
from groq.types.chat import ChatCompletionMessageParam
from openai import AsyncOpenAI

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,6 +61,8 @@ async def _generate_response(self, messages: list[Message]) -> dict[str, typing.
)
result = response.choices[0].message.content or ''
return json.loads(result)
except groq.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
4 changes: 4 additions & 0 deletions graphiti_core/llm_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import logging
import typing

import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,6 +61,8 @@ async def _generate_response(self, messages: list[Message]) -> dict[str, typing.
)
result = response.choices[0].message.content or ''
return json.loads(result)
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
Loading