From 7794bcb680300249cd9be48562ce190eed8b9cff Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Wed, 22 May 2024 14:40:03 +0100 Subject: [PATCH] chore(messages): add back-compat for isinstance() checks --- src/anthropic/_streaming.py | 48 ++++++++++++++++++++++++++-- tests/lib/streaming/test_messages.py | 8 ++++- tests/test_streaming.py | 8 +++++ 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/anthropic/_streaming.py b/src/anthropic/_streaming.py index 73474c33..d43e2e6a 100644 --- a/src/anthropic/_streaming.py +++ b/src/anthropic/_streaming.py @@ -1,8 +1,10 @@ # Note: initially copied from /~https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py from __future__ import annotations +import abc import json import inspect +import warnings from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable @@ -18,7 +20,28 @@ _T = TypeVar("_T") -class Stream(Generic[_T]): +class _SyncStreamMeta(abc.ABCMeta): + @override + def __instancecheck__(self, instance: Any) -> bool: + # we override the `isinstance()` check for `Stream` + # as a previous version of the `MessageStream` class + # inherited from `Stream` & without this workaround, + # changing it to not inherit would be a breaking change. + + from .lib.streaming import MessageStream + + if isinstance(instance, MessageStream): + warnings.warn( + "Using `isinstance()` to check if a `MessageStream` object is an instance of `Stream` is deprecated & will be removed in the next major version", + DeprecationWarning, + stacklevel=2, + ) + return True + + return False + + +class Stream(Generic[_T], metaclass=_SyncStreamMeta): """Provides the core interface to iterate over a synchronous stream response.""" response: httpx.Response @@ -114,7 +137,28 @@ def close(self) -> None: self.response.close() -class AsyncStream(Generic[_T]): +class _AsyncStreamMeta(abc.ABCMeta): + @override + def __instancecheck__(self, instance: Any) -> bool: + # we override the `isinstance()` check for `AsyncStream` + # as a previous version of the `AsyncMessageStream` class + # inherited from `AsyncStream` & without this workaround, + # changing it to not inherit would be a breaking change. + + from .lib.streaming import AsyncMessageStream + + if isinstance(instance, AsyncMessageStream): + warnings.warn( + "Using `isinstance()` to check if a `AsyncMessageStream` object is an instance of `AsyncStream` is deprecated & will be removed in the next major version", + DeprecationWarning, + stacklevel=2, + ) + return True + + return False + + +class AsyncStream(Generic[_T], metaclass=_AsyncStreamMeta): """Provides the core interface to iterate over an asynchronous stream response.""" response: httpx.Response diff --git a/tests/lib/streaming/test_messages.py b/tests/lib/streaming/test_messages.py index 4d247238..5c872af5 100644 --- a/tests/lib/streaming/test_messages.py +++ b/tests/lib/streaming/test_messages.py @@ -8,7 +8,7 @@ import pytest from respx import MockRouter -from anthropic import Anthropic, AsyncAnthropic +from anthropic import Stream, Anthropic, AsyncStream, AsyncAnthropic from anthropic.lib.streaming import MessageStream, AsyncMessageStream from anthropic.types.message import Message from anthropic.types.message_stream_event import MessageStreamEvent @@ -124,6 +124,9 @@ def test_basic_response(self, respx_mock: MockRouter) -> None: model="claude-3-opus-20240229", event_handler=SyncEventTracker, ) as stream: + with pytest.warns(DeprecationWarning): + assert isinstance(stream, Stream) + assert_basic_response(stream, stream.get_final_message()) @pytest.mark.respx(base_url=base_url) @@ -163,6 +166,9 @@ async def test_basic_response(self, respx_mock: MockRouter) -> None: model="claude-3-opus-20240229", event_handler=AsyncEventTracker, ) as stream: + with pytest.warns(DeprecationWarning): + assert isinstance(stream, AsyncStream) + assert_basic_response(stream, await stream.get_final_message()) @pytest.mark.asyncio diff --git a/tests/test_streaming.py b/tests/test_streaming.py index f9ad7182..9e8908ae 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -216,6 +216,14 @@ def body() -> Iterator[bytes]: assert sse.json() == {"content": "известни"} +def test_isinstance_check(client: Anthropic, async_client: AsyncAnthropic) -> None: + async_stream = AsyncStream(cast_to=object, client=async_client, response=httpx.Response(200, content=b"foo")) + assert isinstance(async_stream, AsyncStream) + + sync_stream = Stream(cast_to=object, client=client, response=httpx.Response(200, content=b"foo")) + assert isinstance(sync_stream, Stream) + + async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: for chunk in iter: yield chunk