diff --git a/internal/grpctest/grpctest.go b/internal/grpctest/grpctest.go index b92e17dc362e..b8bc385807c0 100644 --- a/internal/grpctest/grpctest.go +++ b/internal/grpctest/grpctest.go @@ -20,6 +20,7 @@ package grpctest import ( + "context" "reflect" "strings" "sync/atomic" @@ -58,15 +59,19 @@ func (Tester) Setup(t *testing.T) { // completely addressed, and this can be turned back on as soon as this issue is // fixed. leakcheck.SetTrackingBufferPool(logger{t: t}) + leakcheck.TrackTimers() } // Teardown performs a leak check. func (Tester) Teardown(t *testing.T) { leakcheck.CheckTrackingBufferPool() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + leakcheck.CheckTimers(ctx, logger{t: t}) if atomic.LoadUint32(&lcFailed) == 1 { return } - leakcheck.CheckGoroutines(logger{t: t}, 10*time.Second) + leakcheck.CheckGoroutines(ctx, logger{t: t}) if atomic.LoadUint32(&lcFailed) == 1 { t.Log("Goroutine leak check disabled for future tests") } diff --git a/internal/internal.go b/internal/internal.go index 13e1f386b1cb..2ce012cda135 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -259,6 +259,13 @@ var ( // SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for // testing purposes. SetBufferPoolingThresholdForTesting any // func(int) + + // TimeAfterFunc is used to create timers. During tests the function is + // replaced to track allocated timers and fail the test if a timer isn't + // cancelled. + TimeAfterFunc = func(d time.Duration, f func()) Timer { + return time.AfterFunc(d, f) + } ) // HealthChecker defines the signature of the client-side LB channel health @@ -300,3 +307,9 @@ type EnforceSubConnEmbedding interface { type EnforceClientConnEmbedding interface { enforceClientConnEmbedding() } + +// Timer is an interface to allow injecting different time.Timer implementations +// during tests. +type Timer interface { + Stop() bool +} diff --git a/internal/leakcheck/leakcheck.go b/internal/leakcheck/leakcheck.go index 90830848830b..2927fb377403 100644 --- a/internal/leakcheck/leakcheck.go +++ b/internal/leakcheck/leakcheck.go @@ -24,6 +24,8 @@ package leakcheck import ( + "context" + "fmt" "runtime" "runtime/debug" "slices" @@ -53,6 +55,7 @@ func init() { } var globalPool swappableBufferPool +var globalTimerTracker *timerFactory type swappableBufferPool struct { atomic.Pointer[mem.BufferPool] @@ -81,7 +84,7 @@ func SetTrackingBufferPool(logger Logger) { // CheckTrackingBufferPool undoes the effects of SetTrackingBufferPool, and fails // unit tests if not all buffers were returned. It is invalid to invoke this -// method without previously having invoked SetTrackingBufferPool. +// function without previously having invoked SetTrackingBufferPool. func CheckTrackingBufferPool() { p := (*globalPool.Load()).(*trackingBufferPool) p.lock.Lock() @@ -148,24 +151,9 @@ type trackingBufferPool struct { func (p *trackingBufferPool) Get(length int) *[]byte { p.lock.Lock() defer p.lock.Unlock() - p.bufferCount++ - buf := p.pool.Get(length) - - var stackBuf [16]uintptr - var stack []uintptr - skip := 2 - for { - n := runtime.Callers(skip, stackBuf[:]) - stack = append(stack, stackBuf[:n]...) - if n < len(stackBuf) { - break - } - skip += len(stackBuf) - } - p.allocatedBuffers[buf] = stack - + p.allocatedBuffers[buf] = currentStack(2) return buf } @@ -257,12 +245,11 @@ type Logger interface { // CheckGoroutines looks at the currently-running goroutines and checks if there // are any interesting (created by gRPC) goroutines leaked. It waits up to 10 // seconds in the error cases. -func CheckGoroutines(logger Logger, timeout time.Duration) { +func CheckGoroutines(ctx context.Context, logger Logger) { // Loop, waiting for goroutines to shut down. // Wait up to timeout, but finish as quickly as possible. - deadline := time.Now().Add(timeout) var leaked []string - for time.Now().Before(deadline) { + for ctx.Err() == nil { if leaked = interestingGoroutines(); len(leaked) == 0 { return } @@ -279,13 +266,6 @@ type LeakChecker struct { logger Logger } -// Check executes the leak check tests, failing the unit test if any buffer or -// goroutine leaks are detected. -func (lc *LeakChecker) Check() { - CheckTrackingBufferPool() - CheckGoroutines(lc.logger, 10*time.Second) -} - // NewLeakChecker offers a convenient way to set up the leak checks for a // specific unit test. It can be used as follows, at the beginning of tests: // @@ -298,3 +278,119 @@ func NewLeakChecker(logger Logger) *LeakChecker { SetTrackingBufferPool(logger) return &LeakChecker{logger: logger} } + +type timerFactory struct { + mu sync.Mutex + allocatedTimers map[internal.Timer][]uintptr +} + +func (tf *timerFactory) timeAfterFunc(d time.Duration, f func()) internal.Timer { + tf.mu.Lock() + defer tf.mu.Unlock() + ch := make(chan internal.Timer, 1) + timer := time.AfterFunc(d, func() { + f() + tf.remove(<-ch) + }) + ch <- timer + tf.allocatedTimers[timer] = currentStack(2) + return &trackingTimer{ + Timer: timer, + parent: tf, + } +} + +func (tf *timerFactory) remove(timer internal.Timer) { + tf.mu.Lock() + defer tf.mu.Unlock() + delete(tf.allocatedTimers, timer) +} + +func (tf *timerFactory) pendingTimers() []string { + tf.mu.Lock() + defer tf.mu.Unlock() + leaked := []string{} + for _, stack := range tf.allocatedTimers { + leaked = append(leaked, fmt.Sprintf("Allocated timer never cancelled:\n%s", traceToString(stack))) + } + return leaked +} + +type trackingTimer struct { + internal.Timer + parent *timerFactory +} + +func (t *trackingTimer) Stop() bool { + t.parent.remove(t.Timer) + return t.Timer.Stop() +} + +// TrackTimers replaces internal.TimerAfterFunc with one that tracks timer +// creations, stoppages and expirations. CheckTimers should then be invoked at +// the end of the test to validate that all timers created have either executed +// or are cancelled. +func TrackTimers() { + globalTimerTracker = &timerFactory{ + allocatedTimers: make(map[internal.Timer][]uintptr), + } + internal.TimeAfterFunc = globalTimerTracker.timeAfterFunc +} + +// CheckTimers undoes the effects of TrackTimers, and fails unit tests if not +// all timers were cancelled or executed. It is invalid to invoke this function +// without previously having invoked TrackTimers. +func CheckTimers(ctx context.Context, logger Logger) { + tt := globalTimerTracker + + // Loop, waiting for timers to be cancelled. + // Wait up to timeout, but finish as quickly as possible. + var leaked []string + for ctx.Err() == nil { + if leaked = tt.pendingTimers(); len(leaked) == 0 { + return + } + time.Sleep(50 * time.Millisecond) + } + for _, g := range leaked { + logger.Errorf("Leaked timers: %v", g) + } + + // Reset the internal function. + internal.TimeAfterFunc = func(d time.Duration, f func()) internal.Timer { + return time.AfterFunc(d, f) + } +} + +func currentStack(skip int) []uintptr { + var stackBuf [16]uintptr + var stack []uintptr + skip++ + for { + n := runtime.Callers(skip, stackBuf[:]) + stack = append(stack, stackBuf[:n]...) + if n < len(stackBuf) { + break + } + skip += len(stackBuf) + } + return stack +} + +func traceToString(stack []uintptr) string { + frames := runtime.CallersFrames(stack) + var trace strings.Builder + for { + f, ok := frames.Next() + if !ok { + break + } + trace.WriteString(f.Function) + trace.WriteString("\n\t") + trace.WriteString(f.File) + trace.WriteString(":") + trace.WriteString(strconv.Itoa(f.Line)) + trace.WriteString("\n") + } + return trace.String() +} diff --git a/internal/leakcheck/leakcheck_test.go b/internal/leakcheck/leakcheck_test.go index a0f67ffb1b44..26ecfc145799 100644 --- a/internal/leakcheck/leakcheck_test.go +++ b/internal/leakcheck/leakcheck_test.go @@ -19,10 +19,14 @@ package leakcheck import ( + "context" "fmt" "strings" + "sync" "testing" "time" + + "google.golang.org/grpc/internal" ) type testLogger struct { @@ -47,12 +51,16 @@ func TestCheck(t *testing.T) { t.Error("blah") } e := &testLogger{} - CheckGoroutines(e, time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + CheckGoroutines(ctx, e) if e.errorCount != leakCount { t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount) t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n")) } - CheckGoroutines(t, 3*time.Second) + ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + CheckGoroutines(ctx, t) } func ignoredTestingLeak(d time.Duration) { @@ -70,10 +78,55 @@ func TestCheckRegisterIgnore(t *testing.T) { t.Error("blah") } e := &testLogger{} - CheckGoroutines(e, time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + CheckGoroutines(ctx, e) if e.errorCount != leakCount { t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount) t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n")) } - CheckGoroutines(t, 3*time.Second) + ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + CheckGoroutines(ctx, t) +} + +// TestTrackTimers verifies that only leaked timers are reported and expired, +// stopped timers are ignored. +func TestTrackTimers(t *testing.T) { + TrackTimers() + const leakCount = 3 + for i := 0; i < leakCount; i++ { + internal.TimeAfterFunc(2*time.Second, func() { + t.Logf("Timer %d fired.", i) + }) + } + wg := sync.WaitGroup{} + // Let a couple of timers expire. + for i := 0; i < 2; i++ { + wg.Add(1) + internal.TimeAfterFunc(time.Millisecond, func() { + wg.Done() + }) + } + wg.Wait() + + // Stop a couple of timers. + for i := 0; i < leakCount; i++ { + t := internal.TimeAfterFunc(time.Hour, func() { + t.Error("Timer fired before test ended.") + }) + t.Stop() + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + e := &testLogger{} + CheckTimers(ctx, e) + if e.errorCount != leakCount { + t.Errorf("CheckTimers found %v leaks, want %v leaks", e.errorCount, leakCount) + t.Logf("leaked timers:\n%v", strings.Join(e.errors, "\n")) + } + ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + CheckTimers(ctx, t) } diff --git a/internal/transport/client_stream.go b/internal/transport/client_stream.go index 8ed347c54195..ccc0e017e5e5 100644 --- a/internal/transport/client_stream.go +++ b/internal/transport/client_stream.go @@ -59,7 +59,7 @@ func (s *ClientStream) Read(n int) (mem.BufferSlice, error) { return b, err } -// Close closes the stream and popagates err to any readers. +// Close closes the stream and propagates err to any readers. func (s *ClientStream) Close(err error) { var ( rst bool diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 997b0a59b586..7e53eb1735ef 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -35,6 +35,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/pretty" @@ -598,6 +599,22 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade if len(t.activeStreams) == 1 { t.idle = time.Time{} } + // Start a timer to close the stream on reaching the deadline. + if timeoutSet { + // We need to wait for s.cancel to be updated before calling + // t.closeStream to avoid data races. + cancelUpdated := make(chan struct{}) + timer := internal.TimeAfterFunc(timeout, func() { + <-cancelUpdated + t.closeStream(s, true, http2.ErrCodeCancel, false) + }) + oldCancel := s.cancel + s.cancel = func() { + oldCancel() + timer.Stop() + } + close(cancelUpdated) + } t.mu.Unlock() if channelz.IsOn() { t.channelz.SocketMetrics.StreamsStarted.Add(1) @@ -1274,7 +1291,6 @@ func (t *http2Server) Close(err error) { // deleteStream deletes the stream s from transport's active streams. func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) { - t.mu.Lock() if _, ok := t.activeStreams[s.id]; ok { delete(t.activeStreams, s.id) @@ -1324,7 +1340,10 @@ func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCo // called to interrupt the potential blocking on other goroutines. s.cancel() - s.swapState(streamDone) + oldState := s.swapState(streamDone) + if oldState == streamDone { + return + } t.deleteStream(s, eosReceived) t.controlBuf.put(&cleanupStream{ diff --git a/internal/transport/server_stream.go b/internal/transport/server_stream.go index a22a90151494..cf8da0b52d0a 100644 --- a/internal/transport/server_stream.go +++ b/internal/transport/server_stream.go @@ -35,8 +35,10 @@ type ServerStream struct { *Stream // Embed for common stream functionality. st internalServerTransport - ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance) - cancel context.CancelFunc // invoked at the end of stream to cancel ctx. + ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance) + // cancel is invoked at the end of stream to cancel ctx. It also stops the + // timer for monitoring the rpc deadline if configured. + cancel func() // Holds compressor names passed in grpc-accept-encoding metadata from the // client. diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 1220b16b909a..8b1219597912 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -834,7 +834,9 @@ func (s) TestGracefulClose(t *testing.T) { server.lis.Close() // Check for goroutine leaks (i.e. GracefulClose with an active stream // doesn't eventually close the connection when that stream completes). - leakcheck.CheckGoroutines(t, 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + leakcheck.CheckGoroutines(ctx, t) leakcheck.CheckTrackingBufferPool() // Correctly clean up the server server.stop() @@ -895,6 +897,8 @@ func (s) TestGracefulClose(t *testing.T) { func (s) TestLargeMessageSuspension(t *testing.T) { server, ct, cancel := setUp(t, 0, suspended) defer cancel() + defer ct.Close(fmt.Errorf("closed manually by test")) + defer server.stop() callHdr := &CallHdr{ Host: "localhost", Method: "foo.Large", @@ -906,12 +910,6 @@ func (s) TestLargeMessageSuspension(t *testing.T) { if err != nil { t.Fatalf("failed to open stream: %v", err) } - // Launch a goroutine similar to the stream monitoring goroutine in - // stream.go to keep track of context timeout and call CloseStream. - go func() { - <-ctx.Done() - s.Close(ContextErr(ctx.Err())) - }() // Write should not be done successfully due to flow control. msg := make([]byte, initialWindowSize*8) s.Write(nil, newBufferSlice(msg), &WriteOptions{}) @@ -919,12 +917,14 @@ func (s) TestLargeMessageSuspension(t *testing.T) { if err != errStreamDone { t.Fatalf("Write got %v, want io.EOF", err) } - expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) - if _, err := s.readTo(make([]byte, 8)); err.Error() != expectedErr.Error() { - t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) + // The server will send an RST stream frame on observing the deadline + // expiration making the client stream fail with a DeadlineExceeded status. + if _, err := s.readTo(make([]byte, 8)); err != io.EOF { + t.Fatalf("Read got unexpected error: %v, want %v", err, io.EOF) + } + if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want { + t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want) } - ct.Close(fmt.Errorf("closed manually by test")) - server.stop() } func (s) TestMaxStreams(t *testing.T) { @@ -2973,3 +2973,90 @@ func (s) TestReadMessageHeaderMultipleBuffers(t *testing.T) { t.Errorf("bytesRead = %d, want = %d", bytesRead, headerLen) } } + +// Tests a scenario when the client doesn't send an RST frame when the +// configured deadline is reached. The test verifies that the server sends an +// RST stream only after the deadline is reached. +func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) { + server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) + defer server.stop() + // Create a client that can override server stream quota. + mconn, err := net.Dial("tcp", server.lis.Addr().String()) + if err != nil { + t.Fatalf("Clent failed to dial:%v", err) + } + defer mconn.Close() + if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil { + t.Fatalf("Failed to set write deadline: %v", err) + } + if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { + t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + } + // rstTimeChan chan indicates that reader received a RSTStream from server. + rstTimeChan := make(chan time.Time, 1) + var mu sync.Mutex + framer := http2.NewFramer(mconn, mconn) + if err := framer.WriteSettings(); err != nil { + t.Fatalf("Error while writing settings: %v", err) + } + go func() { // Launch a reader for this misbehaving client. + for { + frame, err := framer.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.PingFrame: + // Write ping ack back so that server's BDP estimation works right. + mu.Lock() + framer.WritePing(true, frame.Data) + mu.Unlock() + case *http2.RSTStreamFrame: + if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeCancel { + t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeCancel", frame.Header().StreamID, http2.ErrCode(frame.ErrCode)) + } + rstTimeChan <- time.Now() + return + default: + // Do nothing. + } + } + }() + // Create a stream. + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + if err := henc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: "10m"}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } + mu.Lock() + startTime := time.Now() + if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { + mu.Unlock() + t.Fatalf("Error while writing headers: %v", err) + } + mu.Unlock() + + // Test server behavior for deadline expiration. + var rstTime time.Time + select { + case <-time.After(5 * time.Second): + t.Fatalf("Test timed out.") + case rstTime = <-rstTimeChan: + } + + if got, want := rstTime.Sub(startTime), 10*time.Millisecond; got < want { + t.Fatalf("RST frame received earlier than expected by duration: %v", want-got) + } +}