From 054c1d65c2cc74516e56f62b1223139709e1ecc3 Mon Sep 17 00:00:00 2001 From: Takashi Kawachi Date: Sun, 19 Mar 2023 14:19:40 +0900 Subject: [PATCH] Add configurable model option to AI chat - Added 'model' field to chatOptions struct - Replaced hardcoded model with aiChat.options.model - Added command line flag for model selection --- aichat.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/aichat.go b/aichat.go index b7655c3..5316ae1 100644 --- a/aichat.go +++ b/aichat.go @@ -16,6 +16,7 @@ import ( ) type chatOptions struct { + model string temperature float32 maxTokens int nonStreaming bool @@ -87,7 +88,7 @@ func (aiChat *AIChat) stdChatLoop() error { }) fmt.Print("assistant: ") request := gogpt.ChatCompletionRequest{ - Model: gogpt.GPT3Dot5Turbo, + Model: aiChat.options.model, Messages: messages, Temperature: aiChat.options.temperature, MaxTokens: aiChat.options.maxTokens, @@ -143,7 +144,7 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error { firstInput := aiChat.encoder.Decode(firstEncoded) temperature := firstNonZeroFloat32(aiChat.options.temperature, prompt.Temperature) firstRequest := gogpt.ChatCompletionRequest{ - Model: gogpt.GPT3Dot5Turbo, + Model: aiChat.options.model, Messages: prompt.CreateMessages(firstInput), Temperature: temperature, } @@ -183,7 +184,7 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error { } input := aiChat.encoder.Decode(encoded[idx:nextIdx]) request := gogpt.ChatCompletionRequest{ - Model: gogpt.GPT3Dot5Turbo, + Model: aiChat.options.model, Messages: prompt.CreateSubsequentMessages(output, input), Temperature: temperature, } @@ -214,12 +215,14 @@ func main() { var listPrompts = false var nonStreaming = false var split = false + var model = gogpt.GPT3Dot5Turbo getopt.FlagLong(&temperature, "temperature", 't', "temperature") getopt.FlagLong(&maxTokens, "max-tokens", 'm', "max tokens, 0 to use default") getopt.FlagLong(&verbose, "verbose", 'v', "verbose output") getopt.FlagLong(&listPrompts, "list-prompts", 'l', "list prompts") getopt.FlagLong(&nonStreaming, "non-streaming", 0, "non streaming mode") getopt.FlagLong(&split, "split", 0, "split input") + getopt.FlagLong(&model, "model", 0, "model") getopt.Parse() if listPrompts { @@ -234,6 +237,7 @@ func main() { log.Fatal(err) } options := chatOptions{ + model: model, temperature: temperature, maxTokens: maxTokens, nonStreaming: nonStreaming, @@ -303,7 +307,7 @@ func main() { for _, messages := range messagesSlice { request := gogpt.ChatCompletionRequest{ - Model: gogpt.GPT3Dot5Turbo, + Model: model, Messages: messages, Temperature: firstNonZeroFloat32(prompt.Temperature, aiChat.options.temperature), MaxTokens: maxTokens,