Skip to content

Commit

Permalink
feat: add temperature control to OpenAI package
Browse files Browse the repository at this point in the history
- Set the default sampling temperature to 0.7
- Add a flag to set the sampling temperature in the config command
- Move the initialization of the `openai` package from the `commit` command to a separate file
- Add a `WithTemperature` function to the `openai` package to set the sampling temperature
- Change the `maxTokens` variable in the `review` command from package level to local scope
- Set the `temperature` field in the `Client` struct of the `openai` package to the value set by `WithTemperature`

Signed-off-by: Bo-Yi.Wu <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed Apr 5, 2023
1 parent 5079244 commit c98e375
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 23 deletions.
1 change: 1 addition & 0 deletions cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ var commitCmd = &cobra.Command{
openai.WithBaseURL(viper.GetString("openai.base_url")),
openai.WithTimeout(viper.GetDuration("openai.timeout")),
openai.WithMaxTokens(viper.GetInt("openai.max_tokens")),
openai.WithTemperature(float32(viper.GetFloat64("openai.temperature"))),
)
if err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var availableKeys = []string{
"openai.base_url",
"openai.timeout",
"openai.max_tokens",
"openai.temperature",
}

func init() {
Expand All @@ -39,6 +40,7 @@ func init() {
configCmd.PersistentFlags().StringP("template_string", "", "", "git commit message string")
configCmd.PersistentFlags().IntP("diff_unified", "", 3, "generate diffs with <n> lines of context, default is 3")
configCmd.PersistentFlags().IntP("max_tokens", "", 300, "the maximum number of tokens to generate in the chat completion.")
configCmd.PersistentFlags().Float32P("temperature", "", 0.7, "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.")
configCmd.PersistentFlags().StringSliceP("exclude_list", "", []string{}, "exclude file from `git diff` command")

_ = viper.BindPFlag("openai.base_url", configCmd.PersistentFlags().Lookup("base_url"))
Expand All @@ -49,6 +51,7 @@ func init() {
_ = viper.BindPFlag("openai.socks", configCmd.PersistentFlags().Lookup("socks"))
_ = viper.BindPFlag("openai.timeout", configCmd.PersistentFlags().Lookup("timeout"))
_ = viper.BindPFlag("openai.max_tokens", configCmd.PersistentFlags().Lookup("max_tokens"))
_ = viper.BindPFlag("openai.temperature", configCmd.PersistentFlags().Lookup("temperature"))
_ = viper.BindPFlag("output.lang", configCmd.PersistentFlags().Lookup("lang"))
_ = viper.BindPFlag("git.diff_unified", configCmd.PersistentFlags().Lookup("diff_unified"))
_ = viper.BindPFlag("git.exclude_list", configCmd.PersistentFlags().Lookup("exclude_list"))
Expand Down
9 changes: 4 additions & 5 deletions cmd/review.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ import (
"github.com/spf13/viper"
)

var (
// The maximum number of tokens to generate in the chat completion.
// The total length of input tokens and generated tokens is limited by the model's context length.
maxTokens int
)
// The maximum number of tokens to generate in the chat completion.
// The total length of input tokens and generated tokens is limited by the model's context length.
var maxTokens int

func init() {
reviewCmd.Flags().IntVar(&diffUnified, "diff_unified", 3, "generate diffs with <n> lines of context, default is 3")
Expand Down Expand Up @@ -57,6 +55,7 @@ var reviewCmd = &cobra.Command{
openai.WithBaseURL(viper.GetString("openai.base_url")),
openai.WithTimeout(viper.GetDuration("openai.timeout")),
openai.WithMaxTokens(viper.GetInt("openai.max_tokens")),
openai.WithTemperature(float32(viper.GetFloat64("openai.temperature"))),
)
if err != nil {
return err
Expand Down
17 changes: 10 additions & 7 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ func GetModel(model string) string {

// Client is a struct that represents an OpenAI client.
type Client struct {
client *openai.Client
model string
maxTokens int
client *openai.Client
model string
maxTokens int
temperature float32
}

type Response struct {
Expand All @@ -66,7 +67,7 @@ func (c *Client) CreateChatCompletion(
req := openai.ChatCompletionRequest{
Model: c.model,
MaxTokens: c.maxTokens,
Temperature: 0.7,
Temperature: c.temperature,
TopP: 1,
Messages: []openai.ChatCompletionMessage{
{
Expand All @@ -92,7 +93,7 @@ func (c *Client) CreateCompletion(
req := openai.CompletionRequest{
Model: c.model,
MaxTokens: c.maxTokens,
Temperature: 0.7,
Temperature: c.temperature,
TopP: 1,
Prompt: content,
}
Expand Down Expand Up @@ -135,8 +136,9 @@ func (c *Client) Completion(
// returns a pointer to a Client and an error.
func New(opts ...Option) (*Client, error) {
cfg := &config{
maxTokens: defaultMaxTokens,
model: defaultModel,
maxTokens: defaultMaxTokens,
model: defaultModel,
temperature: defaultTemperature,
}

// Loop through each option
Expand All @@ -156,6 +158,7 @@ func New(opts ...Option) (*Client, error) {
}
instance.model = v
instance.maxTokens = cfg.maxTokens
instance.temperature = cfg.temperature

c := openai.DefaultConfig(cfg.token)
if cfg.orgID != "" {
Expand Down
37 changes: 26 additions & 11 deletions openai/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import (
)

const (
defaultMaxTokens = 300
defaultModel = openai.GPT3Dot5Turbo
defaultMaxTokens = 300
defaultModel = openai.GPT3Dot5Turbo
defaultTemperature = 0.7
)

// Option is an interface that specifies instrumentation configuration options.
Expand Down Expand Up @@ -85,22 +86,36 @@ func WithTimeout(val time.Duration) Option {
// The maximum number of tokens to generate in the chat completion.
// The total length of input tokens and generated tokens is limited by the model's context length.
func WithMaxTokens(val int) Option {
if val == 0 {
if val <= 0 {
val = defaultMaxTokens
}
return optionFunc(func(c *config) {
c.maxTokens = val
})
}

// WithTemperature returns a new Option that sets the temperature for the client configuration.
// What sampling temperature to use, between 0 and 2.
// Higher values like 0.8 will make the output more random,
// while lower values like 0.2 will make it more focused and deterministic.
func WithTemperature(val float32) Option {
if val <= 0 {
val = defaultTemperature
}
return optionFunc(func(c *config) {
c.temperature = val
})
}

// config is a struct that stores configuration options for the instrumentation.
type config struct {
baseURL string
token string
orgID string
model string
proxyURL string
socksURL string
timeout time.Duration
maxTokens int
baseURL string
token string
orgID string
model string
proxyURL string
socksURL string
timeout time.Duration
maxTokens int
temperature float32
}

0 comments on commit c98e375

Please sign in to comment.