From 03851d20327b7df5358ff9fb0ac96f476be1875a Mon Sep 17 00:00:00 2001 From: Adrian Liechti Date: Sun, 30 Jun 2024 17:20:10 +0200 Subject: [PATCH] allow custom voice and speech models (#691) --- speech.go | 31 ------------------------------- speech_test.go | 17 ----------------- 2 files changed, 48 deletions(-) diff --git a/speech.go b/speech.go index 7e22e755..19b21bdf 100644 --- a/speech.go +++ b/speech.go @@ -2,7 +2,6 @@ package openai import ( "context" - "errors" "net/http" ) @@ -36,11 +35,6 @@ const ( SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) -var ( - ErrInvalidSpeechModel = errors.New("invalid speech model") - ErrInvalidVoice = errors.New("invalid voice") -) - type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` @@ -49,32 +43,7 @@ type CreateSpeechRequest struct { Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } -func contains[T comparable](s []T, e T) bool { - for _, v := range s { - if v == e { - return true - } - } - return false -} - -func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) -} - -func isValidVoice(voice SpeechVoice) bool { - return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) -} - func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - if !isValidSpeechModel(request.Model) { - err = ErrInvalidSpeechModel - return - } - if !isValidVoice(request.Voice) { - err = ErrInvalidVoice - return - } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json"), diff --git a/speech_test.go b/speech_test.go index d9ba58b1..f1e405c3 100644 --- a/speech_test.go +++ b/speech_test.go @@ -95,21 +95,4 @@ func TestSpeechIntegration(t *testing.T) { err = os.WriteFile("test.mp3", buf, 0644) checks.NoError(t, err, "Create error") }) - t.Run("invalid model", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: "invalid_model", - Input: "Hello!", - Voice: openai.VoiceAlloy, - }) - checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") - }) - - t.Run("invalid voice", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, - Input: "Hello!", - Voice: "invalid_voice", - }) - checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") - }) }