Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703) #6708

Merged
merged 3 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
ID: http2.SettingMaxFrameSize,
Val: http2MaxFrameLen,
}}
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
// permitted in the HTTP2 spec.
maxStreams := config.MaxStreams
if maxStreams == 0 {
maxStreams = math.MaxUint32
} else {
if config.MaxStreams != math.MaxUint32 {
isettings = append(isettings, http2.Setting{
ID: http2.SettingMaxConcurrentStreams,
Val: maxStreams,
Val: config.MaxStreams,
})
}
dynamicWindow := true
Expand Down Expand Up @@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
maxStreams: maxStreams,
maxStreams: config.MaxStreams,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,
Expand Down
35 changes: 19 additions & 16 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
if err != nil {
return
}
if serverConfig.MaxStreams == 0 {
serverConfig.MaxStreams = math.MaxUint32
}
transport, err := NewServerTransport(conn, serverConfig)
if err != nil {
return
Expand Down Expand Up @@ -442,8 +445,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
return server
}

func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
}

func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
Expand Down Expand Up @@ -538,7 +541,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {

// Tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
callHdr := &CallHdr{
Expand Down Expand Up @@ -583,7 +586,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
}

func (s) TestClientSendAndReceive(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -623,7 +626,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
}

func (s) TestClientErrorNotify(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
go server.stop()
// ct.reader should detect the error and activate ct.Error().
Expand Down Expand Up @@ -657,7 +660,7 @@ func performOneRPC(ct ClientTransport) {
}

func (s) TestClientMix(t *testing.T) {
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
s, ct, cancel := setUp(t, 0, normal)
defer cancel()
time.AfterFunc(time.Second, s.stop)
go func(ct ClientTransport) {
Expand All @@ -671,7 +674,7 @@ func (s) TestClientMix(t *testing.T) {
}

func (s) TestLargeMessage(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -806,7 +809,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
// proceed until they complete naturally, while not allowing creation of new
// streams during this window.
func (s) TestGracefulClose(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
server, ct, cancel := setUp(t, 0, pingpong)
defer cancel()
defer func() {
// Stop the server's listener to make the server's goroutines terminate
Expand Down Expand Up @@ -872,7 +875,7 @@ func (s) TestGracefulClose(t *testing.T) {
}

func (s) TestLargeMessageSuspension(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -980,7 +983,7 @@ func (s) TestMaxStreams(t *testing.T) {
}

func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -1452,7 +1455,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
var encodingTestStatus = status.New(codes.Internal, "\n")

func (s) TestEncodingRequiredStatus(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -1480,7 +1483,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
}

func (s) TestInvalidHeaderField(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand All @@ -1502,7 +1505,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
}

func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
defer server.stop()
defer ct.Close(fmt.Errorf("closed manually by test"))
Expand Down Expand Up @@ -2170,7 +2173,7 @@ func (s) TestPingPong1MB(t *testing.T) {

// This is a stress-test of flow control logic.
func runPingPongTest(t *testing.T, msgSize int) {
server, client, cancel := setUp(t, 0, 0, pingpong)
server, client, cancel := setUp(t, 0, pingpong)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
Expand Down Expand Up @@ -2252,7 +2255,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
}
}()

server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer ct.Close(fmt.Errorf("closed manually by test"))
defer server.stop()
Expand Down Expand Up @@ -2611,7 +2614,7 @@ func TestConnectionError_Unwrap(t *testing.T) {

func (s) TestPeerSetInServerContext(t *testing.T) {
// create client and server transports.
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
server, client, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
Expand Down
69 changes: 48 additions & 21 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,6 @@ type serviceInfo struct {
mdata interface{}
}

type serverWorkerData struct {
st transport.ServerTransport
wg *sync.WaitGroup
stream *transport.Stream
}

// Server is a gRPC server to serve RPC requests.
type Server struct {
opts serverOptions
Expand All @@ -145,7 +139,7 @@ type Server struct {
channelzID *channelz.Identifier
czData *channelzData

serverWorkerChannel chan *serverWorkerData
serverWorkerChannel chan func()
}

type serverOptions struct {
Expand Down Expand Up @@ -177,6 +171,7 @@ type serverOptions struct {
}

var defaultServerOptions = serverOptions{
maxConcurrentStreams: math.MaxUint32,
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
maxSendMessageSize: defaultServerMaxSendMessageSize,
connectionTimeout: 120 * time.Second,
Expand Down Expand Up @@ -387,6 +382,9 @@ func MaxSendMsgSize(m int) ServerOption {
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
if n == 0 {
n = math.MaxUint32
}
return newFuncServerOption(func(o *serverOptions) {
o.maxConcurrentStreams = n
})
Expand Down Expand Up @@ -567,24 +565,19 @@ const serverWorkerResetThreshold = 1 << 16
// [1] /~https://github.com/golang/go/issues/18138
func (s *Server) serverWorker() {
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
data, ok := <-s.serverWorkerChannel
f, ok := <-s.serverWorkerChannel
if !ok {
return
}
s.handleSingleStream(data)
f()
}
go s.serverWorker()
}

func (s *Server) handleSingleStream(data *serverWorkerData) {
defer data.wg.Done()
s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream))
}

// initServerWorkers creates worker goroutines and a channel to process incoming
// connections to reduce the time spent overall on runtime.morestack.
func (s *Server) initServerWorkers() {
s.serverWorkerChannel = make(chan *serverWorkerData)
s.serverWorkerChannel = make(chan func())
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
go s.serverWorker()
}
Expand Down Expand Up @@ -943,21 +936,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
defer st.Close(errors.New("finished serving streams for the server transport"))
var wg sync.WaitGroup

streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(func(stream *transport.Stream) {
wg.Add(1)

streamQuota.acquire()
f := func() {
defer streamQuota.release()
defer wg.Done()
s.handleStream(st, stream, s.traceInfo(st, stream))
}

if s.opts.numServerWorkers > 0 {
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
select {
case s.serverWorkerChannel <- data:
case s.serverWorkerChannel <- f:
return
default:
// If all stream workers are busy, fallback to the default code path.
}
}
go func() {
defer wg.Done()
s.handleStream(st, stream, s.traceInfo(st, stream))
}()
go f()
}, func(ctx context.Context, method string) context.Context {
if !EnableTracing {
return ctx
Expand Down Expand Up @@ -2052,3 +2050,32 @@ func validateSendCompressor(name, clientCompressors string) error {
}
return fmt.Errorf("client does not support compressor %q", name)
}

// atomicSemaphore implements a blocking, counting semaphore. acquire should be
// called synchronously; release may be called asynchronously.
type atomicSemaphore struct {
n int64
wait chan struct{}
}

func (q *atomicSemaphore) acquire() {
if atomic.AddInt64(&q.n, -1) < 0 {
// We ran out of quota. Block until a release happens.
<-q.wait
}
}

func (q *atomicSemaphore) release() {
// N.B. the "<= 0" check below should allow for this to work with multiple
// concurrent calls to acquire, but also note that with synchronous calls to
// acquire, as our system does, n will never be less than -1. There are
// fairness issues (queuing) to consider if this was to be generalized.
if atomic.AddInt64(&q.n, 1) <= 0 {
// An acquire was waiting on us. Unblock it.
q.wait <- struct{}{}
}
}

func newHandlerQuota(n uint32) *atomicSemaphore {
return &atomicSemaphore{n: int64(n), wait: make(chan struct{}, 1)}
}
Loading