Skip to content

Commit

Permalink
feat: add SOCKS proxy support to OpenAI API client
Browse files Browse the repository at this point in the history
- Add `openai.WithSocksURL` option to set the `socksURL` field of the config struct
- Add `availableKeys` entry for `openai.socks`
- Modify `configCmd.PersistentFlags()` to include `socks` flag
- Modify `init()` to bind `socks` flag to `viper`
- Modify `New()` to check for `socksURL` and set up `httpClient.Transport` accordingly.
  • Loading branch information
appleboy committed Mar 13, 2023
1 parent a663a00 commit 278219d
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 5 deletions.
1 change: 1 addition & 0 deletions cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ var commitCmd = &cobra.Command{
openai.WithModel(viper.GetString("openai.model")),
openai.WithOrgID(viper.GetString("openai.org_id")),
openai.WithProxyURL(viper.GetString("openai.proxy")),
openai.WithSocksURL(viper.GetString("openai.socks")),
)
if err != nil {
return err
Expand Down
4 changes: 3 additions & 1 deletion cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ import (
"github.com/spf13/viper"
)

var availableKeys = []string{"openai.api_key", "openai.model", "openai.org_id", "openai.proxy", "output.lang"}
var availableKeys = []string{"openai.socks", "openai.api_key", "openai.model", "openai.org_id", "openai.proxy", "output.lang"}

func init() {
configCmd.PersistentFlags().StringP("api_key", "k", "", "openai api key")
configCmd.PersistentFlags().StringP("model", "m", "gpt-3.5-turbo", "openai model")
configCmd.PersistentFlags().StringP("lang", "l", "en", "summarizing language uses English by default")
configCmd.PersistentFlags().StringP("org_id", "o", "", "openai requesting organization")
configCmd.PersistentFlags().StringP("proxy", "", "", "http proxy")
configCmd.PersistentFlags().StringP("socks", "", "", "socks proxy")
configCmd.PersistentFlags().IntP("diff_unified", "", 3, "generate diffs with <n> lines of context, default is 3")
configCmd.PersistentFlags().StringSliceP("exclude_list", "", []string{}, "exclude file from `git diff` command")

_ = viper.BindPFlag("openai.org_id", configCmd.PersistentFlags().Lookup("org_id"))
_ = viper.BindPFlag("openai.api_key", configCmd.PersistentFlags().Lookup("api_key"))
_ = viper.BindPFlag("openai.model", configCmd.PersistentFlags().Lookup("model"))
_ = viper.BindPFlag("openai.proxy", configCmd.PersistentFlags().Lookup("proxy"))
_ = viper.BindPFlag("openai.socks", configCmd.PersistentFlags().Lookup("socks"))
_ = viper.BindPFlag("output.lang", configCmd.PersistentFlags().Lookup("lang"))
_ = viper.BindPFlag("git.diff_unified", configCmd.PersistentFlags().Lookup("diff_unified"))
_ = viper.BindPFlag("git.exclude_list", configCmd.PersistentFlags().Lookup("exclude_list"))
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/sashabaranov/go-openai v1.4.1
github.com/spf13/cobra v1.6.1
github.com/spf13/viper v1.15.0
golang.org/x/net v0.4.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU=
golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
Expand Down
18 changes: 14 additions & 4 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package openai
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"

openai "github.com/sashabaranov/go-openai"
"golang.org/x/net/proxy"
)

// DefaultModel is the default OpenAI model to use if one is not provided.
Expand Down Expand Up @@ -140,17 +142,25 @@ func New(opts ...Option) (*Client, error) {
c.OrgID = cfg.orgID
}

httpClient := &http.Client{
Timeout: time.Second * 10,
}
if cfg.proxyURL != "" {
httpClient := &http.Client{
Timeout: time.Second * 10,
}
proxy, _ := url.Parse(cfg.proxyURL)
httpClient.Transport = &http.Transport{
Proxy: http.ProxyURL(proxy),
}
c.HTTPClient = httpClient
} else if cfg.socksURL != "" {
dialer, err := proxy.SOCKS5("tcp", cfg.socksURL, nil, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("can't connect to the proxy: %s", err)
}
httpClient.Transport = &http.Transport{
Dial: dialer.Dial,
}
}

c.HTTPClient = httpClient
instance.client = openai.NewClientWithConfig(c)

return instance, nil
Expand Down
8 changes: 8 additions & 0 deletions openai/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,18 @@ func WithProxyURL(val string) Option {
})
}

// WithSocksURL is a function that returns an Option, which sets the socksURL field of the config struct.
func WithSocksURL(val string) Option {
return optionFunc(func(c *config) {
c.socksURL = val
})
}

// config is a struct that stores configuration options for the instrumentation.
type config struct {
token string
orgID string
model string
proxyURL string
socksURL string
}

0 comments on commit 278219d

Please sign in to comment.