diff --git a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java index c94c05ffafd..7f088509c04 100644 --- a/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java +++ b/netty/src/main/java/io/grpc/netty/AbstractNettyHandler.java @@ -16,10 +16,12 @@ package io.grpc.netty; +import static com.google.common.base.Preconditions.checkNotNull; import static io.netty.handler.codec.http2.Http2CodecUtil.getEmbeddedHttp2Exception; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.base.Ticker; import io.grpc.ChannelLogger; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; @@ -44,6 +46,7 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { private boolean autoTuneFlowControlOn; private ChannelHandlerContext ctx; private boolean initialWindowSent = false; + private final Ticker ticker; private static final long BDP_MEASUREMENT_PING = 1234; @@ -54,7 +57,8 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { Http2Settings initialSettings, ChannelLogger negotiationLogger, boolean autoFlowControl, - PingLimiter pingLimiter) { + PingLimiter pingLimiter, + Ticker ticker) { super(channelUnused, decoder, encoder, initialSettings, negotiationLogger); // During a graceful shutdown, wait until all streams are closed. @@ -62,12 +66,13 @@ abstract class AbstractNettyHandler extends GrpcHttp2ConnectionHandler { // Extract the connection window from the settings if it was set. this.initialConnectionWindow = initialSettings.initialWindowSize() == null ? -1 : - initialSettings.initialWindowSize(); + initialSettings.initialWindowSize(); this.autoTuneFlowControlOn = autoFlowControl; if (pingLimiter == null) { pingLimiter = new AllowPingLimiter(); } this.flowControlPing = new FlowControlPinger(pingLimiter); + this.ticker = checkNotNull(ticker, "ticker"); } @Override @@ -131,14 +136,17 @@ void setAutoTuneFlowControl(boolean isOn) { final class FlowControlPinger { private static final int MAX_WINDOW_SIZE = 8 * 1024 * 1024; + public static final int MAX_BACKOFF = 10; private final PingLimiter pingLimiter; private int pingCount; private int pingReturn; private boolean pinging; private int dataSizeSincePing; - private float lastBandwidth; // bytes per second + private long lastBandwidth; // bytes per nanosecond private long lastPingTime; + private int lastTargetWindow; + private int pingFrequencyMultiplier; public FlowControlPinger(PingLimiter pingLimiter) { Preconditions.checkNotNull(pingLimiter, "pingLimiter"); @@ -157,10 +165,24 @@ public void onDataRead(int dataLength, int paddingLength) { if (!autoTuneFlowControlOn) { return; } - if (!isPinging() && pingLimiter.isPingAllowed()) { + + // Note that we are double counting around the ping initiation as the current data will be + // added at the end of this method, so will be available in the next check. This at worst + // causes us to send a ping one data packet earlier, but makes startup faster if there are + // small packets before big ones. + int dataForCheck = getDataSincePing() + dataLength + paddingLength; + // Need to double the data here to account for targetWindow being set to twice the data below + if (!isPinging() && pingLimiter.isPingAllowed() + && dataForCheck * 2 >= lastTargetWindow * pingFrequencyMultiplier) { setPinging(true); sendPing(ctx()); } + + if (lastTargetWindow == 0) { + lastTargetWindow = + decoder().flowController().initialWindowSize(connection().connectionStream()); + } + incrementDataSincePing(dataLength + paddingLength); } @@ -169,25 +191,32 @@ public void updateWindow() throws Http2Exception { return; } pingReturn++; - long elapsedTime = (System.nanoTime() - lastPingTime); + setPinging(false); + + long elapsedTime = (ticker.read() - lastPingTime); if (elapsedTime == 0) { elapsedTime = 1; } + long bandwidth = (getDataSincePing() * TimeUnit.SECONDS.toNanos(1)) / elapsedTime; - Http2LocalFlowController fc = decoder().flowController(); // Calculate new window size by doubling the observed BDP, but cap at max window int targetWindow = Math.min(getDataSincePing() * 2, MAX_WINDOW_SIZE); - setPinging(false); + Http2LocalFlowController fc = decoder().flowController(); int currentWindow = fc.initialWindowSize(connection().connectionStream()); - if (targetWindow > currentWindow && bandwidth > lastBandwidth) { - lastBandwidth = bandwidth; - int increase = targetWindow - currentWindow; - fc.incrementWindowSize(connection().connectionStream(), increase); - fc.initialWindowSize(targetWindow); - Http2Settings settings = new Http2Settings(); - settings.initialWindowSize(targetWindow); - frameWriter().writeSettings(ctx(), settings, ctx().newPromise()); + if (bandwidth <= lastBandwidth || targetWindow <= currentWindow) { + pingFrequencyMultiplier = Math.min(pingFrequencyMultiplier + 1, MAX_BACKOFF); + return; } + + pingFrequencyMultiplier = 0; // react quickly when size is changing + lastBandwidth = bandwidth; + lastTargetWindow = targetWindow; + int increase = targetWindow - currentWindow; + fc.incrementWindowSize(connection().connectionStream(), increase); + fc.initialWindowSize(targetWindow); + Http2Settings settings = new Http2Settings(); + settings.initialWindowSize(targetWindow); + frameWriter().writeSettings(ctx(), settings, ctx().newPromise()); } private boolean isPinging() { @@ -200,7 +229,7 @@ private void setPinging(boolean pingOut) { private void sendPing(ChannelHandlerContext ctx) { setDataSizeSincePing(0); - lastPingTime = System.nanoTime(); + lastPingTime = ticker.read(); encoder().writePing(ctx, false, BDP_MEASUREMENT_PING, ctx.newPromise()); pingCount++; } @@ -229,10 +258,12 @@ private void setDataSizeSincePing(int dataSize) { dataSizeSincePing = dataSize; } + // Only used in testing @VisibleForTesting void setDataSizeAndSincePing(int dataSize) { setDataSizeSincePing(dataSize); - lastPingTime = System.nanoTime() - TimeUnit.SECONDS.toNanos(1); + pingFrequencyMultiplier = 1; + lastPingTime = ticker.read() ; } } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 0b87f8ea119..da7fe84d9cb 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -23,6 +23,7 @@ import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ticker; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.InlineMe; @@ -738,7 +739,7 @@ public void run() { maxMessageSize, maxHeaderListSize, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls, options.getAuthority(), options.getUserAgent(), tooManyPingsRunnable, transportTracerFactory.create(), options.getEagAttributes(), - localSocketPicker, channelLogger, useGetForSafeMethods); + localSocketPicker, channelLogger, useGetForSafeMethods, Ticker.systemTicker()); return transport; } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 2fe4d65bfa1..55337935e3b 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -24,6 +24,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import com.google.common.base.Ticker; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.InternalChannelz; @@ -143,7 +144,8 @@ static NettyClientHandler newHandler( TransportTracer transportTracer, Attributes eagAttributes, String authority, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger, + Ticker ticker) { Preconditions.checkArgument(maxHeaderListSize > 0, "maxHeaderListSize must be positive"); Http2HeadersDecoder headersDecoder = new GrpcHttp2ClientHeadersDecoder(maxHeaderListSize); Http2FrameReader frameReader = new DefaultHttp2FrameReader(headersDecoder); @@ -169,7 +171,8 @@ static NettyClientHandler newHandler( transportTracer, eagAttributes, authority, - negotiationLogger); + negotiationLogger, + ticker); } @VisibleForTesting @@ -187,7 +190,8 @@ static NettyClientHandler newHandler( TransportTracer transportTracer, Attributes eagAttributes, String authority, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger, + Ticker ticker) { Preconditions.checkNotNull(connection, "connection"); Preconditions.checkNotNull(frameReader, "frameReader"); Preconditions.checkNotNull(lifecycleManager, "lifecycleManager"); @@ -237,7 +241,8 @@ static NettyClientHandler newHandler( eagAttributes, authority, autoFlowControl, - pingCounter); + pingCounter, + ticker); } private NettyClientHandler( @@ -253,9 +258,10 @@ private NettyClientHandler( Attributes eagAttributes, String authority, boolean autoFlowControl, - PingLimiter pingLimiter) { + PingLimiter pingLimiter, + Ticker ticker) { super(/* channelUnused= */ null, decoder, encoder, settings, - negotiationLogger, autoFlowControl, pingLimiter); + negotiationLogger, autoFlowControl, pingLimiter, ticker); this.lifecycleManager = lifecycleManager; this.keepAliveManager = keepAliveManager; this.stopwatchFactory = stopwatchFactory; diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index dbfa8cf7cab..689dd847d5e 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -23,6 +23,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.base.Ticker; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; @@ -102,6 +103,7 @@ class NettyClientTransport implements ConnectionClientTransport { private final LocalSocketPicker localSocketPicker; private final ChannelLogger channelLogger; private final boolean useGetForSafeMethods; + private final Ticker ticker; NettyClientTransport( SocketAddress address, ChannelFactory channelFactory, @@ -112,7 +114,8 @@ class NettyClientTransport implements ConnectionClientTransport { boolean keepAliveWithoutCalls, String authority, @Nullable String userAgent, Runnable tooManyPingsRunnable, TransportTracer transportTracer, Attributes eagAttributes, LocalSocketPicker localSocketPicker, ChannelLogger channelLogger, - boolean useGetForSafeMethods) { + boolean useGetForSafeMethods, Ticker ticker) { + this.negotiator = Preconditions.checkNotNull(negotiator, "negotiator"); this.negotiationScheme = this.negotiator.scheme(); this.remoteAddress = Preconditions.checkNotNull(address, "address"); @@ -137,6 +140,7 @@ class NettyClientTransport implements ConnectionClientTransport { this.logId = InternalLogId.allocate(getClass(), remoteAddress.toString()); this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger"); this.useGetForSafeMethods = useGetForSafeMethods; + this.ticker = Preconditions.checkNotNull(ticker, "ticker"); } @Override @@ -225,7 +229,8 @@ public Runnable start(Listener transportListener) { transportTracer, eagAttributes, authorityString, - channelLogger); + channelLogger, + ticker); ChannelHandler negotiationHandler = negotiator.newHandler(handler); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 62dd50ce65e..6382471f46a 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -34,6 +34,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import com.google.common.base.Ticker; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; @@ -190,7 +191,8 @@ static NettyServerHandler newHandler( maxConnectionAgeGraceInNanos, permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, - eagAttributes); + eagAttributes, + Ticker.systemTicker()); } static NettyServerHandler newHandler( @@ -212,7 +214,8 @@ static NettyServerHandler newHandler( long maxConnectionAgeGraceInNanos, boolean permitKeepAliveWithoutCalls, long permitKeepAliveTimeInNanos, - Attributes eagAttributes) { + Attributes eagAttributes, + Ticker ticker) { Preconditions.checkArgument(maxStreams > 0, "maxStreams must be positive: %s", maxStreams); Preconditions.checkArgument(flowControlWindow > 0, "flowControlWindow must be positive: %s", flowControlWindow); @@ -245,6 +248,10 @@ static NettyServerHandler newHandler( settings.maxConcurrentStreams(maxStreams); settings.maxHeaderListSize(maxHeaderListSize); + if (ticker == null) { + ticker = Ticker.systemTicker(); + } + return new NettyServerHandler( channelUnused, connection, @@ -258,7 +265,7 @@ static NettyServerHandler newHandler( maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, keepAliveEnforcer, autoFlowControl, - eagAttributes); + eagAttributes, ticker); } private NettyServerHandler( @@ -278,9 +285,10 @@ private NettyServerHandler( long maxConnectionAgeGraceInNanos, final KeepAliveEnforcer keepAliveEnforcer, boolean autoFlowControl, - Attributes eagAttributes) { + Attributes eagAttributes, + Ticker ticker) { super(channelUnused, decoder, encoder, settings, new ServerChannelLogger(), - autoFlowControl, null); + autoFlowControl, null, ticker); final MaxConnectionIdleManager maxConnectionIdleManager; if (maxConnectionIdleInNanos == MAX_CONNECTION_IDLE_NANOS_DISABLED) { @@ -325,7 +333,6 @@ public void onStreamClosed(Http2Stream stream) { this.transportListener = checkNotNull(transportListener, "transportListener"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); this.transportTracer = checkNotNull(transportTracer, "transportTracer"); - // Set the frame listener on the decoder. decoder().frameListener(new FrameListener()); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index d47942858a3..5ec82446cd6 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -54,9 +54,11 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.Attributes; +import io.grpc.CallOptions; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.StatusException; +import io.grpc.internal.AbstractStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.ClientTransport; @@ -68,6 +70,7 @@ import io.grpc.internal.StreamListener; import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ClientHeadersDecoder; +import io.grpc.testing.TestMethodDescriptors; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; @@ -118,7 +121,7 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase setKeepaliveManagerFor = ImmutableList.of("cancelShouldSucceed", @@ -136,12 +139,31 @@ public class NettyClientHandlerTest extends NettyHandlerTestBase streamListenerMessageQueue = new LinkedList<>(); + private NettyClientStream stream; @Override protected void manualSetUp() throws Exception { setUp(); } + @Override + protected AbstractStream stream() throws Exception { + if (stream == null) { + stream = new NettyClientStream(streamTransportState, + TestMethodDescriptors.voidMethod(), + new Metadata(), + channel(), + AsciiString.of("localhost"), + AsciiString.of("http"), + AsciiString.of("agent"), + StatsTraceContext.NOOP, + transportTracer, + CallOptions.DEFAULT, + false); + } + return stream; + } + /** * Set up for test. */ @@ -201,7 +223,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio // Create a new stream with id 3. ChannelFuture createFuture = enqueue( newCreateStreamCommand(grpcHeaders, streamTransportState)); - assertEquals(3, streamTransportState.id()); + assertEquals(STREAM_ID, streamTransportState.id()); // Cancel the stream. cancelStream(Status.CANCELLED); @@ -212,7 +234,7 @@ public void cancelBufferedStreamShouldChangeClientStreamStatus() throws Exceptio @Test public void createStreamShouldSucceed() throws Exception { createStream(); - verifyWrite().writeHeaders(eq(ctx()), eq(3), eq(grpcHeaders), eq(0), + verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(grpcHeaders), eq(0), eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class)); } @@ -221,7 +243,7 @@ public void cancelShouldSucceed() throws Exception { createStream(); cancelStream(Status.CANCELLED); - verifyWrite().writeRstStream(eq(ctx()), eq(3), eq(Http2Error.CANCEL.code()), + verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verify(mockKeepAliveManager, times(1)).onTransportIdle(); // onStreamClosed @@ -233,7 +255,7 @@ public void cancelDeadlineExceededShouldSucceed() throws Exception { createStream(); cancelStream(Status.DEADLINE_EXCEEDED); - verifyWrite().writeRstStream(eq(ctx()), eq(3), eq(Http2Error.CANCEL.code()), + verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); } @@ -262,7 +284,7 @@ public void cancelTwiceShouldSucceed() throws Exception { cancelStream(Status.CANCELLED); - verifyWrite().writeRstStream(any(ChannelHandlerContext.class), eq(3), + verifyWrite().writeRstStream(any(ChannelHandlerContext.class), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); ChannelFuture future = cancelStream(Status.CANCELLED); @@ -275,7 +297,7 @@ public void cancelTwiceDifferentReasons() throws Exception { cancelStream(Status.DEADLINE_EXCEEDED); - verifyWrite().writeRstStream(eq(ctx()), eq(3), eq(Http2Error.CANCEL.code()), + verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.CANCEL.code()), any(ChannelPromise.class)); ChannelFuture future = cancelStream(Status.CANCELLED); @@ -291,7 +313,7 @@ public void sendFrameShouldSucceed() throws Exception { = enqueue(new SendGrpcFrameCommand(streamTransportState, content(), true)); assertTrue(future.isSuccess()); - verifyWrite().writeData(eq(ctx()), eq(3), eq(content()), eq(0), eq(true), + verifyWrite().writeData(eq(ctx()), eq(STREAM_ID), eq(content()), eq(0), eq(true), any(ChannelPromise.class)); verify(mockKeepAliveManager, times(1)).onTransportActive(); // onStreamActive verifyNoMoreInteractions(mockKeepAliveManager); @@ -313,7 +335,7 @@ public void inboundShouldForwardToStream() throws Exception { Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC) .set(as("magic"), as("value")); - ByteBuf headersFrame = headersFrame(3, headers); + ByteBuf headersFrame = headersFrame(STREAM_ID, headers); channelRead(headersFrame); ArgumentCaptor captor = ArgumentCaptor.forClass(Metadata.class); verify(streamListener).headersRead(captor.capture()); @@ -323,7 +345,7 @@ public void inboundShouldForwardToStream() throws Exception { streamTransportState.requestMessagesFromDeframerForTesting(1); // Create a data frame and then trigger the handler to read it. - ByteBuf frame = grpcDataFrame(3, false, contentAsArray()); + ByteBuf frame = grpcDataFrame(STREAM_ID, false, contentAsArray()); channelRead(frame); InputStream message = streamListenerMessageQueue.poll(); assertArrayEquals(ByteBufUtil.getBytes(content()), ByteStreams.toByteArray(message)); @@ -580,7 +602,7 @@ public void close() throws SecurityException { public void cancelStreamShouldCreateAndThenFailBufferedStream() throws Exception { receiveMaxConcurrentStreams(0); enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); - assertEquals(3, streamTransportState.id()); + assertEquals(STREAM_ID, streamTransportState.id()); cancelStream(Status.CANCELLED); verify(streamListener).closed(eq(Status.CANCELLED), same(PROCESSED), any(Metadata.class)); } @@ -627,7 +649,7 @@ public void connectionWindowShouldBeOverridden() throws Exception { public void createIncrementsIdsForActualAndBufferdStreams() throws Exception { receiveMaxConcurrentStreams(2); enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); - assertEquals(3, streamTransportState.id()); + assertEquals(STREAM_ID, streamTransportState.id()); streamTransportState = new TransportStateImpl( handler(), @@ -766,7 +788,7 @@ public void oustandingUserPingShouldNotInteractWithDataPing() throws Exception { ArgumentCaptor captor = ArgumentCaptor.forClass(long.class); verifyWrite().writePing(eq(ctx()), eq(false), captor.capture(), any(ChannelPromise.class)); long payload = captor.getValue(); - channelRead(grpcDataFrame(3, false, contentAsArray())); + channelRead(grpcDataFrame(STREAM_ID, false, contentAsArray())); long pingData = handler().flowControlPing().payload(); channelRead(pingFrame(true, pingData)); @@ -789,18 +811,18 @@ public void bdpPingAvoidsTooManyPingsOnSpecialServers() throws Exception { Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - channelRead(headersFrame(3, headers)); - channelRead(dataFrame(3, false, content())); + channelRead(headersFrame(STREAM_ID, headers)); + channelRead(dataFrame(STREAM_ID, false, content())); verifyWrite().writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + channelRead(dataFrame(STREAM_ID, false, content())); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); + channelRead(dataFrame(STREAM_ID, false, content())); // No ping was sent - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); } @Test @@ -820,26 +842,26 @@ public void bdpPingAllowedAfterSendingData() throws Exception { Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - channelRead(headersFrame(3, headers)); - channelRead(dataFrame(3, false, content())); + channelRead(headersFrame(STREAM_ID, headers)); + channelRead(dataFrame(STREAM_ID, false, content())); verifyWrite().writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + channelRead(dataFrame(STREAM_ID, false, content())); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(pingFrame(true, 1234)); - channelRead(dataFrame(3, false, content())); + channelRead(dataFrame(STREAM_ID, false, content())); // No ping was sent - verifyWrite(times(2)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); channelRead(windowUpdate(0, 2024)); - channelRead(windowUpdate(3, 2024)); + channelRead(windowUpdate(STREAM_ID, 2024)); assertTrue(future.isDone()); assertTrue(future.isSuccess()); // But now one is sent - channelRead(dataFrame(3, false, content())); - verifyWrite(times(3)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); + channelRead(dataFrame(STREAM_ID, false, content())); + verifyWrite(times(1)).writePing(eq(ctx()), eq(false), eq(1234L), any(ChannelPromise.class)); } @Override @@ -869,7 +891,7 @@ protected void makeStream() throws Exception { // both client- and server-side. Http2Headers headers = new DefaultHttp2Headers().status(STATUS_OK) .set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC); - ByteBuf headersFrame = headersFrame(3, headers); + ByteBuf headersFrame = headersFrame(STREAM_ID, headers); channelRead(headersFrame); } @@ -928,7 +950,8 @@ public Stopwatch get() { transportTracer, Attributes.EMPTY, "someauthority", - null); + null, + fakeClock().getTicker()); } @Override diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 018ca9b6594..5f47c7b14c5 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -36,6 +36,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import com.google.common.base.Ticker; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; @@ -196,7 +197,7 @@ public void setSoLingerChannelOption() throws IOException { newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority, null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, - new SocketPicker(), new FakeChannelLogger(), false); + new SocketPicker(), new FakeChannelLogger(), false, Ticker.systemTicker()); transports.add(transport); callMeMaybe(transport.start(clientTransportListener)); @@ -448,7 +449,7 @@ public void failingToConstructChannelShouldFailGracefully() throws Exception { newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority, null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker(), - new FakeChannelLogger(), false); + new FakeChannelLogger(), false, Ticker.systemTicker()); transports.add(transport); // Should not throw @@ -763,7 +764,8 @@ private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int max negotiator, false, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize, keepAliveTimeNano, keepAliveTimeoutNano, false, authority, userAgent, tooManyPingsRunnable, - new TransportTracer(), eagAttributes, new SocketPicker(), new FakeChannelLogger(), false); + new TransportTracer(), eagAttributes, new SocketPicker(), new FakeChannelLogger(), false, + Ticker.systemTicker()); transports.add(transport); return transport; } diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index b59c20f8d72..fbab1ca5fae 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -30,6 +30,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.grpc.InternalChannelz.TransportStats; +import io.grpc.internal.AbstractStream; import io.grpc.internal.FakeClock; import io.grpc.internal.MessageFramer; import io.grpc.internal.StatsTraceContext; @@ -64,8 +65,10 @@ import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; import java.io.ByteArrayInputStream; +import java.nio.ByteBuffer; import java.util.concurrent.Delayed; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -80,6 +83,7 @@ @RunWith(JUnit4.class) public abstract class NettyHandlerTestBase { + protected static final int STREAM_ID = 3; private ByteBuf content; private EmbeddedChannel channel; @@ -328,6 +332,8 @@ protected final Http2Connection connection() { return handler().connection(); } + protected abstract AbstractStream stream() throws Exception; + @CanIgnoreReturnValue protected final ChannelFuture enqueue(WriteQueue.QueuedCommand command) { ChannelFuture future = writeQueue.enqueue(command, true); @@ -415,18 +421,15 @@ public void windowUpdateMatchesTarget() throws Exception { AbstractNettyHandler handler = (AbstractNettyHandler) handler(); handler.setAutoTuneFlowControl(true); - ByteBuf data = ctx().alloc().buffer(1024); - while (data.isWritable()) { - data.writeLong(1111); - } - int length = data.readableBytes(); - ByteBuf frame = dataFrame(3, false, data.copy()); + byte[] data = initXkbBuffer(1); + int wireSize = data.length + 5; // 5 is the size of the header + ByteBuf frame = grpcDataFrame(3, false, data); channelRead(frame); - int accumulator = length; + int accumulator = wireSize; // 40 is arbitrary, any number large enough to trigger a window update would work for (int i = 0; i < 40; i++) { - channelRead(dataFrame(3, false, data.copy())); - accumulator += length; + channelRead(grpcDataFrame(3, false, data)); + accumulator += wireSize; } long pingData = handler.flowControlPing().payload(); channelRead(pingFrame(true, pingData)); @@ -444,8 +447,10 @@ public void windowShouldNotExceedMaxWindowSize() throws Exception { Http2Stream connectionStream = connection().connectionStream(); Http2LocalFlowController localFlowController = connection().local().flowController(); int maxWindow = handler.flowControlPing().maxWindow(); + fakeClock.forwardTime(10, TimeUnit.SECONDS); handler.flowControlPing().setDataSizeAndSincePing(maxWindow); + fakeClock.forwardTime(1, TimeUnit.SECONDS); long payload = handler.flowControlPing().payload(); channelRead(pingFrame(true, payload)); @@ -501,4 +506,124 @@ public void transportTracer_windowUpdate_local() throws Exception { assertEquals(flowControlWindow + 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE, connection().local().flowController().windowSize(connection().connectionStream())); } + + private AbstractNettyHandler setupPingTest() throws Exception { + this.flowControlWindow = 1024 * 64; + manualSetUp(); + makeStream(); + + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + return handler; + } + + @Test + public void bdpPingLimitOutstanding() throws Exception { + AbstractNettyHandler handler = setupPingTest(); + long pingData = handler.flowControlPing().payload(); + + byte[] data1KbBuf = initXkbBuffer(1); + byte[] data40KbBuf = initXkbBuffer(40); + + readXCopies(1, data1KbBuf); // should initiate a ping + + readXCopies(1, data40KbBuf); // no ping, already active + fakeClock().forwardTime(20, TimeUnit.MILLISECONDS); + readPingAck(pingData); + assertEquals(1, handler.flowControlPing().getPingCount()); + assertEquals(1, handler.flowControlPing().getPingReturn()); + + readXCopies(4, data40KbBuf); // initiate ping + assertEquals(2, handler.flowControlPing().getPingCount()); + fakeClock.forwardTime(1, TimeUnit.MILLISECONDS); + readPingAck(pingData); + + readXCopies(1, data1KbBuf); // ping again since had 160K data since last ping started + assertEquals(3, handler.flowControlPing().getPingCount()); + fakeClock.forwardTime(1, TimeUnit.MILLISECONDS); + readPingAck(pingData); + + fakeClock.forwardTime(1, TimeUnit.MILLISECONDS); + readXCopies(1, data1KbBuf); // no ping, too little data + assertEquals(3, handler.flowControlPing().getPingCount()); + } + + @Test + public void testPingBackoff() throws Exception { + AbstractNettyHandler handler = setupPingTest(); + long pingData = handler.flowControlPing().payload(); + byte[] data40KbBuf = initXkbBuffer(40); + + handler.flowControlPing().setDataSizeAndSincePing(200000); + + for (int i = 0; i <= 10; i++) { + int beforeCount = handler.flowControlPing().getPingCount(); + // should resize on 0 + readXCopies(6, data40KbBuf); // initiate ping on i= {0, 1, 3, 6, 10} + int afterCount = handler.flowControlPing().getPingCount(); + fakeClock().forwardNanos(200); + if (afterCount > beforeCount) { + readPingAck(pingData); // should increase backoff multiplier + } + } + assertEquals(6, handler.flowControlPing().getPingCount()); + } + + @Test + public void bdpPingWindowResizing() throws Exception { + this.flowControlWindow = 1024 * 8; + manualSetUp(); + makeStream(); + + AbstractNettyHandler handler = (AbstractNettyHandler) handler(); + handler.setAutoTuneFlowControl(true); + Http2LocalFlowController localFlowController = connection().local().flowController(); + long pingData = handler.flowControlPing().payload(); + int initialWindowSize = localFlowController.initialWindowSize(); + byte[] data1Kb = initXkbBuffer(1); + byte[] data10Kb = initXkbBuffer(10); + + readXCopies(1, data1Kb); // initiate ping + fakeClock().forwardNanos(2); + readPingAck(pingData); // should not resize window because of small target window + assertEquals(initialWindowSize, localFlowController.initialWindowSize()); + + readXCopies(2, data10Kb); // initiate ping on first + fakeClock().forwardNanos(200); + readPingAck(pingData); // should resize window + int windowSizeA = localFlowController.initialWindowSize(); + Assert.assertNotEquals(initialWindowSize, windowSizeA); + + readXCopies(3, data10Kb); // initiate ping w/ first 10K packet + fakeClock().forwardNanos(5000); + readPingAck(pingData); // should not resize window as bandwidth didn't increase + Assert.assertEquals(windowSizeA, localFlowController.initialWindowSize()); + + readXCopies(6, data10Kb); // initiate ping with fist packet + fakeClock().forwardNanos(100); + readPingAck(pingData); // should resize window + int windowSizeB = localFlowController.initialWindowSize(); + Assert.assertNotEquals(windowSizeA, windowSizeB); + } + + private void readPingAck(long pingData) throws Exception { + channelRead(pingFrame(true, pingData)); + } + + private void readXCopies(int copies, byte[] data) throws Exception { + for (int i = 0; i < copies; i++) { + channelRead(grpcDataFrame(STREAM_ID, false, data)); // buffer it + stream().request(1); // consume it + } + } + + private byte[] initXkbBuffer(int multiple) { + ByteBuffer data = ByteBuffer.allocate(1024 * multiple); + + for (int i = 0; i < multiple * 1024 / 4; i++) { + data.putInt(4 * i, 1111); + } + return data.array(); + } + } diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 72c267a4825..926ce8261a4 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -60,6 +60,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.StreamTracer; +import io.grpc.internal.AbstractStream; import io.grpc.internal.GrpcUtil; import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.KeepAliveManager; @@ -112,8 +113,6 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase