Skip to content

Commit

Permalink
chore: GenAI - Part.function_call.args is now a proper dict
Browse files Browse the repository at this point in the history
Fixes #4079

PiperOrigin-RevId: 680833836
  • Loading branch information
Ark-kun authored and copybara-github committed Oct 1, 2024
1 parent f6e0a5a commit 427bd75
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,13 +2298,13 @@ def text(self) -> str:
) from e

@property
def function_calls(self) -> Sequence[gapic_tool_types.FunctionCall]:
def function_calls(self) -> Sequence["FunctionCall"]:
if not self.content or not self.content.parts:
return []
return [
part.function_call
for part in self.content.parts
if part and part.function_call
if part._raw_part._pb.WhichOneof("data") == "function_call"
]


Expand Down Expand Up @@ -2479,8 +2479,8 @@ def file_data(self) -> gapic_content_types.FileData:
return self._raw_part.file_data

@property
def function_call(self) -> gapic_tool_types.FunctionCall:
return self._raw_part.function_call
def function_call(self) -> "FunctionCall":
return FunctionCall._from_gapic(self._raw_part.function_call)

@property
def function_response(self) -> gapic_tool_types.FunctionResponse:
Expand All @@ -2491,6 +2491,35 @@ def _image(self) -> "Image":
return Image.from_bytes(data=self._raw_part.inline_data.data)


class FunctionCall:
"""Function call."""

def __init__(self):
self._raw_message = aiplatform_types.FunctionCall()

@classmethod
def _from_gapic(cls, raw_message: aiplatform_types.FunctionCall) -> "FunctionCall":
response = cls()
response._raw_message = raw_message
return response

def to_dict(self) -> Dict[str, Any]:
return _proto_to_dict(self._raw_message)

def __repr__(self) -> str:
return self._raw_message.__repr__()

@property
def name(self) -> str:
return self._raw_message.name

@property
def args(self) -> Dict[str, Any]:
# We cannot use `type(self.args).to_dict(self.args)`
# due to: AttributeError: type object 'MapComposite' has no attribute 'to_dict'
return self.to_dict().get("args")


class SafetySetting:
"""Parameters for the generation."""

Expand Down Expand Up @@ -2949,10 +2978,9 @@ def respond_to_model_response(
)

try:
# We cannot use `function_args = type(function_call.args).to_dict(function_call.args)`
# due to: AttributeError: type object 'MapComposite' has no attribute 'to_dict'
function_args = type(function_call).to_dict(function_call)["args"]
function_call_result = callable_function._function(**function_args)
function_call_result = callable_function._function(
**function_call.args
)
if not isinstance(function_call_result, Mapping):
# If the function returns a single value, wrap it in the
# format that Part.from_function_response can accept.
Expand Down

0 comments on commit 427bd75

Please sign in to comment.