diff --git a/completion.go b/completion.go index 6617e5a7..44057616 100644 --- a/completion.go +++ b/completion.go @@ -7,8 +7,9 @@ import ( ) var ( - ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll - ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll + ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Promp only supports string and []string") //nolint:lll ) // GPT3 Defines the models provided by OpenAI to use when generating @@ -77,10 +78,16 @@ func checkEndpointSupportsModel(endpoint, model string) bool { return !disabledModelsForEndpoints[endpoint][model] } +func checkPromptType(prompt any) bool { + _, isString := prompt.(string) + _, isStringSlice := prompt.([]string) + return isString || isStringSlice +} + // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` + Prompt any `json:"prompt,omitempty"` Suffix string `json:"suffix,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` @@ -143,6 +150,11 @@ func (c *Client) CreateCompletion( return } + if !checkPromptType(request.Prompt) { + err = ErrCompletionRequestPromptTypeNotSupported + return + } + req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) if err != nil { return diff --git a/completion_test.go b/completion_test.go index ce95faf4..2e302591 100644 --- a/completion_test.go +++ b/completion_test.go @@ -98,14 +98,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) if completionReq.Echo { - completionStr = completionReq.Prompt + completionStr + completionStr = completionReq.Prompt.(string) + completionStr } res.Choices = append(res.Choices, CompletionChoice{ Text: completionStr, Index: i, }) } - inputTokens := numTokens(completionReq.Prompt) * completionReq.N + inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N completionTokens := completionReq.MaxTokens * completionReq.N res.Usage = Usage{ PromptTokens: inputTokens, diff --git a/request_builder_test.go b/request_builder_test.go index 0f14f93f..e5b65df0 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -51,7 +51,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { ctx := context.Background() - _, err = client.CreateCompletion(ctx, CompletionRequest{}) + _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"}) if !errors.Is(err, errTestRequestBuilderFailed) { t.Fatalf("Did not return error when request builder failed: %v", err) } @@ -146,3 +146,27 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestReturnsRequestBuilderErrorsAddtion(t *testing.T) { + var err error + ts := test.NewTestServer().OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + ctx := context.Background() + + _, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } + + _, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: 1}) + if !errors.Is(err, ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("Did not return error when request builder failed: %v", err) + } +} diff --git a/stream.go b/stream.go index 944546a6..64688cdc 100644 --- a/stream.go +++ b/stream.go @@ -28,6 +28,11 @@ func (c *Client) CreateCompletionStream( return } + if !checkPromptType(request.Prompt) { + err = ErrCompletionRequestPromptTypeNotSupported + return + } + request.Stream = true req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request) if err != nil {