Skip to content

Commit

Permalink
Add token limit for GPT models and update prompt token limit calculation
Browse files Browse the repository at this point in the history
This commit adds a function to calculate the maximum token limit for a given
GPT model. It also updates the `AllowedInputTokens` and
`AllowedSubsequentInputTokens` functions to use the token limit instead of
hardcoding the value to 4096. Additionally, the `CreateMessagesWithSplit`
function is updated to use the new `AllowedInputTokens` function signature.
Finally, a new test function is added to test the `tokenLimitOfModel` function.
  • Loading branch information
tkawachi committed Mar 26, 2023
1 parent 054c1d6 commit ca4a528
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 12 deletions.
26 changes: 20 additions & 6 deletions aichat.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error {
return fmt.Errorf("encode: %w", err)
}

firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, aiChat.options.maxTokens, aiChat.options.verbose)
tokenLimit := tokenLimitOfModel(aiChat.options.model)
firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
return err
}
Expand Down Expand Up @@ -174,7 +175,7 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error {
}

allowedTokens, err := prompt.AllowedSubsequentInputTokens(
aiChat.encoder, len(outputTokens), aiChat.options.maxTokens, aiChat.options.verbose)
aiChat.encoder, len(outputTokens), tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
return fmt.Errorf("allowed subsequent input tokens: %w", err)
}
Expand Down Expand Up @@ -208,14 +209,26 @@ func (aiChat *AIChat) fold(prompt *Prompt, input string) error {
return nil
}

// tokenLimitOfModel returns the maximum number of tokens allowed for a given model.
func tokenLimitOfModel(model string) int {
switch model {
case gogpt.GPT4, gogpt.GPT40314:
return 8 * 1024
case gogpt.GPT432K, gogpt.GPT432K0314:
return 32 * 1024
default:
return 4 * 1024
}
}

func main() {
var temperature float32 = 0.5
var maxTokens = 0
var verbose = false
var listPrompts = false
var nonStreaming = false
var split = false
var model = gogpt.GPT3Dot5Turbo
var model = gogpt.GPT4
getopt.FlagLong(&temperature, "temperature", 't', "temperature")
getopt.FlagLong(&maxTokens, "max-tokens", 'm', "max tokens, 0 to use default")
getopt.FlagLong(&verbose, "verbose", 'v', "verbose output")
Expand Down Expand Up @@ -281,9 +294,10 @@ func main() {
}

var messagesSlice [][]gogpt.ChatCompletionMessage
tokenLimit := tokenLimitOfModel(aiChat.options.model)

if split {
messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, aiChat.options.maxTokens, aiChat.options.verbose)
messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, tokenLimit, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -320,8 +334,8 @@ func main() {
if verbose {
log.Printf("total tokens %d", cnt)
}
if cnt+maxTokens > 4096 {
log.Fatalf("total tokens %d exceeds 4096", cnt)
if cnt+maxTokens > tokenLimit {
log.Fatalf("total tokens %d exceeds %d", cnt, tokenLimit)
}

if aiChat.options.nonStreaming {
Expand Down
16 changes: 16 additions & 0 deletions aichat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,19 @@ func TestCountTokens(t *testing.T) {
t.Errorf("CountTokens() returned %d, expected 8", count)
}
}

func TestTokenLimitOfModel(t *testing.T) {
data := []struct {
modelName string
tokenLimit int
}{
{"gpt-3.5-turbo", 4096},
{"gpt-4", 8192},
}
for _, d := range data {
tokenLimit := tokenLimitOfModel(d.modelName)
if tokenLimit != d.tokenLimit {
t.Errorf("TokenLimitForModel(%q) returned %d, expected %d", d.modelName, tokenLimit, d.tokenLimit)
}
}
}
12 changes: 6 additions & 6 deletions prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ func countMessagesTokens(encoder *tokenizer.Encoder, messages []Message) (int, e
}

// AllowedInputTokens returns the number of tokens allowed for the input
func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverride int, verbose bool) (int, error) {
func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, tokenLimit, maxTokensOverride int, verbose bool) (int, error) {
promptTokens, err := p.CountTokens(encoder)
if err != nil {
return 0, err
}
// reserve 500 tokens for output if maxTokens is not specified
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500)
result := 4096 - (promptTokens + maxTokens)
result := tokenLimit - (promptTokens + maxTokens)
if verbose {
log.Printf("allowed tokens for input is %d", result)
}
Expand All @@ -104,14 +104,14 @@ func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverrid
return result, nil
}

func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, maxTokensOverride int, verbose bool) (int, error) {
func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, tokenLimit, maxTokensOverride int, verbose bool) (int, error) {
promptTokens, err := p.CountSubsequentTokens(encoder)
if err != nil {
return 0, err
}
// reserve 500 tokens for output if maxTokens is not specified
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500)
result := 4096 - (promptTokens + maxTokens + outputLen)
result := tokenLimit - (promptTokens + maxTokens + outputLen)
if verbose {
log.Printf("allowed tokens for subsequent input is %d", result)
}
Expand Down Expand Up @@ -145,8 +145,8 @@ func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) {
return parts, nil
}

func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) {
allowedInputTokens, err := p.AllowedInputTokens(encoder, maxTokensOverride, verbose)
func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, tokenLimit, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) {
allowedInputTokens, err := p.AllowedInputTokens(encoder, tokenLimit, maxTokensOverride, verbose)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit ca4a528

Please sign in to comment.