Skip to content

Commit

Permalink
Fix llm client retry
Browse files Browse the repository at this point in the history
  • Loading branch information
danielchalef committed Sep 10, 2024
1 parent 3f12254 commit a053889
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 12 deletions.
3 changes: 2 additions & 1 deletion graphiti_core/llm_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
from .openai_client import OpenAIClient

__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
10 changes: 9 additions & 1 deletion graphiti_core/llm_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import logging
import typing

import anthropic
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError

logger = logging.getLogger(__name__)

Expand All @@ -35,7 +37,11 @@ def __init__(self, config: LLMConfig | None = None, cache: bool = False):
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.client = AsyncAnthropic(api_key=config.api_key)
self.client = AsyncAnthropic(
api_key=config.api_key,
# we'll use tenacity to retry
max_retries=1,
)

def get_embedder(self) -> typing.Any:
openai_client = AsyncOpenAI()
Expand All @@ -58,6 +64,8 @@ async def _generate_response(self, messages: list[Message]) -> dict[str, typing.
)

return json.loads('{' + result.content[0].text) # type: ignore
except anthropic.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
27 changes: 17 additions & 10 deletions graphiti_core/llm_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@

import httpx
from diskcache import Cache
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential

from ..prompts.models import Message
from .config import LLMConfig
from .errors import RateLimitError

DEFAULT_TEMPERATURE = 0
DEFAULT_CACHE_DIR = './llm_cache'

logger = logging.getLogger(__name__)


def is_server_error(exception):
def is_server_or_retry_error(exception):
if isinstance(exception, RateLimitError):
return True

return (
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
)
Expand All @@ -56,18 +60,21 @@ def get_embedder(self) -> typing.Any:
pass

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception(is_server_error),
stop=stop_after_attempt(4),
wait=wait_random_exponential(multiplier=10, min=5, max=120),
retry=retry_if_exception(is_server_or_retry_error),
after=lambda retry_state: logger.warning(
f'Retrying {retry_state.fn.__name__} after {retry_state.attempt_number} attempts...'

Check failure on line 67 in graphiti_core/llm_client/client.py

View workflow job for this annotation

GitHub Actions / mypy

union-attr

Item "None" of "WrappedFn | None" has no attribute "__name__"
)
if retry_state.attempt_number > 1
else None,
reraise=True,
)
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages)
except httpx.HTTPStatusError as e:
if not is_server_error(e):
raise Exception(f'LLM request error: {e}') from e
else:
raise
except (httpx.HTTPStatusError, RateLimitError) as e:
raise e

@abstractmethod
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
Expand Down
6 changes: 6 additions & 0 deletions graphiti_core/llm_client/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class RateLimitError(Exception):
"""Exception raised when the rate limit is exceeded."""

def __init__(self, message='Rate limit exceeded. Please try again later.'):
self.message = message
super().__init__(self.message)
4 changes: 4 additions & 0 deletions graphiti_core/llm_client/groq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import logging
import typing

import groq
from groq import AsyncGroq
from groq.types.chat import ChatCompletionMessageParam
from openai import AsyncOpenAI

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,6 +61,8 @@ async def _generate_response(self, messages: list[Message]) -> dict[str, typing.
)
result = response.choices[0].message.content or ''
return json.loads(result)
except groq.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
4 changes: 4 additions & 0 deletions graphiti_core/llm_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import logging
import typing

import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,6 +61,8 @@ async def _generate_response(self, messages: list[Message]) -> dict[str, typing.
)
result = response.choices[0].message.content or ''
return json.loads(result)
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

0 comments on commit a053889

Please sign in to comment.