From ca4a52884b0b217001dfaa217edd6fa3e173cfcd Mon Sep 17 00:00:00 2001 From: Takashi Kawachi Date: Sun, 26 Mar 2023 17:40:42 +0900 Subject: [PATCH] Add token limit for GPT models and update prompt token limit calculation This commit adds a function to calculate the maximum token limit for a given GPT model. It also updates the `AllowedInputTokens` and `AllowedSubsequentInputTokens` functions to use the token limit instead of hardcoding the value to 4096. Additionally, the `CreateMessagesWithSplit` function is updated to use the new `AllowedInputTokens` function signature. Finally, a new test function is added to test the `tokenLimitOfModel` function. --- aichat.go | 26 ++++++++++++++++++++------ aichat_test.go | 16 ++++++++++++++++ prompt.go | 12 ++++++------ 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/aichat.go b/aichat.go index 5316ae1..7d50139 100644 --- a/aichat.go +++ b/aichat.go @@ -131,7 +131,8 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error { return fmt.Errorf("encode: %w", err) } - firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, aiChat.options.maxTokens, aiChat.options.verbose) + tokenLimit := tokenLimitOfModel(aiChat.options.model) + firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose) if err != nil { return err } @@ -174,7 +175,7 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error { } allowedTokens, err := prompt.AllowedSubsequentInputTokens( - aiChat.encoder, len(outputTokens), aiChat.options.maxTokens, aiChat.options.verbose) + aiChat.encoder, len(outputTokens), tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose) if err != nil { return fmt.Errorf("allowed subsequent input tokens: %w", err) } @@ -208,6 +209,18 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error { return nil } +// tokenLimitOfModel returns the maximum number of tokens allowed for a given model. +func tokenLimitOfModel(model string) int { + switch model { + case gogpt.GPT4, gogpt.GPT40314: + return 8 * 1024 + case gogpt.GPT432K, gogpt.GPT432K0314: + return 32 * 1024 + default: + return 4 * 1024 + } +} + func main() { var temperature float32 = 0.5 var maxTokens = 0 @@ -215,7 +228,7 @@ func main() { var listPrompts = false var nonStreaming = false var split = false - var model = gogpt.GPT3Dot5Turbo + var model = gogpt.GPT4 getopt.FlagLong(&temperature, "temperature", 't', "temperature") getopt.FlagLong(&maxTokens, "max-tokens", 'm', "max tokens, 0 to use default") getopt.FlagLong(&verbose, "verbose", 'v', "verbose output") @@ -281,9 +294,10 @@ func main() { } var messagesSlice [][]gogpt.ChatCompletionMessage + tokenLimit := tokenLimitOfModel(aiChat.options.model) if split { - messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, aiChat.options.maxTokens, aiChat.options.verbose) + messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose) if err != nil { log.Fatal(err) } @@ -320,8 +334,8 @@ func main() { if verbose { log.Printf("total tokens %d", cnt) } - if cnt+maxTokens > 4096 { - log.Fatalf("total tokens %d exceeds 4096", cnt) + if cnt+maxTokens > tokenLimit { + log.Fatalf("total tokens %d exceeds %d", cnt, tokenLimit) } if aiChat.options.nonStreaming { diff --git a/aichat_test.go b/aichat_test.go index faecbb7..8a556d7 100644 --- a/aichat_test.go +++ b/aichat_test.go @@ -17,3 +17,19 @@ func TestCountTokens(t *testing.T) { t.Errorf("CountTokens() returned %d, expected 8", count) } } + +func TestTokenLimitOfModel(t *testing.T) { + data := []struct { + modelName string + tokenLimit int + }{ + {"gpt-3.5-turbo", 4096}, + {"gpt-4", 8192}, + } + for _, d := range data { + tokenLimit := tokenLimitOfModel(d.modelName) + if tokenLimit != d.tokenLimit { + t.Errorf("TokenLimitForModel(%q) returned %d, expected %d", d.modelName, tokenLimit, d.tokenLimit) + } + } +} diff --git a/prompt.go b/prompt.go index 3828702..9605c71 100644 --- a/prompt.go +++ b/prompt.go @@ -87,14 +87,14 @@ func countMessagesTokens(encoder *tokenizer.Encoder, messages []Message) (int, e } // AllowedInputTokens returns the number of tokens allowed for the input -func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverride int, verbose bool) (int, error) { +func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, tokenLimit, maxTokensOverride int, verbose bool) (int, error) { promptTokens, err := p.CountTokens(encoder) if err != nil { return 0, err } // reserve 500 tokens for output if maxTokens is not specified maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500) - result := 4096 - (promptTokens + maxTokens) + result := tokenLimit - (promptTokens + maxTokens) if verbose { log.Printf("allowed tokens for input is %d", result) } @@ -104,14 +104,14 @@ func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverrid return result, nil } -func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, maxTokensOverride int, verbose bool) (int, error) { +func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, tokenLimit, maxTokensOverride int, verbose bool) (int, error) { promptTokens, err := p.CountSubsequentTokens(encoder) if err != nil { return 0, err } // reserve 500 tokens for output if maxTokens is not specified maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500) - result := 4096 - (promptTokens + maxTokens + outputLen) + result := tokenLimit - (promptTokens + maxTokens + outputLen) if verbose { log.Printf("allowed tokens for subsequent input is %d", result) } @@ -145,8 +145,8 @@ func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) { return parts, nil } -func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) { - allowedInputTokens, err := p.AllowedInputTokens(encoder, maxTokensOverride, verbose) +func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, tokenLimit, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) { + allowedInputTokens, err := p.AllowedInputTokens(encoder, tokenLimit, maxTokensOverride, verbose) if err != nil { return nil, err }