Skip to content

Commit

Permalink
feat: set default maxTokens for OpenAI client to 300
Browse files Browse the repository at this point in the history
- Add an `openai.max_tokens` flag to the `review` command with a default value of `300`
- Add a `WithMaxTokens` function to the OpenAI client options, which sets the `maxTokens` field in the client configuration
- Use the `openai.max_tokens` flag value in the `review` command when calling the OpenAI API
- Remove the `MaxTokens` field initialization with a hardcoded value in the `CreateChatCompletion` and `CreateCompletion` functions of the OpenAI client
- Change the `maxTokens` field of the OpenAI client configuration to be set by default to `300` in the `New` function of the OpenAI client
- Add a constant `defaultMaxTokens` to the OpenAI client options file

Signed-off-by: Bo-Yi.Wu <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed Mar 25, 2023
1 parent de1d07a commit cb3d75c
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 12 deletions.
1 change: 1 addition & 0 deletions cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ var commitCmd = &cobra.Command{
openai.WithSocksURL(viper.GetString("openai.socks")),
openai.WithBaseURL(viper.GetString("openai.base_url")),
openai.WithTimeout(viper.GetDuration("openai.timeout")),
openai.WithMaxTokens(viper.GetInt("openai.max_tokens")),
)
if err != nil {
return err
Expand Down
4 changes: 4 additions & 0 deletions cmd/hepler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func check() error {
viper.Set("openai.socks", socksProxy)
}

if maxTokens != 0 {
viper.Set("openai.max_tokens", maxTokens)
}

if templateFile != "" {
viper.Set("git.template_file", templateFile)
}
Expand Down
8 changes: 8 additions & 0 deletions cmd/review.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@ import (
"github.com/spf13/viper"
)

var (
// The maximum number of tokens to generate in the chat completion.
// The total length of input tokens and generated tokens is limited by the model's context length.
maxTokens int
)

func init() {
reviewCmd.Flags().IntVar(&diffUnified, "diff_unified", 3, "generate diffs with <n> lines of context, default is 3")
reviewCmd.Flags().IntVar(&maxTokens, "max_tokens", 300, "the maximum number of tokens to generate in the chat completion.")
reviewCmd.Flags().StringVar(&commitModel, "model", "gpt-3.5-turbo", "select openai model")
reviewCmd.Flags().StringVar(&commitLang, "lang", "en", "summarizing language uses English by default")
reviewCmd.Flags().StringSliceVar(&excludeList, "exclude_list", []string{}, "exclude file from git diff command")
Expand Down Expand Up @@ -49,6 +56,7 @@ var reviewCmd = &cobra.Command{
openai.WithSocksURL(viper.GetString("openai.socks")),
openai.WithBaseURL(viper.GetString("openai.base_url")),
openai.WithTimeout(viper.GetDuration("openai.timeout")),
openai.WithMaxTokens(viper.GetInt("openai.max_tokens")),
)
if err != nil {
return err
Expand Down
15 changes: 10 additions & 5 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ func GetModel(model string) string {

// Client is a struct that represents an OpenAI client.
type Client struct {
client *openai.Client
model string
client *openai.Client
model string
maxTokens int
}

type Response struct {
Expand All @@ -64,7 +65,7 @@ func (c *Client) CreateChatCompletion(
) (resp openai.ChatCompletionResponse, err error) {
req := openai.ChatCompletionRequest{
Model: c.model,
MaxTokens: 300,
MaxTokens: c.maxTokens,
Temperature: 0.7,
TopP: 1,
Messages: []openai.ChatCompletionMessage{
Expand All @@ -90,7 +91,7 @@ func (c *Client) CreateCompletion(
) (resp openai.CompletionResponse, err error) {
req := openai.CompletionRequest{
Model: c.model,
MaxTokens: 300,
MaxTokens: c.maxTokens,
Temperature: 0.7,
TopP: 1,
Prompt: content,
Expand Down Expand Up @@ -133,7 +134,10 @@ func (c *Client) Completion(
// New is a function that takes a variadic slice of Option types and
// returns a pointer to a Client and an error.
func New(opts ...Option) (*Client, error) {
cfg := &config{}
cfg := &config{
maxTokens: defaultMaxTokens,
model: defaultModel,
}

// Loop through each option
for _, o := range opts {
Expand All @@ -151,6 +155,7 @@ func New(opts ...Option) (*Client, error) {
return nil, errors.New("missing model")
}
instance.model = v
instance.maxTokens = cfg.maxTokens

c := openai.DefaultConfig(cfg.token)
if cfg.orgID != "" {
Expand Down
34 changes: 27 additions & 7 deletions openai/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ package openai

import (
"time"

"github.com/sashabaranov/go-openai"
)

const (
defaultMaxTokens = 300
defaultModel = openai.GPT3Dot5Turbo
)

// Option is an interface that specifies instrumentation configuration options.
Expand Down Expand Up @@ -74,13 +81,26 @@ func WithTimeout(val time.Duration) Option {
})
}

// WithMaxTokens returns a new Option that sets the max tokens for the client configuration.
// The maximum number of tokens to generate in the chat completion.
// The total length of input tokens and generated tokens is limited by the model's context length.
func WithMaxTokens(val int) Option {
if val == 0 {
val = defaultMaxTokens
}
return optionFunc(func(c *config) {
c.maxTokens = val
})
}

// config is a struct that stores configuration options for the instrumentation.
type config struct {
baseURL string
token string
orgID string
model string
proxyURL string
socksURL string
timeout time.Duration
baseURL string
token string
orgID string
model string
proxyURL string
socksURL string
timeout time.Duration
maxTokens int
}

0 comments on commit cb3d75c

Please sign in to comment.