Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DocumentUrl and support document via BinaryContent #987

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field, replace
from datetime import datetime
from mimetypes import guess_type
from typing import Annotated, Any, Literal, Union, cast, overload

import pydantic
Expand Down Expand Up @@ -80,8 +81,28 @@ def media_type(self) -> ImageMediaType:
raise ValueError(f'Unknown image file extension: {self.url}')


@dataclass
class DocumentUrl:
"""The URL of the document."""

url: str
"""The URL of the document."""

kind: Literal['document-url'] = 'document-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> str:
"""Return the media type of the document, based on the url."""
type_, _ = guess_type(self.url)
if type_ is None:
raise RuntimeError(f'Unknown document file extension: {self.url}')
return type_


AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
DocumentMediaType: TypeAlias = Literal['application/pdf', 'text/plain']


@dataclass
Expand All @@ -91,7 +112,7 @@ class BinaryContent:
data: bytes
"""The binary data."""

media_type: AudioMediaType | ImageMediaType | str
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
"""The media type of the binary data."""

kind: Literal['binary'] = 'binary'
Expand All @@ -107,6 +128,11 @@ def is_image(self) -> bool:
"""Return `True` if the media type is an image type."""
return self.media_type.startswith('image/')

@property
def is_document(self) -> bool:
"""Return `True` if the media type is a document type."""
return self.media_type in {'application/pdf', 'text/plain'}

@property
def audio_format(self) -> Literal['mp3', 'wav']:
"""Return the audio format given the media type."""
Expand All @@ -118,7 +144,7 @@ def audio_format(self) -> Literal['mp3', 'wav']:
raise ValueError(f'Unknown audio media type: {self.media_type}')


UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent'
UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | BinaryContent'


@dataclass
Expand Down
31 changes: 29 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from json import JSONDecodeError, loads as json_loads
from typing import Any, Literal, Union, cast, overload

from anthropic.types import DocumentBlockParam
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
BinaryContent,
DocumentUrl,
ImageUrl,
ModelMessage,
ModelRequest,
Expand All @@ -41,10 +43,12 @@
try:
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
from anthropic.types import (
Base64PDFSourceParam,
ImageBlockParam,
Message as AnthropicMessage,
MessageParam,
MetadataParam,
PlainTextSourceParam,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
Expand Down Expand Up @@ -281,7 +285,9 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
anthropic_messages: list[MessageParam] = []
for m in messages:
if isinstance(m, ModelRequest):
user_content_params: list[ToolResultBlockParam | TextBlockParam | ImageBlockParam] = []
user_content_params: list[
ToolResultBlockParam | TextBlockParam | ImageBlockParam | DocumentBlockParam
] = []
for request_part in m.parts:
if isinstance(request_part, SystemPromptPart):
system_prompt += request_part.content
Expand Down Expand Up @@ -327,7 +333,9 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
return system_prompt, anthropic_messages

@staticmethod
async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]:
async def _map_user_prompt(
part: UserPromptPart,
) -> AsyncGenerator[ImageBlockParam | TextBlockParam | DocumentBlockParam]:
if isinstance(part.content, str):
yield TextBlockParam(text=part.content, type='text')
else:
Expand All @@ -349,6 +357,25 @@ async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockPar
source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'},
type='image',
)
elif isinstance(item, DocumentUrl):
response = await cached_async_http_client().get(item.url)
response.raise_for_status()
if item.media_type == 'application/pdf':
yield DocumentBlockParam(
source=Base64PDFSourceParam(
data=io.BytesIO(response.content),
media_type=item.media_type,
type='base64',
),
type='document',
)
elif item.media_type == 'text/plain':
yield DocumentBlockParam(
source=PlainTextSourceParam(data=response.text, media_type=item.media_type, type='text'),
type='document',
)
else:
raise RuntimeError(f'Unsupported media type: {item.media_type}')
else:
raise RuntimeError(f'Unsupported content type: {type(item)}')

Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..messages import (
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -323,8 +324,9 @@ async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
content.append(
_GeminiInlineDataPart(inline_data={'data': base64_encoded, 'mime_type': item.media_type})
)
elif isinstance(item, (AudioUrl, ImageUrl)):
elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl)):
try:
print(item.url, item.media_type)
content.append(
_GeminiFileDataPart(file_data={'file_uri': item.url, 'mime_type': item.media_type})
)
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..messages import (
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -384,6 +385,8 @@ async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessa
base64_encoded = base64.b64encode(response.content).decode('utf-8')
audio = InputAudio(data=base64_encoded, format=response.headers.get('content-type'))
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio'))
elif isinstance(item, DocumentUrl):
raise RuntimeError('DocumentUrl is not supported by OpenAI')
else:
assert_never(item)
return chat.ChatCompletionUserMessageParam(role='user', content=content)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ logfire = ["logfire>=2.3"]
openai = ["openai>=1.61.0"]
cohere = ["cohere>=5.13.11"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
anthropic = ["anthropic>=0.40.0"]
anthropic = ["anthropic>=0.41.0"]
groq = ["groq>=0.12.0"]
mistral = ["mistralai>=1.2.5"]

Expand Down
Binary file added tests/assets/dummy.pdf
Binary file not shown.
54 changes: 54 additions & 0 deletions tests/models/cassettes/test_gemini/test_document_url_input.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
interactions:
- request:
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '207'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
method: POST
parsed_body:
contents:
- parts:
- text: What is the main content on this document?
- fileData:
fileUri: gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf
mimeType: application/pdf
role: user
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.0-pro:generateContent
response:
body:
string: |
{
"error": {
"code": 404,
"message": "models/gemini-1.0-pro is not found for API version v1beta, or is not supported for generateContent. Call ListModels to see the list of available models and their supported methods.",
"status": "NOT_FOUND"
}
}
headers:
alt-svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
content-length:
- '263'
content-type:
- application/json; charset=UTF-8
server-timing:
- gfet4t7; dur=658
transfer-encoding:
- chunked
vary:
- Origin
- X-Origin
- Referer
status:
code: 404
message: Not Found
version: 1
12 changes: 12 additions & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pydantic_ai.exceptions import ModelHTTPError
from pydantic_ai.messages import (
BinaryContent,
DocumentUrl,
ImageUrl,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -1010,3 +1011,14 @@ async def test_image_url_input(allow_model_requests: None, gemini_api_key: str)

result = await agent.run(['What is the name of this fruit?', image_url])
assert result.data == snapshot('This is not a fruit, it is an organ console.')


@pytest.mark.vcr()
async def test_document_url_input(allow_model_requests: None, gemini_api_key: str) -> None:
m = GeminiModel('gemini-1.0-pro', api_key=gemini_api_key)
agent = Agent(m)

document_url = DocumentUrl(url='gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf')

result = await agent.run(['What is the main content on this document?', document_url])
assert result.data == snapshot()
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading