From 93e886043f4b12b8f17467260e99d6c96cb2df23 Mon Sep 17 00:00:00 2001 From: smdn Date: Fri, 4 Feb 2022 17:16:03 +0900 Subject: [PATCH] override Stream.ReadAsync(Memory, CancellationToken) --- .../LineOrientedStream.cs | 102 +++++++++++++----- .../LineOrientedStream.cs | 81 ++++++++++++++ 2 files changed, 157 insertions(+), 26 deletions(-) diff --git a/src/Smdn.Fundamental.Stream.LineOriented/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs b/src/Smdn.Fundamental.Stream.LineOriented/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs index cae961feb..f8ce26473 100644 --- a/src/Smdn.Fundamental.Stream.LineOriented/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs +++ b/src/Smdn.Fundamental.Stream.LineOriented/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs @@ -456,66 +456,116 @@ CancellationToken cancellationToken if (count == 0L) return Task.FromResult(0); // do nothing + return ReadAsyncCore( + destination: buffer.AsMemory(offset, count), + cancellationToken: cancellationToken +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + ).AsTask(); +#else + ); +#endif + } + +#if SYSTEM_IO_STREAM_READASYNC_MEMORY_OF_BYTE + public override ValueTask ReadAsync( + Memory buffer, + CancellationToken cancellationToken = default + ) + { + CheckDisposed(); + + if (cancellationToken.IsCancellationRequested) +#if SYSTEM_THREADING_TASKS_VALUETASK_FROMCANCELED + return ValueTask.FromCanceled(cancellationToken); +#else +#if SYSTEM_THREADING_TASKS_TASK_FROMCANCELED + return new(Task.FromCanceled(cancellationToken)); +#else + return new(new Task(() => default, cancellationToken)); +#endif +#endif + + if (buffer.IsEmpty) + return new(0); // do nothing + return ReadAsyncCore( destination: buffer, - offset: offset, - count: count, cancellationToken: cancellationToken ); } +#endif - private async Task ReadAsyncCore( - byte[] destination, - int offset, - int count, + private async +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + ValueTask +#else + Task +#endif + ReadAsyncCore( + Memory destination, CancellationToken cancellationToken ) { - if (count <= bufRemain) { - Buffer.BlockCopy(buffer, bufOffset, destination, offset, count); - bufOffset += count; - bufRemain -= count; + if (destination.Length <= bufRemain) { + buffer.AsSpan(bufOffset, destination.Length).CopyTo(destination.Span); + bufOffset += destination.Length; + bufRemain -= destination.Length; - return count; + return destination.Length; } var read = 0; if (bufRemain != 0) { - Buffer.BlockCopy(buffer, bufOffset, destination, offset, bufRemain); + buffer.AsSpan(bufOffset, bufRemain).CopyTo(destination.Span); read = bufRemain; - offset += bufRemain; - count -= bufRemain; + + destination = destination.Slice(bufRemain); bufRemain = 0; } // read from base stream for (; ; ) { - if (count <= 0) + if (destination.IsEmpty) break; - var r = #if SYSTEM_IO_STREAM_READASYNC_MEMORY_OF_BYTE - await stream.ReadAsync( - destination.AsMemory(offset, count), + var r = await stream.ReadAsync( + destination, + cancellationToken + ).ConfigureAwait(false); #else #pragma warning disable CA1835 - await stream.ReadAsync( - destination, - offset, - count, -#pragma warning restore CA1835 -#endif + byte[] readBuffer = null; + int r = 0; + + try { + readBuffer = ArrayPool.Shared.Rent(destination.Length); + + r = await stream.ReadAsync( + readBuffer, + 0, + destination.Length, cancellationToken ).ConfigureAwait(false); + } + finally { + if (readBuffer is not null) { + if (0 < r) + readBuffer.AsMemory(0, r).CopyTo(destination); + + ArrayPool.Shared.Return(readBuffer); + } + } +#pragma warning restore CA1835 +#endif if (r <= 0) break; - offset += r; - count -= r; + destination = destination.Slice(r); read += r; } diff --git a/tests/Smdn/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs b/tests/Smdn/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs index b802796e9..32c85fad1 100644 --- a/tests/Smdn/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs +++ b/tests/Smdn/Smdn.IO.Streams.LineOriented/LineOrientedStream.cs @@ -250,6 +250,24 @@ public void TestRead_BufferEmpty(StreamType type) Assert.AreEqual(data, buffer); } +#if SYSTEM_IO_STREAM_READASYNC_MEMORY_OF_BYTE + [TestCase(StreamType.Strict)] + [TestCase(StreamType.Loose)] + public async Task TestReadAsync_ToMemory_BufferEmpty(StreamType type) + { + var data = new byte[] {0x40, 0x41, Ascii.Octets.CR, Ascii.Octets.LF, 0x42, 0x43, 0x44, Ascii.Octets.CR, Ascii.Octets.LF, 0x45, 0x46, 0x47}; + var stream = CreateStream(type, new MemoryStream(data), 8); + + Memory buffer = new byte[12]; + + Assert.AreEqual(12L, await stream.ReadAsync(buffer)); + + Assert.AreEqual(12L, stream.Position, "Position"); + + Assert.That(buffer, Is.EqualTo(data.AsMemory()), nameof(buffer)); + } +#endif + [TestCase(StreamType.Strict)] [TestCase(StreamType.Loose)] public void TestRead_LessThanBuffered(StreamType type) @@ -272,6 +290,30 @@ public void TestRead_LessThanBuffered(StreamType type) Assert.AreEqual(data.Slice(4, 4), buffer); } +#if SYSTEM_IO_STREAM_READASYNC_MEMORY_OF_BYTE + [TestCase(StreamType.Strict)] + [TestCase(StreamType.Loose)] + public async Task TestReadAsync_ToMemory_LessThanBuffered(StreamType type) + { + var data = new byte[] {0x40, 0x41, Ascii.Octets.CR, Ascii.Octets.LF, 0x42, 0x43, 0x44, Ascii.Octets.CR, Ascii.Octets.LF, 0x45, 0x46, 0x47}; + var stream = CreateStream(type, new MemoryStream(data), 16); + + var line = stream.ReadLine(true); + + Assert.AreEqual(4L, stream.Position, "Position"); + + Assert.AreEqual(data.Slice(0, 4), line); + + Memory buffer = new byte[4]; + + Assert.AreEqual(4, await stream.ReadAsync(buffer)); + + Assert.AreEqual(8L, stream.Position, "Position"); + + Assert.That(buffer, Is.EqualTo(data.AsMemory(4, 4)), nameof(buffer)); + } +#endif + [TestCase(StreamType.Strict)] [TestCase(StreamType.Loose)] public void TestRead_LongerThanBuffered(StreamType type) @@ -294,6 +336,30 @@ public void TestRead_LongerThanBuffered(StreamType type) Assert.AreEqual(data.Slice(4, 8), buffer.Slice(0, 8)); } +#if SYSTEM_IO_STREAM_READASYNC_MEMORY_OF_BYTE + [TestCase(StreamType.Strict)] + [TestCase(StreamType.Loose)] + public async Task TestReadAsync_ToMemory_LongerThanBuffered(StreamType type) + { + var data = new byte[] {0x40, 0x41, Ascii.Octets.CR, Ascii.Octets.LF, 0x42, 0x43, 0x44, Ascii.Octets.CR, Ascii.Octets.LF, 0x45, 0x46, 0x47}; + var stream = CreateStream(type, new MemoryStream(data), 8); + + var line = stream.ReadLine(true); + + Assert.AreEqual(4L, stream.Position, "Position"); + + Assert.AreEqual(data.Slice(0, 4), line); + + Memory buffer = new byte[10]; + + Assert.AreEqual(8, await stream.ReadAsync(buffer)); + + Assert.AreEqual(12L, stream.Position, "Position"); + + Assert.That(buffer.Slice(0, 8), Is.EqualTo(data.AsMemory(4, 8)), nameof(buffer)); + } +#endif + [TestCase(StreamType.Strict)] [TestCase(StreamType.Loose)] public void TestReadToStream_LessThanBuffered(StreamType type) @@ -368,6 +434,18 @@ public void TestReadAsync_LengthZero(StreamType type) Assert.AreEqual(0L, stream.Position, "Position"); } +#if SYSTEM_IO_STREAM_READASYNC_MEMORY_OF_BYTE + [TestCase(StreamType.Strict)] + [TestCase(StreamType.Loose)] + public async Task TestReadAsync_ToMemory_LengthZero(StreamType type) + { + var stream = CreateStream(type, new MemoryStream(), 8); + + Assert.AreEqual(0L, await stream.ReadAsync(Memory.Empty)); + Assert.AreEqual(0L, stream.Position, "Position"); + } +#endif + [TestCase(StreamType.Strict)] [TestCase(StreamType.Loose)] public void TestReadAsync_CancelledToken(StreamType type) @@ -548,6 +626,9 @@ public void TestClose(StreamType type) Assert.Throws(() => stream.ReadByte()); Assert.Throws(() => stream.Read(buffer, 0, 8)); Assert.Throws(() => stream.ReadAsync(buffer, 0, 8)); +#if SYSTEM_IO_STREAM_READASYNC_MEMORY_OF_BYTE + Assert.ThrowsAsync(async () => await stream.ReadAsync(Memory.Empty)); +#endif Assert.Throws(() => stream.Read(Stream.Null, 8)); Assert.Throws(() => stream.ReadAsync(Stream.Null, 8)); Assert.Throws(() => stream.Flush());