Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add azure openai support #214

Merged
merged 26 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
35ca745
feat: add azure openai support
ttys3 Mar 31, 2023
4ef2708
chore: refine config
ttys3 Mar 31, 2023
350e961
chore: make config options like the python one
ttys3 Mar 31, 2023
61287b2
chore: adjust config struct field order
ttys3 Mar 31, 2023
9f06a7b
test: fix tests
ttys3 Mar 31, 2023
41a5ae5
style: make the linter happy
ttys3 Mar 31, 2023
8cdf9a8
fix: support Azure API Key authentication in sendRequest
ttys3 Mar 31, 2023
f5157ea
chore: check error in CreateChatCompletionStream
ttys3 Mar 31, 2023
a8a6b7e
chore: pass tests
ttys3 Mar 31, 2023
ca1315e
chore: try pass tests again
ttys3 Mar 31, 2023
5da05cf
chore: change ClientConfig back due to this lib does not like WithXxx…
ttys3 Apr 3, 2023
94438d4
chore: revert fix to CreateChatCompletionStream() due to cause tests …
ttys3 Apr 3, 2023
37eecb8
chore: at least add some comment about the required fields
ttys3 Apr 3, 2023
60cafde
chore: re order ClientConfig fields
ttys3 Apr 3, 2023
006ac90
chore: add DefaultAzure()
ttys3 Apr 3, 2023
84ea881
chore: set default api_version the same as py one "2023-03-15-preview"
ttys3 Apr 3, 2023
becc1bb
style: fixup typo
ttys3 Apr 3, 2023
f5c822d
test: add api_internal_test.go
ttys3 Apr 3, 2023
ee4bd5d
style: make lint happy
ttys3 Apr 3, 2023
416dc26
chore: add constant AzureAPIKeyHeader
ttys3 Apr 3, 2023
27be572
chore: use AzureAPIKeyHeader for api-key header, fix azure base url a…
ttys3 Apr 3, 2023
ebef9a9
test: add TestAzureFullURL, TestRequestAuthHeader and TestOpenAIFullURL
ttys3 Apr 3, 2023
862ce64
test: simplify TestRequestAuthHeader
ttys3 Apr 3, 2023
115595b
test: refine TestOpenAIFullURL
ttys3 Apr 3, 2023
59461ec
chore: refine comments
ttys3 Apr 3, 2023
222aa40
feat: DefaultAzureConfig
ttys3 Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
)

// Client is OpenAI GPT-3 API client.
Expand Down Expand Up @@ -39,7 +40,13 @@ func NewOrgClient(authToken, org string) *Client {

func (c *Client) sendRequest(req *http.Request, v interface{}) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}

// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
Expand Down Expand Up @@ -83,6 +90,15 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
}

func (c *Client) fullURL(suffix string) string {
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion)
}

// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
}

Expand All @@ -100,7 +116,14 @@ func (c *Client) newStreamRequest(
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))

// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
return req, nil
}
133 changes: 133 additions & 0 deletions api_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package openai

import (
"context"
"testing"
)

func TestOpenAIFullURL(t *testing.T) {
cases := []struct {
Name string
Suffix string
Expect string
}{
{
"ChatCompletionsURL",
"/chat/completions",
"https://api.openai.com/v1/chat/completions",
},
{
"CompletionsURL",
"/completions",
"https://api.openai.com/v1/completions",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultConfig("dummy")
cli := NewClientWithConfig(az)
actual := cli.fullURL(c.Suffix)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
t.Logf("Full URL: %s", actual)
})
}
}

func TestRequestAuthHeader(t *testing.T) {
cases := []struct {
Name string
APIType APIType
HeaderKey string
Token string
Expect string
}{
{
"OpenAIDefault",
"",
"Authorization",
"dummy-token-openai",
"Bearer dummy-token-openai",
},
{
"OpenAI",
APITypeOpenAI,
"Authorization",
"dummy-token-openai",
"Bearer dummy-token-openai",
},
{
"AzureAD",
APITypeAzureAD,
"Authorization",
"dummy-token-azure",
"Bearer dummy-token-azure",
},
{
"Azure",
APITypeAzure,
AzureAPIKeyHeader,
"dummy-api-key-here",
"dummy-api-key-here",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultConfig(c.Token)
az.APIType = c.APIType

cli := NewClientWithConfig(az)
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil)
if err != nil {
t.Errorf("Failed to create request: %v", err)
}
actual := req.Header.Get(c.HeaderKey)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
t.Logf("%s: %s", c.HeaderKey, actual)
})
}
}

func TestAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Engine string
Expect string
}{
{
"AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-03-15-preview",
},
{
"AzureBaseURLWithoutSlashOK",
"https://httpbin.org",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-03-15-preview",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine)
cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions")
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
t.Logf("Full URL: %s", actual)
})
}
}
45 changes: 39 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,61 @@ import (
)

const (
apiURLv1 = "https://api.openai.com/v1"
openaiAPIURLv1 = "https://api.openai.com/v1"
defaultEmptyMessagesLimit uint = 300

azureAPIPrefix = "openai"
azureDeploymentsPrefix = "deployments"
)

type APIType string

const (
APITypeOpenAI APIType = "OPEN_AI"
sashabaranov marked this conversation as resolved.
Show resolved Hide resolved
APITypeAzure APIType = "AZURE"
APITypeAzureAD APIType = "AZURE_AD"
)

const AzureAPIKeyHeader = "api-key"

// ClientConfig is a configuration of a client.
type ClientConfig struct {
authToken string

HTTPClient *http.Client
BaseURL string
OrgID string
APIType APIType
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
Engine string // required when APIType is APITypeAzure or APITypeAzureAD

BaseURL string
OrgID string
HTTPClient *http.Client

EmptyMessagesLimit uint
}

func DefaultConfig(authToken string) ClientConfig {
ttys3 marked this conversation as resolved.
Show resolved Hide resolved
return ClientConfig{
authToken: authToken,
BaseURL: openaiAPIURLv1,
APIType: APITypeOpenAI,
OrgID: "",

HTTPClient: &http.Client{},
BaseURL: apiURLv1,

EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
}

func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
return ClientConfig{
authToken: apiKey,
BaseURL: baseURL,
OrgID: "",
authToken: authToken,
APIType: APITypeAzure,
APIVersion: "2023-03-15-preview",
Engine: engine,

HTTPClient: &http.Client{},

EmptyMessagesLimit: defaultEmptyMessagesLimit,
}
Expand Down