Skip to content

Commit

Permalink
Fixed calculation of allowed tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
tkawachi committed Mar 12, 2023
1 parent 30a1ebb commit 08b082d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
9 changes: 6 additions & 3 deletions aichat.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ func main() {
messagesSlice = [][]gogpt.ChatCompletionMessage{messages}
}

for _, messages := range messagesSlice {
maxTokens := firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens)

maxTokens := firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens)
for _, messages := range messagesSlice {

request := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Expand All @@ -209,7 +209,10 @@ func main() {
if err != nil {
log.Fatal(err)
}
if cnt > 4096 {
if verbose {
log.Printf("total tokens %d", cnt)
}
if cnt+maxTokens > 4096 {
log.Fatalf("total tokens %d exceeds 4096", cnt)
}

Expand Down
13 changes: 7 additions & 6 deletions prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ func (p *Prompt) CountTokens() (int, error) {
}

// AllowedInputTokens returns the number of tokens allowed for the input
func (p *Prompt) AllowedInputTokens(maxTokensOverride int) (int, error) {
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens)
func (p *Prompt) AllowedInputTokens() (int, error) {
promptTokens, err := p.CountTokens()
if err != nil {
return 0, err
}
return maxTokens - promptTokens, nil
result := 4096 - (promptTokens + p.MaxTokens)
if result <= 0 {
return 0, fmt.Errorf("allowed tokens for input is %d, but it should be greater than 0", result)
}
return result, nil
}

func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) {
Expand Down Expand Up @@ -91,12 +94,10 @@ func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) {
}

func (p *Prompt) CreateMessagesWithSplit(input string, maxTokensOverride int) ([][]gogpt.ChatCompletionMessage, error) {
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens)
promptTokens, err := p.CountTokens()
allowedInputTokens, err := p.AllowedInputTokens()
if err != nil {
return nil, err
}
allowedInputTokens := maxTokens - promptTokens
inputParts, err := splitStringWithTokensLimit(input, allowedInputTokens)
if err != nil {
return nil, err
Expand Down

0 comments on commit 08b082d

Please sign in to comment.