From bc44a302881f7c3269ca519eb7dd6c2915c9e23c Mon Sep 17 00:00:00 2001 From: Takashi Kawachi Date: Mon, 13 Mar 2023 00:31:17 +0900 Subject: [PATCH] Added folding support via subsequent_message --- aichat.go | 95 +++++++++++++++++++++++++++++++++++++++- prompt.go | 85 +++++++++++++++++++++++++++-------- prompt_test.go | 35 ++++++++++++++- testdata/fold.yml | 17 +++++++ testdata/name-branch.yml | 2 + 5 files changed, 213 insertions(+), 21 deletions(-) create mode 100644 testdata/fold.yml diff --git a/aichat.go b/aichat.go index 287bc0b..b17eeb4 100644 --- a/aichat.go +++ b/aichat.go @@ -24,6 +24,7 @@ type chatOptions struct { type AIChat struct { client *gogpt.Client + encoder *tokenizer.Encoder options chatOptions } @@ -123,6 +124,86 @@ func firstNonZeroFloat32(f ...float32) float32 { return 0 } +func (aiChat *AIChat) fold(prompt *Prompt, input string) error { + encoded, err := aiChat.encoder.Encode(input) + if err != nil { + return fmt.Errorf("encode: %w", err) + } + + firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, aiChat.options.maxTokens, aiChat.options.verbose) + if err != nil { + return err + } + + firstEncoded := encoded[:firstAllowedTokens] + firstInput := aiChat.encoder.Decode(firstEncoded) + temperature := firstNonZeroFloat32(aiChat.options.temperature, prompt.Temperature) + firstRequest := gogpt.ChatCompletionRequest{ + Model: gogpt.GPT3Dot5Turbo, + Messages: prompt.CreateMessages(firstInput), + Temperature: temperature, + } + if aiChat.options.verbose { + log.Printf("first request: %+v", firstRequest) + } + response, err := aiChat.client.CreateChatCompletion(context.Background(), firstRequest) + if err != nil { + return fmt.Errorf("create chat completion: %w", err) + } + if len(response.Choices) == 0 { + return fmt.Errorf("no choices returned") + } + output := response.Choices[0].Message.Content + if firstAllowedTokens >= len(encoded) { + fmt.Println(output) + return nil + } + if aiChat.options.verbose { + log.Printf("first output: %s", output) + } + + idx := firstAllowedTokens + for idx < len(encoded) { + outputTokens, err := aiChat.encoder.Encode(output) + if err != nil { + return fmt.Errorf("encode: %w", err) + } + + allowedTokens, err := prompt.AllowedSubsequentInputTokens( + aiChat.encoder, len(outputTokens), aiChat.options.maxTokens, aiChat.options.verbose) + if err != nil { + return fmt.Errorf("allowed subsequent input tokens: %w", err) + } + nextIdx := idx + allowedTokens + if nextIdx > len(encoded) { + nextIdx = len(encoded) + } + input := aiChat.encoder.Decode(encoded[idx:nextIdx]) + request := gogpt.ChatCompletionRequest{ + Model: gogpt.GPT3Dot5Turbo, + Messages: prompt.CreateSubsequentMessages(output, input), + Temperature: temperature, + } + if aiChat.options.verbose { + log.Printf("subsequent request: %+v", request) + } + response, err := aiChat.client.CreateChatCompletion(context.Background(), request) + if err != nil { + return fmt.Errorf("create chat completion: %w", err) + } + if len(response.Choices) == 0 { + return fmt.Errorf("no choices returned") + } + output = response.Choices[0].Message.Content + if aiChat.options.verbose { + log.Printf("subsequent output: %s", output) + } + idx = nextIdx + } + fmt.Println(output) + return nil +} + func main() { var temperature float32 = 0.5 var maxTokens = 0 @@ -158,8 +239,13 @@ func main() { if verbose { log.Printf("options: %+v", options) } + encoder, err := tokenizer.NewEncoder() + if err != nil { + log.Fatal(err) + } aiChat := AIChat{ client: gogpt.NewClient(openaiAPIKey), + encoder: encoder, options: options, } @@ -180,10 +266,17 @@ func main() { // read all from Stdin input := scanAll(bufio.NewScanner(os.Stdin)) + if prompt.isFoldEnabled() { + if err := aiChat.fold(prompt, input); err != nil { + log.Fatal(err) + } + return + } + var messagesSlice [][]gogpt.ChatCompletionMessage if split { - messagesSlice, err = prompt.CreateMessagesWithSplit(input, aiChat.options.maxTokens, aiChat.options.verbose) + messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, aiChat.options.maxTokens, aiChat.options.verbose) if err != nil { log.Fatal(err) } diff --git a/prompt.go b/prompt.go index 8c47d15..f53170f 100644 --- a/prompt.go +++ b/prompt.go @@ -13,16 +13,25 @@ import ( ) const DefaultInputMarker = "$INPUT" +const DefaultOutputMarker = "$OUTPUT" + +type Message struct { + Role string `yaml:"role"` + Content string `yaml:"content"` +} type Prompt struct { - Description string `yaml:"description"` - InputMarker string `yaml:"input_marker"` - Messages []struct { - Role string `yaml:"role"` - Content string `yaml:"content"` - } `yaml:"messages"` - Temperature float32 `yaml:"temperature"` - MaxTokens int `yaml:"max_tokens"` + Description string `yaml:"description"` + InputMarker string `yaml:"input_marker"` + OutputMarker string `yaml:"output_marker"` + Messages []Message `yaml:"messages"` + SubsequentMessages []Message `yaml:"subsequent_messages"` + Temperature float32 `yaml:"temperature"` + MaxTokens int `yaml:"max_tokens"` +} + +func (p *Prompt) isFoldEnabled() bool { + return len(p.SubsequentMessages) > 0 } func (p *Prompt) CreateMessages(input string) []gogpt.ChatCompletionMessage { @@ -39,14 +48,34 @@ func (p *Prompt) CreateMessages(input string) []gogpt.ChatCompletionMessage { return messages } +func (p *Prompt) CreateSubsequentMessages(output, input string) []gogpt.ChatCompletionMessage { + messages := []gogpt.ChatCompletionMessage{} + for _, message := range p.SubsequentMessages { + // replace input marker with input + content := strings.ReplaceAll(message.Content, p.InputMarker, input) + // replace output marker with output + content = strings.ReplaceAll(content, p.OutputMarker, output) + + messages = append(messages, gogpt.ChatCompletionMessage{ + Role: message.Role, + Content: content, + }) + } + return messages +} + // CountTokens counts the number of tokens in the prompt -func (p *Prompt) CountTokens() (int, error) { +func (p *Prompt) CountTokens(encoder *tokenizer.Encoder) (int, error) { + return countMessagesTokens(encoder, p.Messages) +} + +func (p *Prompt) CountSubsequentTokens(encoder *tokenizer.Encoder) (int, error) { + return countMessagesTokens(encoder, p.SubsequentMessages) +} + +func countMessagesTokens(encoder *tokenizer.Encoder, messages []Message) (int, error) { count := 0 - encoder, err := tokenizer.NewEncoder() - if err != nil { - return 0, err - } - for _, message := range p.Messages { + for _, message := range messages { // Encode string with GPT tokenizer encoded, err := encoder.Encode(message.Content) if err != nil { @@ -58,8 +87,8 @@ func (p *Prompt) CountTokens() (int, error) { } // AllowedInputTokens returns the number of tokens allowed for the input -func (p *Prompt) AllowedInputTokens(maxTokensOverride int, verbose bool) (int, error) { - promptTokens, err := p.CountTokens() +func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverride int, verbose bool) (int, error) { + promptTokens, err := p.CountTokens(encoder) if err != nil { return 0, err } @@ -75,6 +104,23 @@ func (p *Prompt) AllowedInputTokens(maxTokensOverride int, verbose bool) (int, e return result, nil } +func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, maxTokensOverride int, verbose bool) (int, error) { + promptTokens, err := p.CountSubsequentTokens(encoder) + if err != nil { + return 0, err + } + // reserve 500 tokens for output if maxTokens is not specified + maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500) + result := 4096 - (promptTokens + maxTokens + outputLen) + if verbose { + log.Printf("allowed tokens for subsequent input is %d", result) + } + if result <= 0 { + return 0, fmt.Errorf("allowed tokens for subsequent input is %d, but it should be greater than 0", result) + } + return result, nil +} + func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) { encoder, err := tokenizer.NewEncoder() if err != nil { @@ -99,8 +145,8 @@ func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) { return parts, nil } -func (p *Prompt) CreateMessagesWithSplit(input string, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) { - allowedInputTokens, err := p.AllowedInputTokens(maxTokensOverride, verbose) +func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) { + allowedInputTokens, err := p.AllowedInputTokens(encoder, maxTokensOverride, verbose) if err != nil { return nil, err } @@ -123,6 +169,9 @@ func NewPromptFromFile(filename string) (*Prompt, error) { if prompt.InputMarker == "" { prompt.InputMarker = DefaultInputMarker } + if prompt.OutputMarker == "" { + prompt.OutputMarker = DefaultOutputMarker + } return prompt, nil } diff --git a/prompt_test.go b/prompt_test.go index aeb82e6..4fc9d40 100644 --- a/prompt_test.go +++ b/prompt_test.go @@ -10,8 +10,39 @@ func TestLoadPrompts(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if prompt.Description == "" { - t.Errorf("expected description, got empty string") + if prompt.Description != "Name a Git branch" { + t.Errorf("expected description 'Name a Git branch', got %q", prompt.Description) + } + if len(prompt.Messages) != 2 { + t.Errorf("expected 2 message, got %d", len(prompt.Messages)) + } + if prompt.Temperature != 0.5 { + t.Errorf("expected temperature 0.5, got %f", prompt.Temperature) + } + if prompt.isFoldEnabled() { + t.Errorf("expected fold to be disabled") + } + if prompt.InputMarker != DefaultInputMarker { + t.Errorf("expected input marker to be %q, got %q", DefaultInputMarker, prompt.InputMarker) + } + if prompt.OutputMarker != DefaultOutputMarker { + t.Errorf("expected output marker to be %q, got %q", DefaultOutputMarker, prompt.OutputMarker) + } +} + +func TestLoadPromptsFold(t *testing.T) { + prompt, err := NewPromptFromFile(filepath.Join("testdata", "fold.yml")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if prompt.Description != "Summarize" { + t.Errorf("expected description 'Summarize', got %q", prompt.Description) + } + if len(prompt.Messages) != 2 { + t.Errorf("expected 2 message, got %d", len(prompt.Messages)) + } + if !prompt.isFoldEnabled() { + t.Errorf("expected fold to be enabled") } } diff --git a/testdata/fold.yml b/testdata/fold.yml new file mode 100644 index 0000000..f5ffa7a --- /dev/null +++ b/testdata/fold.yml @@ -0,0 +1,17 @@ +description: Summarize +messages: + - role: system + content: >- + Summarize the following text in one sentence. + - role: user + content: $INPUT + +subsequent_messages: + - role: system + content: >- + Summarize the following text in one sentence. + + Here is the summary you wrote previously: + $OUTPUT + - role: user + content: $INPUT diff --git a/testdata/name-branch.yml b/testdata/name-branch.yml index 363ba08..9625981 100644 --- a/testdata/name-branch.yml +++ b/testdata/name-branch.yml @@ -6,3 +6,5 @@ messages: It should consist of one to three English words, and the shorter the better. - role: user content: $INPUT + +temperature: 0.5