Skip to content

Commit

Permalink
transport: Send RST stream from the server when deadline expires (#8071)
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal authored Feb 28, 2025
1 parent 7505bf2 commit 0d6e39f
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 49 deletions.
7 changes: 6 additions & 1 deletion internal/grpctest/grpctest.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package grpctest

import (
"context"
"reflect"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -58,15 +59,19 @@ func (Tester) Setup(t *testing.T) {
// completely addressed, and this can be turned back on as soon as this issue is
// fixed.
leakcheck.SetTrackingBufferPool(logger{t: t})
leakcheck.TrackTimers()
}

// Teardown performs a leak check.
func (Tester) Teardown(t *testing.T) {
leakcheck.CheckTrackingBufferPool()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
leakcheck.CheckTimers(ctx, logger{t: t})
if atomic.LoadUint32(&lcFailed) == 1 {
return
}
leakcheck.CheckGoroutines(logger{t: t}, 10*time.Second)
leakcheck.CheckGoroutines(ctx, logger{t: t})
if atomic.LoadUint32(&lcFailed) == 1 {
t.Log("Goroutine leak check disabled for future tests")
}
Expand Down
13 changes: 13 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ var (
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
// testing purposes.
SetBufferPoolingThresholdForTesting any // func(int)

// TimeAfterFunc is used to create timers. During tests the function is
// replaced to track allocated timers and fail the test if a timer isn't
// cancelled.
TimeAfterFunc = func(d time.Duration, f func()) Timer {
return time.AfterFunc(d, f)
}
)

// HealthChecker defines the signature of the client-side LB channel health
Expand Down Expand Up @@ -300,3 +307,9 @@ type EnforceSubConnEmbedding interface {
type EnforceClientConnEmbedding interface {
enforceClientConnEmbedding()
}

// Timer is an interface to allow injecting different time.Timer implementations
// during tests.
type Timer interface {
Stop() bool
}
150 changes: 123 additions & 27 deletions internal/leakcheck/leakcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
package leakcheck

import (
"context"
"fmt"
"runtime"
"runtime/debug"
"slices"
Expand Down Expand Up @@ -53,6 +55,7 @@ func init() {
}

var globalPool swappableBufferPool
var globalTimerTracker *timerFactory

type swappableBufferPool struct {
atomic.Pointer[mem.BufferPool]
Expand Down Expand Up @@ -81,7 +84,7 @@ func SetTrackingBufferPool(logger Logger) {

// CheckTrackingBufferPool undoes the effects of SetTrackingBufferPool, and fails
// unit tests if not all buffers were returned. It is invalid to invoke this
// method without previously having invoked SetTrackingBufferPool.
// function without previously having invoked SetTrackingBufferPool.
func CheckTrackingBufferPool() {
p := (*globalPool.Load()).(*trackingBufferPool)
p.lock.Lock()
Expand Down Expand Up @@ -148,24 +151,9 @@ type trackingBufferPool struct {
func (p *trackingBufferPool) Get(length int) *[]byte {
p.lock.Lock()
defer p.lock.Unlock()

p.bufferCount++

buf := p.pool.Get(length)

var stackBuf [16]uintptr
var stack []uintptr
skip := 2
for {
n := runtime.Callers(skip, stackBuf[:])
stack = append(stack, stackBuf[:n]...)
if n < len(stackBuf) {
break
}
skip += len(stackBuf)
}
p.allocatedBuffers[buf] = stack

p.allocatedBuffers[buf] = currentStack(2)
return buf
}

Expand Down Expand Up @@ -257,12 +245,11 @@ type Logger interface {
// CheckGoroutines looks at the currently-running goroutines and checks if there
// are any interesting (created by gRPC) goroutines leaked. It waits up to 10
// seconds in the error cases.
func CheckGoroutines(logger Logger, timeout time.Duration) {
func CheckGoroutines(ctx context.Context, logger Logger) {
// Loop, waiting for goroutines to shut down.
// Wait up to timeout, but finish as quickly as possible.
deadline := time.Now().Add(timeout)
var leaked []string
for time.Now().Before(deadline) {
for ctx.Err() == nil {
if leaked = interestingGoroutines(); len(leaked) == 0 {
return
}
Expand All @@ -279,13 +266,6 @@ type LeakChecker struct {
logger Logger
}

// Check executes the leak check tests, failing the unit test if any buffer or
// goroutine leaks are detected.
func (lc *LeakChecker) Check() {
CheckTrackingBufferPool()
CheckGoroutines(lc.logger, 10*time.Second)
}

// NewLeakChecker offers a convenient way to set up the leak checks for a
// specific unit test. It can be used as follows, at the beginning of tests:
//
Expand All @@ -298,3 +278,119 @@ func NewLeakChecker(logger Logger) *LeakChecker {
SetTrackingBufferPool(logger)
return &LeakChecker{logger: logger}
}

type timerFactory struct {
mu sync.Mutex
allocatedTimers map[internal.Timer][]uintptr
}

func (tf *timerFactory) timeAfterFunc(d time.Duration, f func()) internal.Timer {
tf.mu.Lock()
defer tf.mu.Unlock()
ch := make(chan internal.Timer, 1)
timer := time.AfterFunc(d, func() {
f()
tf.remove(<-ch)
})
ch <- timer
tf.allocatedTimers[timer] = currentStack(2)
return &trackingTimer{
Timer: timer,
parent: tf,
}
}

func (tf *timerFactory) remove(timer internal.Timer) {
tf.mu.Lock()
defer tf.mu.Unlock()
delete(tf.allocatedTimers, timer)
}

func (tf *timerFactory) pendingTimers() []string {
tf.mu.Lock()
defer tf.mu.Unlock()
leaked := []string{}
for _, stack := range tf.allocatedTimers {
leaked = append(leaked, fmt.Sprintf("Allocated timer never cancelled:\n%s", traceToString(stack)))
}
return leaked
}

type trackingTimer struct {
internal.Timer
parent *timerFactory
}

func (t *trackingTimer) Stop() bool {
t.parent.remove(t.Timer)
return t.Timer.Stop()
}

// TrackTimers replaces internal.TimerAfterFunc with one that tracks timer
// creations, stoppages and expirations. CheckTimers should then be invoked at
// the end of the test to validate that all timers created have either executed
// or are cancelled.
func TrackTimers() {
globalTimerTracker = &timerFactory{
allocatedTimers: make(map[internal.Timer][]uintptr),
}
internal.TimeAfterFunc = globalTimerTracker.timeAfterFunc
}

// CheckTimers undoes the effects of TrackTimers, and fails unit tests if not
// all timers were cancelled or executed. It is invalid to invoke this function
// without previously having invoked TrackTimers.
func CheckTimers(ctx context.Context, logger Logger) {
tt := globalTimerTracker

// Loop, waiting for timers to be cancelled.
// Wait up to timeout, but finish as quickly as possible.
var leaked []string
for ctx.Err() == nil {
if leaked = tt.pendingTimers(); len(leaked) == 0 {
return
}
time.Sleep(50 * time.Millisecond)
}
for _, g := range leaked {
logger.Errorf("Leaked timers: %v", g)
}

// Reset the internal function.
internal.TimeAfterFunc = func(d time.Duration, f func()) internal.Timer {
return time.AfterFunc(d, f)
}
}

func currentStack(skip int) []uintptr {
var stackBuf [16]uintptr
var stack []uintptr
skip++
for {
n := runtime.Callers(skip, stackBuf[:])
stack = append(stack, stackBuf[:n]...)
if n < len(stackBuf) {
break
}
skip += len(stackBuf)
}
return stack
}

func traceToString(stack []uintptr) string {
frames := runtime.CallersFrames(stack)
var trace strings.Builder
for {
f, ok := frames.Next()
if !ok {
break
}
trace.WriteString(f.Function)
trace.WriteString("\n\t")
trace.WriteString(f.File)
trace.WriteString(":")
trace.WriteString(strconv.Itoa(f.Line))
trace.WriteString("\n")
}
return trace.String()
}
61 changes: 57 additions & 4 deletions internal/leakcheck/leakcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
package leakcheck

import (
"context"
"fmt"
"strings"
"sync"
"testing"
"time"

"google.golang.org/grpc/internal"
)

type testLogger struct {
Expand All @@ -47,12 +51,16 @@ func TestCheck(t *testing.T) {
t.Error("blah")
}
e := &testLogger{}
CheckGoroutines(e, time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
CheckGoroutines(ctx, e)
if e.errorCount != leakCount {
t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount)
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
}
CheckGoroutines(t, 3*time.Second)
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
CheckGoroutines(ctx, t)
}

func ignoredTestingLeak(d time.Duration) {
Expand All @@ -70,10 +78,55 @@ func TestCheckRegisterIgnore(t *testing.T) {
t.Error("blah")
}
e := &testLogger{}
CheckGoroutines(e, time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
CheckGoroutines(ctx, e)
if e.errorCount != leakCount {
t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount)
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
}
CheckGoroutines(t, 3*time.Second)
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
CheckGoroutines(ctx, t)
}

// TestTrackTimers verifies that only leaked timers are reported and expired,
// stopped timers are ignored.
func TestTrackTimers(t *testing.T) {
TrackTimers()
const leakCount = 3
for i := 0; i < leakCount; i++ {
internal.TimeAfterFunc(2*time.Second, func() {
t.Logf("Timer %d fired.", i)
})
}
wg := sync.WaitGroup{}
// Let a couple of timers expire.
for i := 0; i < 2; i++ {
wg.Add(1)
internal.TimeAfterFunc(time.Millisecond, func() {
wg.Done()
})
}
wg.Wait()

// Stop a couple of timers.
for i := 0; i < leakCount; i++ {
t := internal.TimeAfterFunc(time.Hour, func() {
t.Error("Timer fired before test ended.")
})
t.Stop()
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
e := &testLogger{}
CheckTimers(ctx, e)
if e.errorCount != leakCount {
t.Errorf("CheckTimers found %v leaks, want %v leaks", e.errorCount, leakCount)
t.Logf("leaked timers:\n%v", strings.Join(e.errors, "\n"))
}
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
CheckTimers(ctx, t)
}
2 changes: 1 addition & 1 deletion internal/transport/client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *ClientStream) Read(n int) (mem.BufferSlice, error) {
return b, err
}

// Close closes the stream and popagates err to any readers.
// Close closes the stream and propagates err to any readers.
func (s *ClientStream) Close(err error) {
var (
rst bool
Expand Down
Loading

0 comments on commit 0d6e39f

Please sign in to comment.