Skip to content

Commit

Permalink
Merge pull request #21 from tkawachi/fold
Browse files Browse the repository at this point in the history
Added folding support via subsequent_message
  • Loading branch information
mergify[bot] authored Mar 12, 2023
2 parents 797ee1e + bc44a30 commit 176f1e1
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 21 deletions.
95 changes: 94 additions & 1 deletion aichat.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type chatOptions struct {

type AIChat struct {
client *gogpt.Client
encoder *tokenizer.Encoder
options chatOptions
}

Expand Down Expand Up @@ -123,6 +124,86 @@ func firstNonZeroFloat32(f ...float32) float32 {
return 0
}

func (aiChat *AIChat) fold(prompt *Prompt, input string) error {
encoded, err := aiChat.encoder.Encode(input)
if err != nil {
return fmt.Errorf("encode: %w", err)
}

firstAllowedTokens, err := prompt.AllowedInputTokens(aiChat.encoder, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
return err
}

firstEncoded := encoded[:firstAllowedTokens]
firstInput := aiChat.encoder.Decode(firstEncoded)
temperature := firstNonZeroFloat32(aiChat.options.temperature, prompt.Temperature)
firstRequest := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: prompt.CreateMessages(firstInput),
Temperature: temperature,
}
if aiChat.options.verbose {
log.Printf("first request: %+v", firstRequest)
}
response, err := aiChat.client.CreateChatCompletion(context.Background(), firstRequest)
if err != nil {
return fmt.Errorf("create chat completion: %w", err)
}
if len(response.Choices) == 0 {
return fmt.Errorf("no choices returned")
}
output := response.Choices[0].Message.Content
if firstAllowedTokens >= len(encoded) {
fmt.Println(output)
return nil
}
if aiChat.options.verbose {
log.Printf("first output: %s", output)
}

idx := firstAllowedTokens
for idx < len(encoded) {
outputTokens, err := aiChat.encoder.Encode(output)
if err != nil {
return fmt.Errorf("encode: %w", err)
}

allowedTokens, err := prompt.AllowedSubsequentInputTokens(
aiChat.encoder, len(outputTokens), aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
return fmt.Errorf("allowed subsequent input tokens: %w", err)
}
nextIdx := idx + allowedTokens
if nextIdx > len(encoded) {
nextIdx = len(encoded)
}
input := aiChat.encoder.Decode(encoded[idx:nextIdx])
request := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: prompt.CreateSubsequentMessages(output, input),
Temperature: temperature,
}
if aiChat.options.verbose {
log.Printf("subsequent request: %+v", request)
}
response, err := aiChat.client.CreateChatCompletion(context.Background(), request)
if err != nil {
return fmt.Errorf("create chat completion: %w", err)
}
if len(response.Choices) == 0 {
return fmt.Errorf("no choices returned")
}
output = response.Choices[0].Message.Content
if aiChat.options.verbose {
log.Printf("subsequent output: %s", output)
}
idx = nextIdx
}
fmt.Println(output)
return nil
}

func main() {
var temperature float32 = 0.5
var maxTokens = 0
Expand Down Expand Up @@ -158,8 +239,13 @@ func main() {
if verbose {
log.Printf("options: %+v", options)
}
encoder, err := tokenizer.NewEncoder()
if err != nil {
log.Fatal(err)
}
aiChat := AIChat{
client: gogpt.NewClient(openaiAPIKey),
encoder: encoder,
options: options,
}

Expand All @@ -180,10 +266,17 @@ func main() {
// read all from Stdin
input := scanAll(bufio.NewScanner(os.Stdin))

if prompt.isFoldEnabled() {
if err := aiChat.fold(prompt, input); err != nil {
log.Fatal(err)
}
return
}

var messagesSlice [][]gogpt.ChatCompletionMessage

if split {
messagesSlice, err = prompt.CreateMessagesWithSplit(input, aiChat.options.maxTokens, aiChat.options.verbose)
messagesSlice, err = prompt.CreateMessagesWithSplit(aiChat.encoder, input, aiChat.options.maxTokens, aiChat.options.verbose)
if err != nil {
log.Fatal(err)
}
Expand Down
85 changes: 67 additions & 18 deletions prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,25 @@ import (
)

const DefaultInputMarker = "$INPUT"
const DefaultOutputMarker = "$OUTPUT"

type Message struct {
Role string `yaml:"role"`
Content string `yaml:"content"`
}

type Prompt struct {
Description string `yaml:"description"`
InputMarker string `yaml:"input_marker"`
Messages []struct {
Role string `yaml:"role"`
Content string `yaml:"content"`
} `yaml:"messages"`
Temperature float32 `yaml:"temperature"`
MaxTokens int `yaml:"max_tokens"`
Description string `yaml:"description"`
InputMarker string `yaml:"input_marker"`
OutputMarker string `yaml:"output_marker"`
Messages []Message `yaml:"messages"`
SubsequentMessages []Message `yaml:"subsequent_messages"`
Temperature float32 `yaml:"temperature"`
MaxTokens int `yaml:"max_tokens"`
}

func (p *Prompt) isFoldEnabled() bool {
return len(p.SubsequentMessages) > 0
}

func (p *Prompt) CreateMessages(input string) []gogpt.ChatCompletionMessage {
Expand All @@ -39,14 +48,34 @@ func (p *Prompt) CreateMessages(input string) []gogpt.ChatCompletionMessage {
return messages
}

func (p *Prompt) CreateSubsequentMessages(output, input string) []gogpt.ChatCompletionMessage {
messages := []gogpt.ChatCompletionMessage{}
for _, message := range p.SubsequentMessages {
// replace input marker with input
content := strings.ReplaceAll(message.Content, p.InputMarker, input)
// replace output marker with output
content = strings.ReplaceAll(content, p.OutputMarker, output)

messages = append(messages, gogpt.ChatCompletionMessage{
Role: message.Role,
Content: content,
})
}
return messages
}

// CountTokens counts the number of tokens in the prompt
func (p *Prompt) CountTokens() (int, error) {
func (p *Prompt) CountTokens(encoder *tokenizer.Encoder) (int, error) {
return countMessagesTokens(encoder, p.Messages)
}

func (p *Prompt) CountSubsequentTokens(encoder *tokenizer.Encoder) (int, error) {
return countMessagesTokens(encoder, p.SubsequentMessages)
}

func countMessagesTokens(encoder *tokenizer.Encoder, messages []Message) (int, error) {
count := 0
encoder, err := tokenizer.NewEncoder()
if err != nil {
return 0, err
}
for _, message := range p.Messages {
for _, message := range messages {
// Encode string with GPT tokenizer
encoded, err := encoder.Encode(message.Content)
if err != nil {
Expand All @@ -58,8 +87,8 @@ func (p *Prompt) CountTokens() (int, error) {
}

// AllowedInputTokens returns the number of tokens allowed for the input
func (p *Prompt) AllowedInputTokens(maxTokensOverride int, verbose bool) (int, error) {
promptTokens, err := p.CountTokens()
func (p *Prompt) AllowedInputTokens(encoder *tokenizer.Encoder, maxTokensOverride int, verbose bool) (int, error) {
promptTokens, err := p.CountTokens(encoder)
if err != nil {
return 0, err
}
Expand All @@ -75,6 +104,23 @@ func (p *Prompt) AllowedInputTokens(maxTokensOverride int, verbose bool) (int, e
return result, nil
}

func (p *Prompt) AllowedSubsequentInputTokens(encoder *tokenizer.Encoder, outputLen, maxTokensOverride int, verbose bool) (int, error) {
promptTokens, err := p.CountSubsequentTokens(encoder)
if err != nil {
return 0, err
}
// reserve 500 tokens for output if maxTokens is not specified
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens, 500)
result := 4096 - (promptTokens + maxTokens + outputLen)
if verbose {
log.Printf("allowed tokens for subsequent input is %d", result)
}
if result <= 0 {
return 0, fmt.Errorf("allowed tokens for subsequent input is %d, but it should be greater than 0", result)
}
return result, nil
}

func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) {
encoder, err := tokenizer.NewEncoder()
if err != nil {
Expand All @@ -99,8 +145,8 @@ func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) {
return parts, nil
}

func (p *Prompt) CreateMessagesWithSplit(input string, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) {
allowedInputTokens, err := p.AllowedInputTokens(maxTokensOverride, verbose)
func (p *Prompt) CreateMessagesWithSplit(encoder *tokenizer.Encoder, input string, maxTokensOverride int, verbose bool) ([][]gogpt.ChatCompletionMessage, error) {
allowedInputTokens, err := p.AllowedInputTokens(encoder, maxTokensOverride, verbose)
if err != nil {
return nil, err
}
Expand All @@ -123,6 +169,9 @@ func NewPromptFromFile(filename string) (*Prompt, error) {
if prompt.InputMarker == "" {
prompt.InputMarker = DefaultInputMarker
}
if prompt.OutputMarker == "" {
prompt.OutputMarker = DefaultOutputMarker
}
return prompt, nil
}

Expand Down
35 changes: 33 additions & 2 deletions prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,39 @@ func TestLoadPrompts(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if prompt.Description == "" {
t.Errorf("expected description, got empty string")
if prompt.Description != "Name a Git branch" {
t.Errorf("expected description 'Name a Git branch', got %q", prompt.Description)
}
if len(prompt.Messages) != 2 {
t.Errorf("expected 2 message, got %d", len(prompt.Messages))
}
if prompt.Temperature != 0.5 {
t.Errorf("expected temperature 0.5, got %f", prompt.Temperature)
}
if prompt.isFoldEnabled() {
t.Errorf("expected fold to be disabled")
}
if prompt.InputMarker != DefaultInputMarker {
t.Errorf("expected input marker to be %q, got %q", DefaultInputMarker, prompt.InputMarker)
}
if prompt.OutputMarker != DefaultOutputMarker {
t.Errorf("expected output marker to be %q, got %q", DefaultOutputMarker, prompt.OutputMarker)
}
}

func TestLoadPromptsFold(t *testing.T) {
prompt, err := NewPromptFromFile(filepath.Join("testdata", "fold.yml"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if prompt.Description != "Summarize" {
t.Errorf("expected description 'Summarize', got %q", prompt.Description)
}
if len(prompt.Messages) != 2 {
t.Errorf("expected 2 message, got %d", len(prompt.Messages))
}
if !prompt.isFoldEnabled() {
t.Errorf("expected fold to be enabled")
}
}

Expand Down
17 changes: 17 additions & 0 deletions testdata/fold.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
description: Summarize
messages:
- role: system
content: >-
Summarize the following text in one sentence.
- role: user
content: $INPUT

subsequent_messages:
- role: system
content: >-
Summarize the following text in one sentence.
Here is the summary you wrote previously:
$OUTPUT
- role: user
content: $INPUT
2 changes: 2 additions & 0 deletions testdata/name-branch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ messages:
It should consist of one to three English words, and the shorter the better.
- role: user
content: $INPUT

temperature: 0.5

0 comments on commit 176f1e1

Please sign in to comment.