Skip to content

Commit

Permalink
test: adds unit test coverage checks (#249)
Browse files Browse the repository at this point in the history
**Commit Message**:

This introduces the unit test coverage checks
in GHA. Note that this doesn't utilize the integration
tests (test-extproc,test-controller and test-e2e), so
the coverage will be checked solely on the unit tests
without relying on the external binaries etc.

One left TODO is to remove the exclusion of controller
package which requires a bit of refactoring, so i will
follow up in a separate PR.

---------

Signed-off-by: Takeshi Yoneda <t.y.mathetake@gmail.com>
  • Loading branch information
mathetake authored Jan 30, 2025
1 parent dd43bba commit 8776563
Show file tree
Hide file tree
Showing 15 changed files with 474 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
~/go/bin
key: unittest-${{ hashFiles('**/go.mod', '**/go.sum', '**/Makefile') }}
- name: Run unit tests
run: make test
run: make test-coverage

test_cel_validation:
if: github.event_name == 'pull_request' || github.event_name == 'push'
Expand Down
21 changes: 21 additions & 0 deletions .testcoverage.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# This is the configuration file for /~https://github.com/vladopajic/go-test-coverage

profile: ./out/go-test-coverage.out
local-prefix: "github.com/envoyproxy/ai-gateway/"

threshold:
file: 70
# TODO: increase to 90.
package: 80
# TODO: increase to 95.
total: 83

exclude:
paths:
- ^api/
- ^examples/
- apischema/
- cmd/
- tests/internal/envtest.go
# TODO: Remove this exclusion.
- internal/controller
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ help:
@echo "All core targets needed for contributing:"
@echo " precommit Run all necessary steps to prepare for a commit."
@echo " test Run the unit tests for the codebase."
@echo " test-coverage Run the unit tests for the codebase with coverage check."
@echo " test-cel Run the integration tests of CEL validation rules in API definitions with envtest."
@echo " This will be needed when changing API definitions."
@echo " test-extproc Run the integration tests for extproc without controller or k8s at all."
Expand Down Expand Up @@ -162,6 +163,13 @@ test-e2e: kind
@echo "Run E2E tests"
@go test ./tests/e2e/... $(GO_TEST_ARGS) $(GO_TEST_E2E_ARGS) -tags test_e2e

# This runs the unit tests for the codebase with coverage check.
.PHONY: test-coverage
test-coverage: go-test-coverage
@mkdir -p $(OUTPUT_DIR)
@$(MAKE) test GO_TEST_ARGS="-coverprofile=$(OUTPUT_DIR)/go-test-coverage.out -covermode=atomic -coverpkg=./... $(GO_TEST_ARGS)"
@${GO_TEST_COVERAGE} --config=.testcoverage.yml

# This builds a binary for the given command under the internal/cmd directory.
#
# Example:
Expand Down
8 changes: 7 additions & 1 deletion Makefile.tools.mk
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CODESPELL = $(LOCALBIN)/.venv/codespell@v2.3.0/bin/codespell
YAMLLINT = $(LOCALBIN)/.venv/yamllint@1.35.1/bin/yamllint
KIND ?= $(LOCALBIN)/kind
CRD_REF_DOCS = $(LOCALBIN)/crd-ref-docs
GO_TEST_COVERAGE ?= $(LOCALBIN)/go-test-coverage

## Tool versions.
CONTROLLER_TOOLS_VERSION ?= v0.17.1
Expand All @@ -23,6 +24,7 @@ GCI_VERSION ?= v0.13.5
EDITORCONFIG_CHECKER_VERSION ?= v3.1.2
KIND_VERSION ?= v0.26.0
CRD_REF_DOCS_VERSION ?= v0.1.0
GO_TEST_COVERAGE_VERSION ?= v2.11.4

.PHONY: golangci-lint
golangci-lint: $(GOLANGCI_LINT)
Expand Down Expand Up @@ -59,11 +61,15 @@ kind: $(KIND)
$(KIND): $(LOCALBIN)
$(call go-install-tool,$(KIND),sigs.k8s.io/kind,$(KIND_VERSION))

.PHONY:
.PHONY: crd-ref-docs
crd-ref-docs: $(CRD_REF_DOCS)
$(CRD_REF_DOCS): $(LOCALBIN)
$(call go-install-tool,$(CRD_REF_DOCS),github.com/elastic/crd-ref-docs,$(CRD_REF_DOCS_VERSION))

.PHONY: go-test-coverage
go-test-coverage: $(GO_TEST_COVERAGE)
$(GO_TEST_COVERAGE): $(LOCALBIN)
$(call go-install-tool,$(GO_TEST_COVERAGE),github.com/vladopajic/go-test-coverage/v2,$(GO_TEST_COVERAGE_VERSION))

.bin/.venv/%:
mkdir -p $(@D)
Expand Down
11 changes: 11 additions & 0 deletions filterapi/filterconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,15 @@ rules:
require.Equal(t, "apikey.txt", cfg.Rules[0].Backends[0].Auth.APIKey.Filename)
require.Equal(t, "aws.txt", cfg.Rules[0].Backends[1].Auth.AWSAuth.CredentialFileName)
require.Equal(t, "us-east-1", cfg.Rules[0].Backends[1].Auth.AWSAuth.Region)

t.Run("not found", func(t *testing.T) {
_, err := filterapi.UnmarshalConfigYaml("not-found.yaml")
require.Error(t, err)
})
t.Run("invalid", func(t *testing.T) {
const invalidConfig = `{wefaf3q20,9u,f02`
require.NoError(t, os.WriteFile(configPath, []byte(invalidConfig), 0o600))
_, err := filterapi.UnmarshalConfigYaml(configPath)
require.Error(t, err)
})
}
30 changes: 30 additions & 0 deletions internal/extensionserver/extension_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package extensionserver

import (
"context"
"testing"

"github.com/go-logr/logr"
"github.com/stretchr/testify/require"
)

func TestNew(t *testing.T) {
logger := logr.Discard()
s := New(logger)
require.NotNil(t, s)
}

func TestCheck(t *testing.T) {
logger := logr.Discard()
s := New(logger)
_, err := s.Check(context.Background(), nil)
require.NoError(t, err)
}

func TestWatch(t *testing.T) {
logger := logr.Discard()
s := New(logger)
err := s.Watch(nil, nil)
require.Error(t, err)
require.Equal(t, "rpc error: code = Unimplemented desc = Watch is not implemented", err.Error())
}
47 changes: 47 additions & 0 deletions internal/extproc/backendauth/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package backendauth

import (
"os"
"testing"

"github.com/stretchr/testify/require"

"github.com/envoyproxy/ai-gateway/filterapi"
)

func TestNewHandler(t *testing.T) {
awsFile := t.TempDir() + "/awstest"
err := os.WriteFile(awsFile, []byte(`
[default]
aws_access_key_id = test
aws_secret_access_key = test
`), 0o600)
require.NoError(t, err)

apiKeyFile := t.TempDir() + "/apikey"
err = os.WriteFile(apiKeyFile, []byte("TEST"), 0o600)
require.NoError(t, err)

for _, tt := range []struct {
name string
config *filterapi.BackendAuth
}{
{
name: "AWSAuth",
config: &filterapi.BackendAuth{AWSAuth: &filterapi.AWSAuth{
Region: "us-west-2", CredentialFileName: awsFile,
}},
},
{
name: "APIKey",
config: &filterapi.BackendAuth{
APIKey: &filterapi.APIKeyAuth{Filename: apiKeyFile},
},
},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := NewHandler(tt.config)
require.NoError(t, err)
})
}
}
3 changes: 2 additions & 1 deletion internal/extproc/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package extproc
import (
"context"
"errors"
"io"
"log/slog"
"testing"

Expand Down Expand Up @@ -71,7 +72,7 @@ func TestProcessor_ProcessResponseBody(t *testing.T) {
require.NoError(t, err)
celProgUint, err := llmcostcel.NewProgram("uint(9999)")
require.NoError(t, err)
p := &Processor{translator: mt, logger: slog.Default(), config: &processorConfig{
p := &Processor{translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), config: &processorConfig{
metadataNamespace: "ai_gateway_llm_ns",
requestCosts: []processorConfigRequestCost{
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}},
Expand Down
5 changes: 1 addition & 4 deletions internal/extproc/translator/openai_awsbedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,9 +846,8 @@ func TestOpenAIToAWSBedrockTranslator_ResponseError(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body, err := json.Marshal(tt.input)
_, err := json.Marshal(tt.input)
require.NoError(t, err)
fmt.Println(string(body))

o := &openAIToAWSBedrockTranslatorV1ChatCompletion{}
hm, bm, err := o.ResponseError(tt.responseHeaders, tt.input)
Expand Down Expand Up @@ -1026,7 +1025,6 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T)
t.Run(tt.name, func(t *testing.T) {
body, err := json.Marshal(tt.input)
require.NoError(t, err)
fmt.Println(string(body))

o := &openAIToAWSBedrockTranslatorV1ChatCompletion{}
hm, bm, usedToken, err := o.ResponseBody(nil, bytes.NewBuffer(body), false)
Expand Down Expand Up @@ -1157,7 +1155,6 @@ func TestOpenAIToAWSBedrockTranslatorExtractAmazonEventStreamEvents(t *testing.T
var texts []string
var usage *awsbedrock.TokenUsage
for _, event := range o.events {
t.Log(event.String())
if delta := event.Delta; delta != nil && delta.Text != nil && *delta.Text != "" {
texts = append(texts, *event.Delta.Text)
}
Expand Down
6 changes: 5 additions & 1 deletion internal/extproc/translator/openai_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ func newOpenAIToOpenAITranslator(path string) (Translator, error) {

// openAIToOpenAITranslatorV1ChatCompletion implements [Translator] for /v1/chat/completions.
type openAIToOpenAITranslatorV1ChatCompletion struct {
defaultTranslator
stream bool
buffered []byte
bufferingDone bool
Expand Down Expand Up @@ -83,6 +82,11 @@ func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseError(respHeaders map
return nil, nil, nil
}

// ResponseHeaders implements [Translator.ResponseHeaders].
func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseHeaders(map[string]string) (headerMutation *extprocv3.HeaderMutation, err error) {
return nil, nil
}

// ResponseBody implements [Translator.ResponseBody].
func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseBody(respHeaders map[string]string, body io.Reader, _ bool) (
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, tokenUsage LLMTokenUsage, err error,
Expand Down
20 changes: 1 addition & 19 deletions internal/extproc/translator/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ type Translator interface {
// ResponseBody translates the response body.
// - `body` is the response body either chunk or the entire body, depending on the context.
// - This returns `headerMutation` and `bodyMutation` that can be nil to indicate no mutation.
// - This returns `usedToken` that is extracted from the body and will be used to do token rate limiting.
// - This returns `tokenUsage` that is extracted from the body and will be used to do token rate limiting.
//
// TODO: this is coupled with "LLM" specific. Once we have another use case, we need to refactor this.
ResponseBody(respHeaders map[string]string, body io.Reader, endOfStream bool) (
Expand All @@ -98,24 +98,6 @@ type Translator interface {
)
}

// defaultTranslator is a no-op translator that implements [Translator].
type defaultTranslator struct{}

// RequestBody implements [Translator.RequestBody].
func (d *defaultTranslator) RequestBody(*extprocv3.HttpBody) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, *extprocv3http.ProcessingMode, string, error) {
return nil, nil, nil, "", nil
}

// ResponseHeaders implements [Translator.ResponseBody].
func (d *defaultTranslator) ResponseHeaders(map[string]string) (*extprocv3.HeaderMutation, error) {
return nil, nil
}

// ResponseBody implements [Translator.ResponseBody].
func (d *defaultTranslator) ResponseBody(io.Reader, bool) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, uint32, error) {
return nil, nil, 0, nil
}

func setContentLength(headers *extprocv3.HeaderMutation, body []byte) {
headers.SetHeaders = append(headers.SetHeaders, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{
Expand Down
19 changes: 11 additions & 8 deletions internal/llmcostcel/cel.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,24 @@ const (
celTotalTokensKey = "total_tokens"
)

// NewProgram creates a new CEL program from the given expression.
func NewProgram(expr string) (prog cel.Program, err error) {
env, err := cel.NewEnv(
var env *cel.Env

func init() {
var err error
env, err = cel.NewEnv(
cel.Variable(celModelNameKey, cel.StringType),
cel.Variable(celBackendKey, cel.StringType),
cel.Variable(celInputTokensKey, cel.UintType),
cel.Variable(celOutputTokensKey, cel.UintType),
cel.Variable(celTotalTokensKey, cel.UintType),
)
if err != nil {
return nil, fmt.Errorf("cannot create CEL environment: %w", err)
panic(fmt.Sprintf("cannot create CEL environment: %v", err))
}
}

// NewProgram creates a new CEL program from the given expression.
func NewProgram(expr string) (prog cel.Program, err error) {
ast, issues := env.Compile(expr)
if issues != nil && issues.Err() != nil {
err := issues.Err()
Expand Down Expand Up @@ -57,12 +63,9 @@ func EvaluateProgram(prog cel.Program, modelName, backend string, inputTokens, o
celOutputTokensKey: outputTokens,
celTotalTokensKey: totalTokens,
})
if err != nil {
if err != nil || out == nil {
return 0, fmt.Errorf("failed to evaluate CEL expression: %w", err)
}
if out == nil {
return 0, fmt.Errorf("CEL expression result is nil")
}

switch out.Type() {
case cel.IntType:
Expand Down
37 changes: 37 additions & 0 deletions internal/llmcostcel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
)

func TestNewProgram(t *testing.T) {
t.Run("invalid", func(t *testing.T) {
_, err := NewProgram("1 +")
require.Error(t, err)
})
t.Run("int", func(t *testing.T) {
_, err := NewProgram("1 + 1")
require.NoError(t, err)
Expand All @@ -28,6 +32,39 @@ func TestNewProgram(t *testing.T) {
require.Equal(t, uint64(3), v)
})

t.Run("uint", func(t *testing.T) {
_, err := NewProgram("uint(1)-uint(1200)")
require.ErrorContains(t, err, "failed to evaluate CEL expression: failed to evaluate CEL expression: unsigned integer overflow")
})

t.Run("ensure concurrency safety", func(t *testing.T) {
// Ensure that the program can be evaluated concurrently.
var wg sync.WaitGroup
wg.Add(100)
for i := 0; i < 100; i++ {
go func() {
defer wg.Done()
_, err := NewProgram("model == 'cool_model' ? input_tokens * output_tokens : total_tokens")
require.NoError(t, err)
}()
}
wg.Wait()
})
}

func TestEvaluateProgram(t *testing.T) {
t.Run("signed integer negative", func(t *testing.T) {
prog, err := NewProgram("int(input_tokens) - int(output_tokens)")
require.NoError(t, err)
_, err = EvaluateProgram(prog, "cool_model", "cool_backend", 100, 2000, 3)
require.ErrorContains(t, err, "CEL expression result is negative (-1900)")
})
t.Run("unsigned integer overflow", func(t *testing.T) {
prog, err := NewProgram("input_tokens - output_tokens")
require.NoError(t, err)
_, err = EvaluateProgram(prog, "cool_model", "cool_backend", 100, 2000, 3)
require.ErrorContains(t, err, "failed to evaluate CEL expression: unsigned integer overflow")
})
t.Run("ensure concurrency safety", func(t *testing.T) {
prog, err := NewProgram("model == 'cool_model' ? input_tokens * output_tokens : total_tokens")
require.NoError(t, err)
Expand Down
Loading

0 comments on commit 8776563

Please sign in to comment.