From 87765633380c7cc83557c2a3d2cd72d56f527da1 Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Thu, 30 Jan 2025 15:53:50 -0800 Subject: [PATCH] test: adds unit test coverage checks (#249) **Commit Message**: This introduces the unit test coverage checks in GHA. Note that this doesn't utilize the integration tests (test-extproc,test-controller and test-e2e), so the coverage will be checked solely on the unit tests without relying on the external binaries etc. One left TODO is to remove the exclusion of controller package which requires a bit of refactoring, so i will follow up in a separate PR. --------- Signed-off-by: Takeshi Yoneda --- .github/workflows/tests.yaml | 2 +- .testcoverage.yml | 21 ++ Makefile | 8 + Makefile.tools.mk | 8 +- filterapi/filterconfig_test.go | 11 + .../extensionserver/extension_server_test.go | 30 +++ internal/extproc/backendauth/auth_test.go | 47 ++++ internal/extproc/processor_test.go | 3 +- .../translator/openai_awsbedrock_test.go | 5 +- internal/extproc/translator/openai_openai.go | 6 +- internal/extproc/translator/translator.go | 20 +- internal/llmcostcel/cel.go | 19 +- internal/llmcostcel/cel_test.go | 37 +++ .../testupstreamlib/testupstream/main.go | 101 ++++---- .../testupstreamlib/testupstream/main_test.go | 243 ++++++++++++++++++ 15 files changed, 474 insertions(+), 87 deletions(-) create mode 100644 .testcoverage.yml create mode 100644 internal/extensionserver/extension_server_test.go create mode 100644 internal/extproc/backendauth/auth_test.go diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 1ebf8db9d..803699879 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -48,7 +48,7 @@ jobs: ~/go/bin key: unittest-${{ hashFiles('**/go.mod', '**/go.sum', '**/Makefile') }} - name: Run unit tests - run: make test + run: make test-coverage test_cel_validation: if: github.event_name == 'pull_request' || github.event_name == 'push' diff --git a/.testcoverage.yml b/.testcoverage.yml new file mode 100644 index 000000000..a8bfcaa16 --- /dev/null +++ b/.testcoverage.yml @@ -0,0 +1,21 @@ +# This is the configuration file for /~https://github.com/vladopajic/go-test-coverage + +profile: ./out/go-test-coverage.out +local-prefix: "github.com/envoyproxy/ai-gateway/" + +threshold: + file: 70 + # TODO: increase to 90. + package: 80 + # TODO: increase to 95. + total: 83 + +exclude: + paths: + - ^api/ + - ^examples/ + - apischema/ + - cmd/ + - tests/internal/envtest.go + # TODO: Remove this exclusion. + - internal/controller diff --git a/Makefile b/Makefile index 843253537..b9d8e6fe5 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,7 @@ help: @echo "All core targets needed for contributing:" @echo " precommit Run all necessary steps to prepare for a commit." @echo " test Run the unit tests for the codebase." + @echo " test-coverage Run the unit tests for the codebase with coverage check." @echo " test-cel Run the integration tests of CEL validation rules in API definitions with envtest." @echo " This will be needed when changing API definitions." @echo " test-extproc Run the integration tests for extproc without controller or k8s at all." @@ -162,6 +163,13 @@ test-e2e: kind @echo "Run E2E tests" @go test ./tests/e2e/... $(GO_TEST_ARGS) $(GO_TEST_E2E_ARGS) -tags test_e2e +# This runs the unit tests for the codebase with coverage check. +.PHONY: test-coverage +test-coverage: go-test-coverage + @mkdir -p $(OUTPUT_DIR) + @$(MAKE) test GO_TEST_ARGS="-coverprofile=$(OUTPUT_DIR)/go-test-coverage.out -covermode=atomic -coverpkg=./... $(GO_TEST_ARGS)" + @${GO_TEST_COVERAGE} --config=.testcoverage.yml + # This builds a binary for the given command under the internal/cmd directory. # # Example: diff --git a/Makefile.tools.mk b/Makefile.tools.mk index e6f239642..e0744432a 100644 --- a/Makefile.tools.mk +++ b/Makefile.tools.mk @@ -13,6 +13,7 @@ CODESPELL = $(LOCALBIN)/.venv/codespell@v2.3.0/bin/codespell YAMLLINT = $(LOCALBIN)/.venv/yamllint@1.35.1/bin/yamllint KIND ?= $(LOCALBIN)/kind CRD_REF_DOCS = $(LOCALBIN)/crd-ref-docs +GO_TEST_COVERAGE ?= $(LOCALBIN)/go-test-coverage ## Tool versions. CONTROLLER_TOOLS_VERSION ?= v0.17.1 @@ -23,6 +24,7 @@ GCI_VERSION ?= v0.13.5 EDITORCONFIG_CHECKER_VERSION ?= v3.1.2 KIND_VERSION ?= v0.26.0 CRD_REF_DOCS_VERSION ?= v0.1.0 +GO_TEST_COVERAGE_VERSION ?= v2.11.4 .PHONY: golangci-lint golangci-lint: $(GOLANGCI_LINT) @@ -59,11 +61,15 @@ kind: $(KIND) $(KIND): $(LOCALBIN) $(call go-install-tool,$(KIND),sigs.k8s.io/kind,$(KIND_VERSION)) -.PHONY: +.PHONY: crd-ref-docs crd-ref-docs: $(CRD_REF_DOCS) $(CRD_REF_DOCS): $(LOCALBIN) $(call go-install-tool,$(CRD_REF_DOCS),github.com/elastic/crd-ref-docs,$(CRD_REF_DOCS_VERSION)) +.PHONY: go-test-coverage +go-test-coverage: $(GO_TEST_COVERAGE) +$(GO_TEST_COVERAGE): $(LOCALBIN) + $(call go-install-tool,$(GO_TEST_COVERAGE),github.com/vladopajic/go-test-coverage/v2,$(GO_TEST_COVERAGE_VERSION)) .bin/.venv/%: mkdir -p $(@D) diff --git a/filterapi/filterconfig_test.go b/filterapi/filterconfig_test.go index aca35fe17..89913c91f 100644 --- a/filterapi/filterconfig_test.go +++ b/filterapi/filterconfig_test.go @@ -85,4 +85,15 @@ rules: require.Equal(t, "apikey.txt", cfg.Rules[0].Backends[0].Auth.APIKey.Filename) require.Equal(t, "aws.txt", cfg.Rules[0].Backends[1].Auth.AWSAuth.CredentialFileName) require.Equal(t, "us-east-1", cfg.Rules[0].Backends[1].Auth.AWSAuth.Region) + + t.Run("not found", func(t *testing.T) { + _, err := filterapi.UnmarshalConfigYaml("not-found.yaml") + require.Error(t, err) + }) + t.Run("invalid", func(t *testing.T) { + const invalidConfig = `{wefaf3q20,9u,f02` + require.NoError(t, os.WriteFile(configPath, []byte(invalidConfig), 0o600)) + _, err := filterapi.UnmarshalConfigYaml(configPath) + require.Error(t, err) + }) } diff --git a/internal/extensionserver/extension_server_test.go b/internal/extensionserver/extension_server_test.go new file mode 100644 index 000000000..8066bb8bb --- /dev/null +++ b/internal/extensionserver/extension_server_test.go @@ -0,0 +1,30 @@ +package extensionserver + +import ( + "context" + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + logger := logr.Discard() + s := New(logger) + require.NotNil(t, s) +} + +func TestCheck(t *testing.T) { + logger := logr.Discard() + s := New(logger) + _, err := s.Check(context.Background(), nil) + require.NoError(t, err) +} + +func TestWatch(t *testing.T) { + logger := logr.Discard() + s := New(logger) + err := s.Watch(nil, nil) + require.Error(t, err) + require.Equal(t, "rpc error: code = Unimplemented desc = Watch is not implemented", err.Error()) +} diff --git a/internal/extproc/backendauth/auth_test.go b/internal/extproc/backendauth/auth_test.go new file mode 100644 index 000000000..90d0ab977 --- /dev/null +++ b/internal/extproc/backendauth/auth_test.go @@ -0,0 +1,47 @@ +package backendauth + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/filterapi" +) + +func TestNewHandler(t *testing.T) { + awsFile := t.TempDir() + "/awstest" + err := os.WriteFile(awsFile, []byte(` +[default] +aws_access_key_id = test +aws_secret_access_key = test +`), 0o600) + require.NoError(t, err) + + apiKeyFile := t.TempDir() + "/apikey" + err = os.WriteFile(apiKeyFile, []byte("TEST"), 0o600) + require.NoError(t, err) + + for _, tt := range []struct { + name string + config *filterapi.BackendAuth + }{ + { + name: "AWSAuth", + config: &filterapi.BackendAuth{AWSAuth: &filterapi.AWSAuth{ + Region: "us-west-2", CredentialFileName: awsFile, + }}, + }, + { + name: "APIKey", + config: &filterapi.BackendAuth{ + APIKey: &filterapi.APIKeyAuth{Filename: apiKeyFile}, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + _, err := NewHandler(tt.config) + require.NoError(t, err) + }) + } +} diff --git a/internal/extproc/processor_test.go b/internal/extproc/processor_test.go index 862f84751..23ece77e2 100644 --- a/internal/extproc/processor_test.go +++ b/internal/extproc/processor_test.go @@ -3,6 +3,7 @@ package extproc import ( "context" "errors" + "io" "log/slog" "testing" @@ -71,7 +72,7 @@ func TestProcessor_ProcessResponseBody(t *testing.T) { require.NoError(t, err) celProgUint, err := llmcostcel.NewProgram("uint(9999)") require.NoError(t, err) - p := &Processor{translator: mt, logger: slog.Default(), config: &processorConfig{ + p := &Processor{translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), config: &processorConfig{ metadataNamespace: "ai_gateway_llm_ns", requestCosts: []processorConfigRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}}, diff --git a/internal/extproc/translator/openai_awsbedrock_test.go b/internal/extproc/translator/openai_awsbedrock_test.go index 9af959867..22bb23268 100644 --- a/internal/extproc/translator/openai_awsbedrock_test.go +++ b/internal/extproc/translator/openai_awsbedrock_test.go @@ -846,9 +846,8 @@ func TestOpenAIToAWSBedrockTranslator_ResponseError(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - body, err := json.Marshal(tt.input) + _, err := json.Marshal(tt.input) require.NoError(t, err) - fmt.Println(string(body)) o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} hm, bm, err := o.ResponseError(tt.responseHeaders, tt.input) @@ -1026,7 +1025,6 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) t.Run(tt.name, func(t *testing.T) { body, err := json.Marshal(tt.input) require.NoError(t, err) - fmt.Println(string(body)) o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} hm, bm, usedToken, err := o.ResponseBody(nil, bytes.NewBuffer(body), false) @@ -1157,7 +1155,6 @@ func TestOpenAIToAWSBedrockTranslatorExtractAmazonEventStreamEvents(t *testing.T var texts []string var usage *awsbedrock.TokenUsage for _, event := range o.events { - t.Log(event.String()) if delta := event.Delta; delta != nil && delta.Text != nil && *delta.Text != "" { texts = append(texts, *event.Delta.Text) } diff --git a/internal/extproc/translator/openai_openai.go b/internal/extproc/translator/openai_openai.go index ce16d55ba..14b29a4b9 100644 --- a/internal/extproc/translator/openai_openai.go +++ b/internal/extproc/translator/openai_openai.go @@ -25,7 +25,6 @@ func newOpenAIToOpenAITranslator(path string) (Translator, error) { // openAIToOpenAITranslatorV1ChatCompletion implements [Translator] for /v1/chat/completions. type openAIToOpenAITranslatorV1ChatCompletion struct { - defaultTranslator stream bool buffered []byte bufferingDone bool @@ -83,6 +82,11 @@ func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseError(respHeaders map return nil, nil, nil } +// ResponseHeaders implements [Translator.ResponseHeaders]. +func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseHeaders(map[string]string) (headerMutation *extprocv3.HeaderMutation, err error) { + return nil, nil +} + // ResponseBody implements [Translator.ResponseBody]. func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseBody(respHeaders map[string]string, body io.Reader, _ bool) ( headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, err error, diff --git a/internal/extproc/translator/translator.go b/internal/extproc/translator/translator.go index aa2b1e90e..fb041da30 100644 --- a/internal/extproc/translator/translator.go +++ b/internal/extproc/translator/translator.go @@ -77,7 +77,7 @@ type Translator interface { // ResponseBody translates the response body. // - `body` is the response body either chunk or the entire body, depending on the context. // - This returns `headerMutation` and `bodyMutation` that can be nil to indicate no mutation. - // - This returns `usedToken` that is extracted from the body and will be used to do token rate limiting. + // - This returns `tokenUsage` that is extracted from the body and will be used to do token rate limiting. // // TODO: this is coupled with "LLM" specific. Once we have another use case, we need to refactor this. ResponseBody(respHeaders map[string]string, body io.Reader, endOfStream bool) ( @@ -98,24 +98,6 @@ type Translator interface { ) } -// defaultTranslator is a no-op translator that implements [Translator]. -type defaultTranslator struct{} - -// RequestBody implements [Translator.RequestBody]. -func (d *defaultTranslator) RequestBody(*extprocv3.HttpBody) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, *extprocv3http.ProcessingMode, string, error) { - return nil, nil, nil, "", nil -} - -// ResponseHeaders implements [Translator.ResponseBody]. -func (d *defaultTranslator) ResponseHeaders(map[string]string) (*extprocv3.HeaderMutation, error) { - return nil, nil -} - -// ResponseBody implements [Translator.ResponseBody]. -func (d *defaultTranslator) ResponseBody(io.Reader, bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, uint32, error) { - return nil, nil, 0, nil -} - func setContentLength(headers *extprocv3.HeaderMutation, body []byte) { headers.SetHeaders = append(headers.SetHeaders, &corev3.HeaderValueOption{ Header: &corev3.HeaderValue{ diff --git a/internal/llmcostcel/cel.go b/internal/llmcostcel/cel.go index 8c7331343..bd67d3f60 100644 --- a/internal/llmcostcel/cel.go +++ b/internal/llmcostcel/cel.go @@ -18,9 +18,11 @@ const ( celTotalTokensKey = "total_tokens" ) -// NewProgram creates a new CEL program from the given expression. -func NewProgram(expr string) (prog cel.Program, err error) { - env, err := cel.NewEnv( +var env *cel.Env + +func init() { + var err error + env, err = cel.NewEnv( cel.Variable(celModelNameKey, cel.StringType), cel.Variable(celBackendKey, cel.StringType), cel.Variable(celInputTokensKey, cel.UintType), @@ -28,8 +30,12 @@ func NewProgram(expr string) (prog cel.Program, err error) { cel.Variable(celTotalTokensKey, cel.UintType), ) if err != nil { - return nil, fmt.Errorf("cannot create CEL environment: %w", err) + panic(fmt.Sprintf("cannot create CEL environment: %v", err)) } +} + +// NewProgram creates a new CEL program from the given expression. +func NewProgram(expr string) (prog cel.Program, err error) { ast, issues := env.Compile(expr) if issues != nil && issues.Err() != nil { err := issues.Err() @@ -57,12 +63,9 @@ func EvaluateProgram(prog cel.Program, modelName, backend string, inputTokens, o celOutputTokensKey: outputTokens, celTotalTokensKey: totalTokens, }) - if err != nil { + if err != nil || out == nil { return 0, fmt.Errorf("failed to evaluate CEL expression: %w", err) } - if out == nil { - return 0, fmt.Errorf("CEL expression result is nil") - } switch out.Type() { case cel.IntType: diff --git a/internal/llmcostcel/cel_test.go b/internal/llmcostcel/cel_test.go index 1e30da1b8..9cae420ba 100644 --- a/internal/llmcostcel/cel_test.go +++ b/internal/llmcostcel/cel_test.go @@ -8,6 +8,10 @@ import ( ) func TestNewProgram(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + _, err := NewProgram("1 +") + require.Error(t, err) + }) t.Run("int", func(t *testing.T) { _, err := NewProgram("1 + 1") require.NoError(t, err) @@ -28,6 +32,39 @@ func TestNewProgram(t *testing.T) { require.Equal(t, uint64(3), v) }) + t.Run("uint", func(t *testing.T) { + _, err := NewProgram("uint(1)-uint(1200)") + require.ErrorContains(t, err, "failed to evaluate CEL expression: failed to evaluate CEL expression: unsigned integer overflow") + }) + + t.Run("ensure concurrency safety", func(t *testing.T) { + // Ensure that the program can be evaluated concurrently. + var wg sync.WaitGroup + wg.Add(100) + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + _, err := NewProgram("model == 'cool_model' ? input_tokens * output_tokens : total_tokens") + require.NoError(t, err) + }() + } + wg.Wait() + }) +} + +func TestEvaluateProgram(t *testing.T) { + t.Run("signed integer negative", func(t *testing.T) { + prog, err := NewProgram("int(input_tokens) - int(output_tokens)") + require.NoError(t, err) + _, err = EvaluateProgram(prog, "cool_model", "cool_backend", 100, 2000, 3) + require.ErrorContains(t, err, "CEL expression result is negative (-1900)") + }) + t.Run("unsigned integer overflow", func(t *testing.T) { + prog, err := NewProgram("input_tokens - output_tokens") + require.NoError(t, err) + _, err = EvaluateProgram(prog, "cool_model", "cool_backend", 100, 2000, 3) + require.ErrorContains(t, err, "failed to evaluate CEL expression: unsigned integer overflow") + }) t.Run("ensure concurrency safety", func(t *testing.T) { prog, err := NewProgram("model == 'cool_model' ? input_tokens * output_tokens : total_tokens") require.NoError(t, err) diff --git a/tests/internal/testupstreamlib/testupstream/main.go b/tests/internal/testupstreamlib/testupstream/main.go index 6457061f8..6301767d1 100644 --- a/tests/internal/testupstreamlib/testupstream/main.go +++ b/tests/internal/testupstreamlib/testupstream/main.go @@ -60,13 +60,13 @@ func handler(w http.ResponseWriter, r *http.Request) { } if v := r.Header.Get(testupstreamlib.ExpectedHostKey); v != "" { if r.Host != v { - fmt.Printf("unexpected host: got %q, expected %q\n", r.Host, v) + logger.Printf("unexpected host: got %q, expected %q\n", r.Host, v) http.Error(w, "unexpected host: got "+r.Host+", expected "+v, http.StatusBadRequest) return } - fmt.Println("host matched:", v) + logger.Println("host matched:", v) } else { - fmt.Println("no expected host: got", r.Host) + logger.Println("no expected host: got", r.Host) } if v := r.Header.Get(testupstreamlib.ExpectedHeadersKey); v != "" { expectedHeaders, err := base64.StdEncoding.DecodeString(v) @@ -212,7 +212,6 @@ func handler(w http.ResponseWriter, r *http.Request) { switch r.Header.Get(testupstreamlib.ResponseTypeKey) { case "sse": w.Header().Set("Content-Type", "text/event-stream") - w.WriteHeader(status) expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(testupstreamlib.ResponseBodyHeaderKey)) if err != nil { @@ -221,6 +220,7 @@ func handler(w http.ResponseWriter, r *http.Request) { return } + w.WriteHeader(status) for _, line := range bytes.Split(expResponseBody, []byte("\n")) { line := string(line) if line == "" { @@ -244,7 +244,6 @@ func handler(w http.ResponseWriter, r *http.Request) { r.Context().Done() case "aws-event-stream": w.Header().Set("Content-Type", "application/vnd.amazon.eventstream") - w.WriteHeader(status) expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(testupstreamlib.ResponseBodyHeaderKey)) if err != nil { @@ -253,6 +252,7 @@ func handler(w http.ResponseWriter, r *http.Request) { return } + w.WriteHeader(status) e := eventstream.NewEncoder() for _, line := range bytes.Split(expResponseBody, []byte("\n")) { // Write each line as a chunk with AWS Event Stream format. @@ -281,7 +281,6 @@ func handler(w http.ResponseWriter, r *http.Request) { r.Context().Done() default: w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) var responseBody []byte if expResponseBody := r.Header.Get(testupstreamlib.ResponseBodyHeaderKey); expResponseBody == "" { @@ -301,61 +300,59 @@ func handler(w http.ResponseWriter, r *http.Request) { } } - if _, err := w.Write(responseBody); err != nil { - logger.Println("failed to write the response body") - } + w.WriteHeader(status) + _, _ = w.Write(responseBody) logger.Println("response sent:", string(responseBody)) } } -var ( - r = rand.New(rand.NewSource(uint64(time.Now().UnixNano()))) - chatCompletionFakeResponses = []string{ - `This is a test.`, - `The quick brown fox jumps over the lazy dog.`, - `Lorem ipsum dolor sit amet, consectetur adipiscing elit.`, - `To be or not to be, that is the question.`, - `All your base are belong to us.`, - `I am the bone of my sword.`, - `I am the master of my fate.`, - `I am the captain of my soul.`, - `I am the master of my fate, I am the captain of my soul.`, - `I am the bone of my sword, steel is my body, and fire is my blood.`, - `The quick brown fox jumps over the lazy dog.`, - `Lorem ipsum dolor sit amet, consectetur adipiscing elit.`, - `To be or not to be, that is the question.`, - `All your base are belong to us.`, - `Omae wa mou shindeiru.`, - `Nani?`, - `I am inevitable.`, - `May the Force be with you.`, - `Houston, we have a problem.`, - `I'll be back.`, - `You can't handle the truth!`, - `Here's looking at you, kid.`, - `Go ahead, make my day.`, - `I see dead people.`, - `Hasta la vista, baby.`, - `You're gonna need a bigger boat.`, - `E.T. phone home.`, - `I feel the need - the need for speed.`, - `I'm king of the world!`, - `Show me the money!`, - `You had me at hello.`, - `I'm the king of the world!`, - `To infinity and beyond!`, - `You're a wizard, Harry.`, - `I solemnly swear that I am up to no good.`, - `Mischief managed.`, - `Expecto Patronum!`, - } -) +var chatCompletionFakeResponses = []string{ + `This is a test.`, + `The quick brown fox jumps over the lazy dog.`, + `Lorem ipsum dolor sit amet, consectetur adipiscing elit.`, + `To be or not to be, that is the question.`, + `All your base are belong to us.`, + `I am the bone of my sword.`, + `I am the master of my fate.`, + `I am the captain of my soul.`, + `I am the master of my fate, I am the captain of my soul.`, + `I am the bone of my sword, steel is my body, and fire is my blood.`, + `The quick brown fox jumps over the lazy dog.`, + `Lorem ipsum dolor sit amet, consectetur adipiscing elit.`, + `To be or not to be, that is the question.`, + `All your base are belong to us.`, + `Omae wa mou shindeiru.`, + `Nani?`, + `I am inevitable.`, + `May the Force be with you.`, + `Houston, we have a problem.`, + `I'll be back.`, + `You can't handle the truth!`, + `Here's looking at you, kid.`, + `Go ahead, make my day.`, + `I see dead people.`, + `Hasta la vista, baby.`, + `You're gonna need a bigger boat.`, + `E.T. phone home.`, + `I feel the need - the need for speed.`, + `I'm king of the world!`, + `Show me the money!`, + `You had me at hello.`, + `I'm the king of the world!`, + `To infinity and beyond!`, + `You're a wizard, Harry.`, + `I solemnly swear that I am up to no good.`, + `Mischief managed.`, + `Expecto Patronum!`, +} func getFakeResponse(path string) ([]byte, error) { switch path { case "/v1/chat/completions": const template = `{"choices":[{"message":{"content":"%s"}}]}` - msg := fmt.Sprintf(template, chatCompletionFakeResponses[r.Intn(len(chatCompletionFakeResponses))]) + msg := fmt.Sprintf(template, + chatCompletionFakeResponses[rand.New(rand.NewSource(uint64(time.Now().UnixNano()))). //nolint:gosec + Intn(len(chatCompletionFakeResponses))]) return []byte(msg), nil default: return nil, fmt.Errorf("unknown path: %s", path) diff --git a/tests/internal/testupstreamlib/testupstream/main_test.go b/tests/internal/testupstreamlib/testupstream/main_test.go index 57c13734f..2ed58eea4 100644 --- a/tests/internal/testupstreamlib/testupstream/main_test.go +++ b/tests/internal/testupstreamlib/testupstream/main_test.go @@ -6,8 +6,10 @@ import ( "encoding/base64" "fmt" "io" + "log" "net" "net/http" + "os" "strings" "testing" "time" @@ -19,6 +21,11 @@ import ( "github.com/envoyproxy/ai-gateway/tests/internal/testupstreamlib" ) +func TestMain(m *testing.M) { + logger = log.New(io.Discard, "", 0) + os.Exit(m.Run()) +} + func Test_main(t *testing.T) { t.Setenv("TESTUPSTREAM_ID", "aaaaaaaaa") t.Setenv("STREAMING_INTERVAL", "200ms") @@ -179,6 +186,31 @@ func Test_main(t *testing.T) { require.Equal(t, "aaaaaaaaa", response.Header.Get("testupstream-id")) }) + t.Run("invalid response body", func(t *testing.T) { + for _, eventType := range []string{"sse", "aws-event-stream"} { + t.Run(eventType, func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/v1/chat/completions", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + request.Header.Set(testupstreamlib.ResponseTypeKey, eventType) + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, + base64.StdEncoding.EncodeToString([]byte("/v1/chat/completions"))) + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ResponseBodyHeaderKey, "09i,30qg9i4,gq03,gq0") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + } + }) + t.Run("fake response", func(t *testing.T) { t.Parallel() request, err := http.NewRequest("GET", @@ -207,6 +239,24 @@ func Test_main(t *testing.T) { require.Contains(t, chatCompletionFakeResponses, chat.Choices[0].Message.Content) }) + t.Run("fake response for unknown path", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/foo", nil) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, + base64.StdEncoding.EncodeToString([]byte("/foo"))) + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + t.Run("aws-event-stream", func(t *testing.T) { t.Parallel() request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", nil) @@ -245,4 +295,197 @@ func Test_main(t *testing.T) { _, err = decoder.Decode(response.Body, nil) require.Equal(t, io.EOF, err) }) + + t.Run("expected host not match", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, + base64.StdEncoding.EncodeToString([]byte("/"))) + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ExpectedHostKey, + base64.StdEncoding.EncodeToString([]byte("example.com"))) + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + + t.Run("expected host match", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/v1/chat/completions", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Host = "localhost" + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ExpectedHostKey, "localhost") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusOK, response.StatusCode) + }) + + t.Run("expected headers invalid encoding", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, + base64.StdEncoding.EncodeToString([]byte("/"))) + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ExpectedHeadersKey, "fewoamfwoajfum092um3f") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + + t.Run("expected headers invalid pairs", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, + base64.StdEncoding.EncodeToString([]byte("/"))) + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ExpectedHeadersKey, + base64.StdEncoding.EncodeToString([]byte("x-baz"))) // Missing value. + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + + t.Run("expected headers not match", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, + base64.StdEncoding.EncodeToString([]byte("/"))) + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ExpectedHeadersKey, + base64.StdEncoding.EncodeToString([]byte("x-foo:bar,x-baz:qux"))) + + request.Header.Set("x-foo", "not-bar") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + + t.Run("non expected headers invalid encoding", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, + base64.StdEncoding.EncodeToString([]byte("/"))) + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.NonExpectedRequestHeadersKey, "fewoamfwoajfum092um3f") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + + t.Run("expected test upstream id", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/v1/chat/completions", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ExpectedTestUpstreamIDKey, "aaaaaaaaa") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusOK, response.StatusCode) + }) + + t.Run("expected test upstream id not match", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/v1/chat/completions", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, + base64.StdEncoding.EncodeToString([]byte("expected request body"))) + request.Header.Set(testupstreamlib.ExpectedTestUpstreamIDKey, "bbbbbbbbb") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + + t.Run("expected path invalid encoding", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedPathHeaderKey, "fewoamfwoajfum092um3f") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) + + t.Run("expected request body invalid encoding", func(t *testing.T) { + t.Parallel() + request, err := http.NewRequest("GET", + "http://"+l.Addr().String()+"/", bytes.NewBuffer([]byte("expected request body"))) + require.NoError(t, err) + + request.Header.Set(testupstreamlib.ExpectedRequestBodyHeaderKey, "fewoamfwoajfum092um3f") + + response, err := http.DefaultClient.Do(request) + require.NoError(t, err) + defer func() { + _ = response.Body.Close() + }() + require.Equal(t, http.StatusBadRequest, response.StatusCode) + }) }