Skip to content

Commit

Permalink
keep MsQuicConnection alive when streams are pending (#52800)
Browse files Browse the repository at this point in the history
* keep MsQuicConnection alive when streams are pending

* remove extra file

* fix gchandle

* feedback from review

* feedback from review

* feedback from review
  • Loading branch information
wfurt authored Jun 10, 2021
1 parent 74374a0 commit 68484ea
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ internal static class MsQuicStatusCodes
internal static uint InternalError => OperatingSystem.IsWindows() ? Windows.InternalError : Posix.InternalError;
internal static uint InvalidState => OperatingSystem.IsWindows() ? Windows.InvalidState : Posix.InvalidState;
internal static uint HandshakeFailure => OperatingSystem.IsWindows() ? Windows.HandshakeFailure : Posix.HandshakeFailure;
internal static uint UserCanceled => OperatingSystem.IsWindows() ? Windows.UserCanceled : Posix.UserCanceled;

// TODO return better error messages here.
public static string GetError(uint status) => OperatingSystem.IsWindows() ? Windows.GetError(status) : Posix.GetError(status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ internal sealed class MsQuicConnection : QuicConnectionProvider
private readonly SafeMsQuicConfigurationHandle? _configuration;

private readonly State _state = new State();
private GCHandle _stateHandle;
private bool _disposed;
private int _disposed;

private IPEndPoint? _localEndPoint;
private readonly EndPoint _remoteEndPoint;
Expand All @@ -43,6 +42,7 @@ internal sealed class MsQuicConnection : QuicConnectionProvider
internal sealed class State
{
public SafeMsQuicConnectionHandle Handle = null!; // set inside of MsQuicConnection ctor.
public GCHandle StateGCHandle;

// These exists to prevent GC of the MsQuicConnection in the middle of an async op (Connect or Shutdown).
public MsQuicConnection? Connection;
Expand All @@ -59,6 +59,8 @@ internal sealed class State

public bool Connected;
public long AbortErrorCode = -1;
public int StreamCount;
private bool _closing;

// Queue for accepted streams.
// Backlog limit is managed by MsQuic so it can be unbounded here.
Expand All @@ -67,30 +69,83 @@ internal sealed class State
SingleReader = true,
SingleWriter = true,
});

public void RemoveStream(MsQuicStream stream)
{
bool releaseHandles;
lock (this)
{
StreamCount--;
Debug.Assert(StreamCount >= 0);
releaseHandles = _closing && StreamCount == 0;
}

if (releaseHandles)
{
Handle?.Dispose();
StateGCHandle.Free();
}
}

public bool TryQueueNewStream(SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags)
{
var stream = new MsQuicStream(this, streamHandle, flags);
if (AcceptQueue.Writer.TryWrite(stream))
{
return true;
}
else
{
stream.Dispose();
return false;
}
}

public bool TryAddStream(MsQuicStream stream)
{
lock (this)
{
if (_closing)
{
return false;
}

StreamCount++;
return true;
}
}

// This is called under lock from connection dispose
public void SetClosing()
{
lock (this)
{
_closing = true;
}
}
}

// constructor for inbound connections
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle)
{
_state.Handle = handle;
_state.StateGCHandle = GCHandle.Alloc(_state);
_state.Connected = true;
_localEndPoint = localEndPoint;
_remoteEndPoint = remoteEndPoint;
_remoteCertificateRequired = false;
_isServer = true;

_stateHandle = GCHandle.Alloc(_state);

try
{
MsQuicApi.Api.SetCallbackHandlerDelegate(
_state.Handle,
s_connectionDelegate,
GCHandle.ToIntPtr(_stateHandle));
GCHandle.ToIntPtr(_state.StateGCHandle));
}
catch
{
_stateHandle.Free();
_state.StateGCHandle.Free();
throw;
}

Expand All @@ -113,7 +168,7 @@ public MsQuicConnection(QuicClientConnectionOptions options)
_remoteCertificateValidationCallback = options.ClientAuthenticationOptions.RemoteCertificateValidationCallback;
}

_stateHandle = GCHandle.Alloc(_state);
_state.StateGCHandle = GCHandle.Alloc(_state);
try
{
// this handle is ref counted by MsQuic, so safe to dispose here.
Expand All @@ -122,14 +177,14 @@ public MsQuicConnection(QuicClientConnectionOptions options)
uint status = MsQuicApi.Api.ConnectionOpenDelegate(
MsQuicApi.Api.Registration,
s_connectionDelegate,
GCHandle.ToIntPtr(_stateHandle),
GCHandle.ToIntPtr(_state.StateGCHandle),
out _state.Handle);

QuicExceptionHelpers.ThrowIfFailed(status, "Could not open the connection.");
}
catch
{
_stateHandle.Free();
_state.StateGCHandle.Free();
throw;
}

Expand Down Expand Up @@ -224,9 +279,13 @@ private static uint HandleEventShutdownComplete(State state, ref ConnectionEvent
private static uint HandleEventNewStream(State state, ref ConnectionEvent connectionEvent)
{
var streamHandle = new SafeMsQuicStreamHandle(connectionEvent.Data.PeerStreamStarted.Stream);
var stream = new MsQuicStream(state, streamHandle, connectionEvent.Data.PeerStreamStarted.Flags);
if (!state.TryQueueNewStream(streamHandle, connectionEvent.Data.PeerStreamStarted.Flags))
{
// This will call StreamCloseDelegate and free the stream.
// We will return Success to the MsQuic to prevent double free.
streamHandle.Dispose();
}

state.AcceptQueue.Writer.TryWrite(stream);
return MsQuicStatusCodes.Success;
}

Expand Down Expand Up @@ -598,17 +657,45 @@ public override void Dispose()
Dispose(false);
}

private async Task FlushAcceptQueue()
{
_state.AcceptQueue.Writer.TryComplete();
await foreach (MsQuicStream item in _state.AcceptQueue.Reader.ReadAllAsync().ConfigureAwait(false))
{
item.Dispose();
}
}

private void Dispose(bool disposing)
{
if (_disposed)
int disposed = Interlocked.Exchange(ref _disposed, 1);
if (disposed != 0)
{
return;
}

bool releaseHandles = false;
lock (_state)
{
_state.Connection = null;
if (_state.StreamCount == 0)
{
releaseHandles = true;
}
else
{
// We have pending streams so we need to defer cleanup until last one is gone.
_state.SetClosing();
}
}

FlushAcceptQueue().GetAwaiter().GetResult();
_configuration?.Dispose();
_state?.Handle?.Dispose();
if (_stateHandle.IsAllocated) _stateHandle.Free();
_disposed = true;
if (releaseHandles)
{
_state!.Handle?.Dispose();
if (_state.StateGCHandle.IsAllocated) _state.StateGCHandle.Free();
}
}

// TODO: this appears abortive and will cause prior successfully shutdown and closed streams to drop data.
Expand All @@ -622,7 +709,7 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell

private void ThrowIfDisposed()
{
if (_disposed)
if (_disposed == 1)
{
throw new ObjectDisposedException(nameof(MsQuicStream));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ internal MsQuicStream(MsQuicConnection.State connectionState, SafeMsQuicStreamHa
throw;
}

if (!connectionState.TryAddStream(this))
{
_stateHandle.Free();
throw new ObjectDisposedException(nameof(QuicConnection));
}

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(
Expand Down Expand Up @@ -133,6 +139,13 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F
throw;
}

if (!connectionState.TryAddStream(this))
{
_state.Handle?.Dispose();
_stateHandle.Free();
throw new ObjectDisposedException(nameof(QuicConnection));
}

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(
Expand Down Expand Up @@ -321,7 +334,6 @@ internal override async ValueTask<int> ReadAsync(Memory<byte> destination, Cance
{
shouldComplete = true;
}

state.ReadState = ReadState.Aborted;
}

Expand Down Expand Up @@ -557,6 +569,8 @@ private void Dispose(bool disposing)
Marshal.FreeHGlobal(_state.SendQuicBuffers);
if (_stateHandle.IsAllocated) _stateHandle.Free();
CleanupSendState(_state);
Debug.Assert(_state.ConnectionState != null);
_state.ConnectionState?.RemoveStream(this);

if (NetEventSource.Log.IsEnabled())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ await RunClientServer(
await (new[] { t1, t2 }).WhenAllOrAnyFailed(millisecondsTimeout: 1000000);
}

[ActiveIssue("/~https://github.com/dotnet/runtime/issues/52048")]
[Fact]
public async Task ManagedAVE_MinimalFailingTest()
{
Expand All @@ -461,6 +460,32 @@ async Task GetStreamIdWithoutStartWorks()
// TODO: stream that is opened by client but left unaccepted by server may cause AccessViolationException in its Finalizer
}

await GetStreamIdWithoutStartWorks().WaitAsync(TimeSpan.FromSeconds(15));

GC.Collect();
}

[Fact]
public async Task DisposingConnection_OK()
{
async Task GetStreamIdWithoutStartWorks()
{
using QuicListener listener = CreateQuicListener();
using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint);

ValueTask clientTask = clientConnection.ConnectAsync();
using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;

using QuicStream clientStream = clientConnection.OpenBidirectionalStream();
Assert.Equal(0, clientStream.StreamId);

// Dispose all connections before the streams;
clientConnection.Dispose();
serverConnection.Dispose();
listener.Dispose();
}

await GetStreamIdWithoutStartWorks();

GC.Collect();
Expand Down

0 comments on commit 68484ea

Please sign in to comment.