Skip to content

Commit

Permalink
chore(messages): add back-compat for isinstance() checks
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed May 24, 2024
1 parent b1a1c03 commit 7794bcb
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
48 changes: 46 additions & 2 deletions src/anthropic/_streaming.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Note: initially copied from /~https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
from __future__ import annotations

import abc
import json
import inspect
import warnings
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
Expand All @@ -18,7 +20,28 @@
_T = TypeVar("_T")


class Stream(Generic[_T]):
class _SyncStreamMeta(abc.ABCMeta):
@override
def __instancecheck__(self, instance: Any) -> bool:
# we override the `isinstance()` check for `Stream`
# as a previous version of the `MessageStream` class
# inherited from `Stream` & without this workaround,
# changing it to not inherit would be a breaking change.

from .lib.streaming import MessageStream

if isinstance(instance, MessageStream):
warnings.warn(
"Using `isinstance()` to check if a `MessageStream` object is an instance of `Stream` is deprecated & will be removed in the next major version",
DeprecationWarning,
stacklevel=2,
)
return True

return False


class Stream(Generic[_T], metaclass=_SyncStreamMeta):
"""Provides the core interface to iterate over a synchronous stream response."""

response: httpx.Response
Expand Down Expand Up @@ -114,7 +137,28 @@ def close(self) -> None:
self.response.close()


class AsyncStream(Generic[_T]):
class _AsyncStreamMeta(abc.ABCMeta):
@override
def __instancecheck__(self, instance: Any) -> bool:
# we override the `isinstance()` check for `AsyncStream`
# as a previous version of the `AsyncMessageStream` class
# inherited from `AsyncStream` & without this workaround,
# changing it to not inherit would be a breaking change.

from .lib.streaming import AsyncMessageStream

if isinstance(instance, AsyncMessageStream):
warnings.warn(
"Using `isinstance()` to check if a `AsyncMessageStream` object is an instance of `AsyncStream` is deprecated & will be removed in the next major version",
DeprecationWarning,
stacklevel=2,
)
return True

return False


class AsyncStream(Generic[_T], metaclass=_AsyncStreamMeta):
"""Provides the core interface to iterate over an asynchronous stream response."""

response: httpx.Response
Expand Down
8 changes: 7 additions & 1 deletion tests/lib/streaming/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from respx import MockRouter

from anthropic import Anthropic, AsyncAnthropic
from anthropic import Stream, Anthropic, AsyncStream, AsyncAnthropic
from anthropic.lib.streaming import MessageStream, AsyncMessageStream
from anthropic.types.message import Message
from anthropic.types.message_stream_event import MessageStreamEvent
Expand Down Expand Up @@ -124,6 +124,9 @@ def test_basic_response(self, respx_mock: MockRouter) -> None:
model="claude-3-opus-20240229",
event_handler=SyncEventTracker,
) as stream:
with pytest.warns(DeprecationWarning):
assert isinstance(stream, Stream)

assert_basic_response(stream, stream.get_final_message())

@pytest.mark.respx(base_url=base_url)
Expand Down Expand Up @@ -163,6 +166,9 @@ async def test_basic_response(self, respx_mock: MockRouter) -> None:
model="claude-3-opus-20240229",
event_handler=AsyncEventTracker,
) as stream:
with pytest.warns(DeprecationWarning):
assert isinstance(stream, AsyncStream)

assert_basic_response(stream, await stream.get_final_message())

@pytest.mark.asyncio
Expand Down
8 changes: 8 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ def body() -> Iterator[bytes]:
assert sse.json() == {"content": "известни"}


def test_isinstance_check(client: Anthropic, async_client: AsyncAnthropic) -> None:
async_stream = AsyncStream(cast_to=object, client=async_client, response=httpx.Response(200, content=b"foo"))
assert isinstance(async_stream, AsyncStream)

sync_stream = Stream(cast_to=object, client=client, response=httpx.Response(200, content=b"foo"))
assert isinstance(sync_stream, Stream)


async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
for chunk in iter:
yield chunk
Expand Down

0 comments on commit 7794bcb

Please sign in to comment.