From 08b082d76904fbd6f48bcdd4a85f565a89c2837b Mon Sep 17 00:00:00 2001 From: Takashi Kawachi Date: Sun, 12 Mar 2023 17:01:30 +0900 Subject: [PATCH] Fixed calculation of allowed tokens --- aichat.go | 9 ++++++--- prompt.go | 13 +++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/aichat.go b/aichat.go index c699545..b0ccf8b 100644 --- a/aichat.go +++ b/aichat.go @@ -194,9 +194,9 @@ func main() { messagesSlice = [][]gogpt.ChatCompletionMessage{messages} } - for _, messages := range messagesSlice { + maxTokens := firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens) - maxTokens := firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens) + for _, messages := range messagesSlice { request := gogpt.ChatCompletionRequest{ Model: gogpt.GPT3Dot5Turbo, @@ -209,7 +209,10 @@ func main() { if err != nil { log.Fatal(err) } - if cnt > 4096 { + if verbose { + log.Printf("total tokens %d", cnt) + } + if cnt+maxTokens > 4096 { log.Fatalf("total tokens %d exceeds 4096", cnt) } diff --git a/prompt.go b/prompt.go index a1bdf87..4668d5c 100644 --- a/prompt.go +++ b/prompt.go @@ -57,13 +57,16 @@ func (p *Prompt) CountTokens() (int, error) { } // AllowedInputTokens returns the number of tokens allowed for the input -func (p *Prompt) AllowedInputTokens(maxTokensOverride int) (int, error) { - maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens) +func (p *Prompt) AllowedInputTokens() (int, error) { promptTokens, err := p.CountTokens() if err != nil { return 0, err } - return maxTokens - promptTokens, nil + result := 4096 - (promptTokens + p.MaxTokens) + if result <= 0 { + return 0, fmt.Errorf("allowed tokens for input is %d, but it should be greater than 0", result) + } + return result, nil } func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) { @@ -91,12 +94,10 @@ func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) { } func (p *Prompt) CreateMessagesWithSplit(input string, maxTokensOverride int) ([][]gogpt.ChatCompletionMessage, error) { - maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens) - promptTokens, err := p.CountTokens() + allowedInputTokens, err := p.AllowedInputTokens() if err != nil { return nil, err } - allowedInputTokens := maxTokens - promptTokens inputParts, err := splitStringWithTokensLimit(input, allowedInputTokens) if err != nil { return nil, err