Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backport: e2e: add agent test using tool (#426) #442

Merged
merged 2 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_style_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
examples
blog
site
backport
subjectPattern: ^(?![A-Z]).+$
subjectPatternError: |
The subject "{subject}" found in the pull request title "{title}"
Expand Down
4 changes: 2 additions & 2 deletions internal/apischema/awsbedrock/awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ type ToolResultContentBlock struct {
Image *ImageBlock `json:"image,omitempty"`

// A tool result that is text.
Text *string `json:"text" type:"string,omitempty"`
Text *string `json:"text,omitempty"`

// A tool result that is JSON format data.
JSON *string `json:"json" type:"string,omitempty"`
JSON *string `json:"json,omitempty"`
}

// ToolResultBlock A tool result block that contains the results for a tool request that the
Expand Down
39 changes: 29 additions & 10 deletions internal/extproc/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMes
openAiMessage *openai.ChatCompletionAssistantMessageParam, role string,
) (*awsbedrock.Message, error) {
var bedrockMessage *awsbedrock.Message
contentBlocks := make([]*awsbedrock.ContentBlock, 1)
contentBlocks := make([]*awsbedrock.ContentBlock, 0)
if openAiMessage.Content.Type == openai.ChatCompletionAssistantMessageParamContentTypeRefusal {
contentBlocks[0] = &awsbedrock.ContentBlock{Text: openAiMessage.Content.Refusal}
} else {
contentBlocks[0] = &awsbedrock.ContentBlock{Text: openAiMessage.Content.Text}
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{Text: openAiMessage.Content.Refusal})
} else if openAiMessage.Content.Text != nil {
// TODO: we are sometimes missing the content (should fix)
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{Text: openAiMessage.Content.Text})
}
bedrockMessage = &awsbedrock.Message{
Role: role,
Expand Down Expand Up @@ -311,16 +312,34 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMes
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMessageRoleTool(
openAiMessage *openai.ChatCompletionToolMessageParam, role string,
) (*awsbedrock.Message, error) {
// Validate and cast the openai content value into bedrock content block
content := make([]*awsbedrock.ToolResultContentBlock, 0)

switch v := openAiMessage.Content.Value.(type) {
case string:
content = []*awsbedrock.ToolResultContentBlock{
{
Text: &v,
},
}
case []openai.ChatCompletionContentPartTextParam:
for _, part := range v {
content = append(content, &awsbedrock.ToolResultContentBlock{
Text: &part.Text,
})
}

default:
return nil, fmt.Errorf("unexpected content type for tool message: %T", openAiMessage.Content.Value)
}

return &awsbedrock.Message{
Role: role,
Content: []*awsbedrock.ContentBlock{
{
ToolResult: &awsbedrock.ToolResultBlock{
Content: []*awsbedrock.ToolResultContentBlock{
{
Text: openAiMessage.Content.Value.(*string),
},
},
Content: content,
ToolUseID: &openAiMessage.ToolCallID,
},
},
},
Expand Down Expand Up @@ -559,7 +578,6 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders
if err = json.NewDecoder(body).Decode(&bedrockResp); err != nil {
return nil, nil, tokenUsage, fmt.Errorf("failed to unmarshal body: %w", err)
}

openAIResp := openai.ChatCompletionResponse{
Object: "chat.completion",
Choices: make([]openai.ChatCompletionResponseChoice, 0),
Expand All @@ -577,6 +595,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders
CompletionTokens: bedrockResp.Usage.OutputTokens,
}
}

// AWS Bedrock does not support N(multiple choices) > 0, so there could be only one choice.
choice := openai.ChatCompletionResponseChoice{
Index: (int64)(0),
Expand Down
80 changes: 77 additions & 3 deletions internal/extproc/translator/openai_awsbedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
},
}, Type: openai.ChatMessageRoleUser,
},
{
Value: openai.ChatCompletionToolMessageParam{
Content: openai.StringOrArray{
Value: "Weather in Queens, NY is 70F and clear skies.",
},
ToolCallID: "call_6g7a",
}, Type: openai.ChatMessageRoleTool,
},
{
Value: openai.ChatCompletionAssistantMessageParam{
Content: openai.ChatCompletionAssistantMessageParamContent{
Expand Down Expand Up @@ -132,6 +140,21 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
},
},
},
{
Role: openai.ChatMessageRoleUser,
Content: []*awsbedrock.ContentBlock{
{
ToolResult: &awsbedrock.ToolResultBlock{
ToolUseID: ptr.To("call_6g7a"),
Content: []*awsbedrock.ToolResultContentBlock{
{
Text: ptr.To("Weather in Queens, NY is 70F and clear skies."),
},
},
},
},
},
},
{
Role: openai.ChatMessageRoleAssistant,
Content: []*awsbedrock.ContentBlock{
Expand Down Expand Up @@ -907,7 +930,7 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
Index: 0,
Message: openai.ChatCompletionResponseChoiceMessage{
Content: ptr.To("response"),
Role: "assistant",
Role: awsbedrock.ConversationRoleAssistant,
},
FinishReason: openai.ChatCompletionChoicesFinishReasonStop,
},
Expand Down Expand Up @@ -998,6 +1021,57 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
},
},
},
{
name: "merge content",
input: awsbedrock.ConverseResponse{
Usage: &awsbedrock.TokenUsage{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
Output: &awsbedrock.ConverseOutput{
Message: awsbedrock.Message{
Role: awsbedrock.ConversationRoleAssistant,
Content: []*awsbedrock.ContentBlock{
{Text: ptr.To("response")},
{ToolUse: &awsbedrock.ToolUseBlock{
Name: "exec_python_code",
ToolUseID: "call_6g7a",
Input: map[string]interface{}{"code_block": "from playwright.sync_api import sync_playwright\n"},
}},
},
},
},
},
output: openai.ChatCompletionResponse{
Object: "chat.completion",
Usage: openai.ChatCompletionResponseUsage{
TotalTokens: 30,
PromptTokens: 10,
CompletionTokens: 20,
},
Choices: []openai.ChatCompletionResponseChoice{
{
Index: 0,
Message: openai.ChatCompletionResponseChoiceMessage{
Content: ptr.To("response"),
Role: awsbedrock.ConversationRoleAssistant,
ToolCalls: []openai.ChatCompletionMessageToolCallParam{
{
ID: "call_6g7a",
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: "exec_python_code",
Arguments: "{\"code_block\":\"from playwright.sync_api import sync_playwright\\n\"}",
},
Type: openai.ChatCompletionMessageToolCallTypeFunction,
},
},
},
FinishReason: openai.ChatCompletionChoicesFinishReasonStop,
},
},
},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -1178,14 +1252,14 @@ func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) {
{
name: "role",
in: awsbedrock.ConverseStreamEvent{
Role: ptrOf("assistant"),
Role: ptrOf(awsbedrock.ConversationRoleAssistant),
},
out: &openai.ChatCompletionResponseChunk{
Object: "chat.completion.chunk",
Choices: []openai.ChatCompletionResponseChunkChoice{
{
Delta: &openai.ChatCompletionResponseChunkChoiceDelta{
Role: "assistant",
Role: awsbedrock.ConversationRoleAssistant,
Content: &emptyString,
},
},
Expand Down
124 changes: 84 additions & 40 deletions tests/extproc/real_providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
"bufio"
"bytes"
"cmp"
"context"
"encoding/json"
"fmt"
"os"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -187,49 +189,91 @@ func TestWithRealProviders(t *testing.T) {
}
})

t.Run("Bedrock calls tool get_weather function", func(t *testing.T) {
cc.maybeSkip(t, requiredCredentialAWS)
t.Run("Bedrock uses tool in response", func(t *testing.T) {
client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/"))
require.Eventually(t, func() bool {
chatCompletion, err := client.Chat.Completions.New(t.Context(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage("What is the weather like in Paris today?"),
}),
Tools: openai.F([]openai.ChatCompletionToolParam{
{
Type: openai.F(openai.ChatCompletionToolTypeFunction),
Function: openai.F(openai.FunctionDefinitionParam{
Name: openai.String("get_weather"),
Description: openai.String("Get weather at the given location"),
Parameters: openai.F(openai.FunctionParameters{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]string{
"type": "string",
},
},
"required": []string{"location"},
}),
for _, tc := range []realProvidersTestCase{
{name: "aws-bedrock", modelName: "us.anthropic.claude-3-5-sonnet-20240620-v1:0", required: requiredCredentialAWS}, // This will go to "aws-bedrock" using credentials file.
} {
t.Run(tc.modelName, func(t *testing.T) {
cc.maybeSkip(t, tc.required)
require.Eventually(t, func() bool {
// Step 1: Initial tool call request
question := "What is the weather in New York City?"
params := openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage(question),
}),
Tools: openai.F([]openai.ChatCompletionToolParam{
{
Type: openai.F(openai.ChatCompletionToolTypeFunction),
Function: openai.F(openai.FunctionDefinitionParam{
Name: openai.String("get_weather"),
Description: openai.String("Get weather at the given location"),
Parameters: openai.F(openai.FunctionParameters{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]string{
"type": "string",
},
},
"required": []string{"location"},
}),
}),
},
}),
},
}),
Model: openai.F("us.anthropic.claude-3-5-sonnet-20240620-v1:0"),
Seed: openai.Int(0),
Model: openai.F(tc.modelName),
}
completion, err := client.Chat.Completions.New(context.Background(), params)
if err != nil {
t.Logf("error: %v", err)
return false
}
// Step 2: Verify tool call
toolCalls := completion.Choices[0].Message.ToolCalls
if len(toolCalls) == 0 {
t.Logf("Expected tool call from completion result but got none")
return false
}
// Step 3: Simulate the tool returning a response, add the tool response to the params, and check the second response
params.Messages.Value = append(params.Messages.Value, completion.Choices[0].Message)
getWeatherCalled := false
for _, toolCall := range toolCalls {
if toolCall.Function.Name == "get_weather" {
getWeatherCalled = true
// Extract the location from the function call arguments
var args map[string]interface{}
if argErr := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); argErr != nil {
t.Logf("Error unmarshalling the function arguments: %v", argErr)
}
location := args["location"].(string)
if location != "New York City" {
t.Logf("Expected location to be New York City but got %s", location)
}
// Simulate getting weather data
weatherData := "Sunny, 25°C"
params.Messages.Value = append(params.Messages.Value, openai.ToolMessage(toolCall.ID, weatherData))
t.Logf("Appended tool message: %v", openai.ToolMessage(toolCall.ID, weatherData)) // Debug log
}
}
if getWeatherCalled == false {
t.Logf("get_weather tool not specified in chat completion response")
return false
}

secondChatCompletion, err := client.Chat.Completions.New(context.Background(), params)
if err != nil {
t.Logf("error during second response: %v", err)
return false
}

// Step 4: Verify that the second response is correct
completionResult := secondChatCompletion.Choices[0].Message.Content
t.Logf("content of completion response using tool: %s", secondChatCompletion.Choices[0].Message.Content)
return strings.Contains(completionResult, "New York City") && strings.Contains(completionResult, "sunny") && strings.Contains(completionResult, "25°C")
}, 60*time.Second, 4*time.Second)
})
if err != nil {
t.Logf("error: %v", err)
return false
}
returnsToolCall := false
for _, choice := range chatCompletion.Choices {
t.Logf("choice content: %s", choice.Message.Content)
t.Logf("finish reason: %s", choice.FinishReason)
t.Logf("choice toolcall: %v", choice.Message.ToolCalls)
if choice.FinishReason == openai.ChatCompletionChoicesFinishReasonToolCalls {
returnsToolCall = true
}
}
return returnsToolCall
}, 30*time.Second, 2*time.Second)
}
})

// Models are served by the extproc filter as a direct response so this can run even if the
Expand Down
Loading