Skip to content

Commit

Permalink
Merge pull request #24 from tkawachi/tokenlimit
Browse files Browse the repository at this point in the history
Add token limit for GPT models and update prompt token limit calculation
  • Loading branch information
mergify[bot] authored Mar 26, 2023
2 parents 4f0804e + ca4a528 commit 9cb6a28
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 12 deletions.
26 changes: 20 additions & 6 deletions aichat.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error {
return fmt.Errorf("encode: %w", err)
}

firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, aiChat.options.maxTokens, aiChat.options.verbose)
tokenLimit := tokenLimitOfModel(aiChat.options.model)
firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
return err
}
Expand Down Expand Up @@ -174,7 +175,7 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error {
}

allowedTokens, err := prompt.AllowedSubsequentInputTokens(
aiChat.encoder, len(outputTokens), aiChat.options.maxTokens, aiChat.options.verbose)
aiChat.encoder, len(outputTokens), tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
return fmt.Errorf("allowed subsequent input tokens: %w", err)
}
Expand Down Expand Up @@ -208,14 +209,26 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error {
return nil
}

// tokenLimitOfModel returns the maximum number of tokens allowed for a given model.
func tokenLimitOfModel(model string) int {
switch model {
case gogpt.GPT4, gogpt.GPT40314:
return 8 * 1024
case gogpt.GPT432K, gogpt.GPT432K0314:
return 32 * 1024
default:
return 4 * 1024
}
}

func main() {
var temperature float32 = 0.5
var maxTokens = 0
var verbose = false
var listPrompts = false
var nonStreaming = false
var split = false
var model = gogpt.GPT3Dot5Turbo
var model = gogpt.GPT4
getopt.FlagLong(&temperature, "temperature", 't', "temperature")
getopt.FlagLong(&maxTokens, "max-tokens", 'm', "max tokens, 0 to use default")
getopt.FlagLong(&verbose, "verbose", 'v', "verbose output")
Expand Down Expand Up @@ -281,9 +294,10 @@ func main() {
}

var messagesSlice [][]gogpt.ChatCompletionMessage
tokenLimit := tokenLimitOfModel(aiChat.options.model)

if split {
messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, aiChat.options.maxTokens, aiChat.options.verbose)
messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -320,8 +334,8 @@ func main() {
if verbose {
log.Printf("total tokens %d", cnt)
}
if cnt+maxTokens > 4096 {
log.Fatalf("total tokens %d exceeds 4096", cnt)
if cnt+maxTokens > tokenLimit {
log.Fatalf("total tokens %d exceeds %d", cnt, tokenLimit)
}

if aiChat.options.nonStreaming {
Expand Down
16 changes: 16 additions & 0 deletions aichat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,19 @@ func TestCountTokens(t *testing.T) {
t.Errorf("CountTokens() returned %d, expected 8", count)
}
}

func TestTokenLimitOfModel(t *testing.T) {
data := []struct {
modelName string
tokenLimit int
}{
{"gpt-3.5-turbo", 4096},
{"gpt-4", 8192},
}
for _, d := range data {
tokenLimit := tokenLimitOfModel(d.modelName)
if tokenLimit != d.tokenLimit {
t.Errorf("TokenLimitForModel(%q) returned %d, expected %d", d.modelName, tokenLimit, d.tokenLimit)
}
}
}
12 changes: 6 additions & 6 deletions prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ func countMessagesTokens(encoder *tokenizer.Encoder, messages []Message) (int, e
}

// AllowedInputTokens returns the number of tokens allowed for the input
func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverride int, verbose bool) (int, error) {
func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, tokenLimit, maxTokensOverride int, verbose bool) (int, error) {
promptTokens, err := p.CountTokens(encoder)
if err != nil {
return 0, err
}
// reserve 500 tokens for output if maxTokens is not specified
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500)
result := 4096 - (promptTokens + maxTokens)
result := tokenLimit - (promptTokens + maxTokens)
if verbose {
log.Printf("allowed tokens for input is %d", result)
}
Expand All @@ -104,14 +104,14 @@ func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverrid
return result, nil
}

func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, maxTokensOverride int, verbose bool) (int, error) {
func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, tokenLimit, maxTokensOverride int, verbose bool) (int, error) {
promptTokens, err := p.CountSubsequentTokens(encoder)
if err != nil {
return 0, err
}
// reserve 500 tokens for output if maxTokens is not specified
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500)
result := 4096 - (promptTokens + maxTokens + outputLen)
result := tokenLimit - (promptTokens + maxTokens + outputLen)
if verbose {
log.Printf("allowed tokens for subsequent input is %d", result)
}
Expand Down Expand Up @@ -145,8 +145,8 @@ func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) {
return parts, nil
}

func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) {
allowedInputTokens, err := p.AllowedInputTokens(encoder, maxTokensOverride, verbose)
func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, tokenLimit, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) {
allowedInputTokens, err := p.AllowedInputTokens(encoder, tokenLimit, maxTokensOverride, verbose)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 9cb6a28

Please sign in to comment.