diff --git a/audio.go b/audio.go index f321f93d..03b96754 100644 --- a/audio.go +++ b/audio.go @@ -45,8 +45,8 @@ type AudioRequest struct { Reader io.Reader Prompt string - Temperature float32 - Language string // Only for transcription. + Temperature float32 // Defaults to 0, so fine to not be a pointer. + Language string // Only for transcription. Format AudioResponseFormat TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. } diff --git a/chat.go b/chat.go index fcaf79cf..116e2ba2 100644 --- a/chat.go +++ b/chat.go @@ -222,7 +222,7 @@ type ChatCompletionRequest struct { // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` N int `json:"n,omitempty"` Stream bool `json:"stream,omitempty"` diff --git a/chat_test.go b/chat_test.go index 134026cd..6f1d44bc 100644 --- a/chat_test.go +++ b/chat_test.go @@ -153,7 +153,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(2), + Temperature: openai.NewFloat(2), }, expectedError: openai.ErrO1BetaLimitationsOther, }, @@ -170,7 +170,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(1), + Temperature: openai.NewFloat(1), TopP: float32(0.1), }, expectedError: openai.ErrO1BetaLimitationsOther, @@ -188,7 +188,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Role: openai.ChatMessageRoleAssistant, }, }, - Temperature: float32(1), + Temperature: openai.NewFloat(1), TopP: float32(1), N: 2, }, @@ -259,6 +259,22 @@ func TestChatRequestOmitEmpty(t *testing.T) { } } +func TestChatRequestOmitEmptyWithZeroTemp(t *testing.T) { + data, err := json.Marshal(openai.ChatCompletionRequest{ + // We set model b/c it's required, so omitempty doesn't make sense + Model: "gpt-4", + Temperature: openai.NewFloat(0), + }) + checks.NoError(t, err) + + // messages is also required so isn't omitted + // but the zero-value for temp is not excluded, b/c that's a valid value to set the temp to! + const expected = `{"model":"gpt-4","messages":null,"temperature":0}` + if string(data) != expected { + t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data)) + } +} + func TestChatCompletionsWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" diff --git a/common.go b/common.go index 8cc7289c..477cb3b8 100644 --- a/common.go +++ b/common.go @@ -22,3 +22,8 @@ type PromptTokensDetails struct { AudioTokens int `json:"audio_tokens"` CachedTokens int `json:"cached_tokens"` } + +// NewFloat returns a pointer to a float, useful for setting the temperature on some APIs. +func NewFloat(v float32) *float32 { + return &v +} diff --git a/completion.go b/completion.go index f1156608..892ad38b 100644 --- a/completion.go +++ b/completion.go @@ -220,7 +220,7 @@ func validateRequestForO1Models(request ChatCompletionRequest) error { } // Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0. - if request.Temperature > 0 && request.Temperature != 1 { + if request.Temperature != nil && *request.Temperature != 1 { return ErrO1BetaLimitationsOther } if request.TopP > 0 && request.TopP != 1 { @@ -263,7 +263,7 @@ type CompletionRequest struct { Stop []string `json:"stop,omitempty"` Stream bool `json:"stream,omitempty"` Suffix string `json:"suffix,omitempty"` - Temperature float32 `json:"temperature,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` User string `json:"user,omitempty"` } diff --git a/openai_test.go b/openai_test.go index 729d8880..48a00b9f 100644 --- a/openai_test.go +++ b/openai_test.go @@ -31,7 +31,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer // -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { return int(float32(len(s)) / 4) }