diff --git a/ktor-client/ktor-client-core/common/src/io/ktor/client/request/forms/formDsl.kt b/ktor-client/ktor-client-core/common/src/io/ktor/client/request/forms/formDsl.kt index ef14c5b5364..dbaf867391f 100644 --- a/ktor-client/ktor-client-core/common/src/io/ktor/client/request/forms/formDsl.kt +++ b/ktor-client/ktor-client-core/common/src/io/ktor/client/request/forms/formDsl.kt @@ -9,6 +9,7 @@ import io.ktor.http.content.* import io.ktor.utils.io.* import io.ktor.utils.io.core.* import kotlinx.io.* +import kotlinx.io.Buffer import kotlin.contracts.* /** @@ -49,8 +50,10 @@ public fun formData(vararg values: FormPart<*>): List { PartData.BinaryItem({ ByteReadPacket(value) }, {}, partHeaders.build()) } is Source -> { - partHeaders.append(HttpHeaders.ContentLength, value.remaining.toString()) - PartData.BinaryItem({ value.copy() }, { value.close() }, partHeaders.build()) + if (value is Buffer) { + partHeaders.append(HttpHeaders.ContentLength, value.remaining.toString()) + } + PartData.BinaryItem({ value.peek() }, { value.close() }, partHeaders.build()) } is InputProvider -> { val size = value.size diff --git a/ktor-client/ktor-client-core/common/test/MultiPartFormDataContentTest.kt b/ktor-client/ktor-client-core/common/test/MultiPartFormDataContentTest.kt index c46a0f0b75d..1eef14ec38e 100644 --- a/ktor-client/ktor-client-core/common/test/MultiPartFormDataContentTest.kt +++ b/ktor-client/ktor-client-core/common/test/MultiPartFormDataContentTest.kt @@ -3,17 +3,22 @@ */ import io.ktor.client.request.forms.* +import io.ktor.http.Headers +import io.ktor.http.HttpHeaders import io.ktor.test.dispatcher.* import io.ktor.utils.io.* import io.ktor.utils.io.charsets.* import kotlinx.coroutines.* +import kotlinx.coroutines.test.runTest import kotlinx.io.* +import kotlinx.io.files.Path +import kotlin.random.Random import kotlin.test.* class MultiPartFormDataContentTest { @Test - fun testMultiPartFormDataContentHasCorrectPrefix() = testSuspend { + fun testMultiPartFormDataContentHasCorrectPrefix() = runTest { val formData = MultiPartFormDataContent( formData { append("Hello", "World") @@ -33,7 +38,7 @@ class MultiPartFormDataContentTest { } @Test - fun testEmptyByteReadChannel() = testSuspend { + fun testEmptyByteReadChannel() = runTest { val data = MultiPartFormDataContent( formData { append("channel", ChannelProvider { ByteReadChannel.Empty }) @@ -55,7 +60,7 @@ class MultiPartFormDataContentTest { } @Test - fun testByteReadChannelWithString() = testSuspend { + fun testByteReadChannelWithString() = runTest { val content = "body" val data = MultiPartFormDataContent( formData { @@ -79,7 +84,7 @@ class MultiPartFormDataContentTest { } @Test - fun testNumberQuoted() = testSuspend { + fun testNumberQuoted() = runTest { val data = MultiPartFormDataContent( formData { append("not_a_forty_two", 1337) @@ -102,7 +107,7 @@ class MultiPartFormDataContentTest { } @Test - fun testBooleanQuoted() = testSuspend { + fun testBooleanQuoted() = runTest { val data = MultiPartFormDataContent( formData { append("is_forty_two", false) @@ -125,7 +130,7 @@ class MultiPartFormDataContentTest { } @Test - fun testStringsList() = testSuspend { + fun testStringsList() = runTest { val data = MultiPartFormDataContent( formData { append("platforms[]", listOf("windows", "linux", "osx")) @@ -158,7 +163,7 @@ class MultiPartFormDataContentTest { } @Test - fun testStringsArray() = testSuspend { + fun testStringsArray() = runTest { val data = MultiPartFormDataContent( formData { append("platforms[]", arrayOf("windows", "linux", "osx")) @@ -191,7 +196,7 @@ class MultiPartFormDataContentTest { } @Test - fun testStringsListBadKey() = testSuspend { + fun testStringsListBadKey() = runTest { val attempt = { MultiPartFormDataContent( formData { @@ -206,7 +211,7 @@ class MultiPartFormDataContentTest { } @Test - fun testByteReadChannelOverBufferSize() = testSuspend { + fun testByteReadChannelOverBufferSize() = runTest { val body = ByteArray(4089) { 'k'.code.toByte() } val data = MultiPartFormDataContent( formData { @@ -228,6 +233,36 @@ class MultiPartFormDataContentTest { ) } + @Test + fun testFileContentFromSource() = runTest { + val expected = "This content should appear in the multipart body." + val fileSource = try { + with(kotlinx.io.files.SystemFileSystem) { + val file = Path(kotlinx.io.files.SystemTemporaryDirectory, "temp${Random.nextInt(1000, 9999)}.txt") + sink(file).buffered().use { it.writeString(expected) } + source(file).buffered() + } + } catch (_: Throwable) { + // filesystem is not supported for web platforms (yet) + return@runTest + } + val data = MultiPartFormDataContent( + formData { + append( + key = "key", + value = fileSource, + headers = Headers.build { + append(HttpHeaders.ContentType, "text/plain") + append(HttpHeaders.ContentDisposition, "filename=\"file.txt\"") + }, + ) + } + ) + assertTrue("File contents should be present in the multipart body.") { + data.readString().contains(expected) + } + } + private suspend fun MultiPartFormDataContent.readString(charset: Charset = Charsets.UTF_8): String { val bytes = readBytes() return bytes.decodeToString(0, 0 + bytes.size) diff --git a/ktor-io/common/src/io/ktor/utils/io/core/ByteReadPacket.kt b/ktor-io/common/src/io/ktor/utils/io/core/ByteReadPacket.kt index 4fc4b885125..6fad093c460 100644 --- a/ktor-io/common/src/io/ktor/utils/io/core/ByteReadPacket.kt +++ b/ktor-io/common/src/io/ktor/utils/io/core/ByteReadPacket.kt @@ -49,8 +49,15 @@ public fun Source.readAvailable(out: kotlinx.io.Buffer): Int { return result.toInt() } +/** + * Returns a copy of the current buffer attached to this Source. + */ +@Deprecated( + "Use peek() or buffer.copy() instead, depending on your use case.", + ReplaceWith("peek()", "kotlinx.io.Source") +) @OptIn(InternalIoApi::class) -public fun Source.copy(): Source = buffer.copy() +public fun Source.copy(): Source = peek() @OptIn(InternalIoApi::class) public fun Source.readShortLittleEndian(): Short { diff --git a/ktor-io/jvm/src/io/ktor/utils/io/streams/Streams.kt b/ktor-io/jvm/src/io/ktor/utils/io/streams/Streams.kt index d68c8d5cd69..bc7148b9280 100644 --- a/ktor-io/jvm/src/io/ktor/utils/io/streams/Streams.kt +++ b/ktor-io/jvm/src/io/ktor/utils/io/streams/Streams.kt @@ -18,7 +18,7 @@ public fun Source.inputStream(): InputStream = asInputStream() @OptIn(InternalIoApi::class) public fun OutputStream.writePacket(packet: Source) { - packet.buffer.copyTo(this) + packet.transferTo(this.asSink()) } public fun OutputStream.writePacket(block: Sink.() -> Unit) { diff --git a/ktor-server/ktor-server-jetty-jakarta/jvm/test/io/ktor/tests/server/jetty/jakarta/JettyIdleTimeoutTest.kt b/ktor-server/ktor-server-jetty-jakarta/jvm/test/io/ktor/tests/server/jetty/jakarta/JettyIdleTimeoutTest.kt index 99cc90a6461..d1a4ee5e51b 100644 --- a/ktor-server/ktor-server-jetty-jakarta/jvm/test/io/ktor/tests/server/jetty/jakarta/JettyIdleTimeoutTest.kt +++ b/ktor-server/ktor-server-jetty-jakarta/jvm/test/io/ktor/tests/server/jetty/jakarta/JettyIdleTimeoutTest.kt @@ -27,7 +27,7 @@ class JettyIdleTimeoutTest : EngineTestBase repeat(repeatCount) { - getOutputStream().writePacket(request.copy()) + getOutputStream().writePacket(request.peek()) getOutputStream().write(body) getOutputStream().flush() } diff --git a/ktor-shared/ktor-websockets/jvm/src/io/ktor/websocket/internals/BytePacketUtils.kt b/ktor-shared/ktor-websockets/jvm/src/io/ktor/websocket/internals/BytePacketUtils.kt index 828476dca29..63df6a0bd41 100644 --- a/ktor-shared/ktor-websockets/jvm/src/io/ktor/websocket/internals/BytePacketUtils.kt +++ b/ktor-shared/ktor-websockets/jvm/src/io/ktor/websocket/internals/BytePacketUtils.kt @@ -7,8 +7,9 @@ package io.ktor.websocket.internals import io.ktor.utils.io.core.* import kotlinx.io.* +@OptIn(InternalIoApi::class) internal fun Source.endsWith(data: ByteArray): Boolean { - copy().apply { + buffer.copy().apply { discard(remaining - data.size) return readByteArray().contentEquals(data) } diff --git a/ktor-utils/common/src/io/ktor/util/Base64.kt b/ktor-utils/common/src/io/ktor/util/Base64.kt index 6d56fe00cf2..24177342f38 100644 --- a/ktor-utils/common/src/io/ktor/util/Base64.kt +++ b/ktor-utils/common/src/io/ktor/util/Base64.kt @@ -104,7 +104,7 @@ public fun String.decodeBase64Bytes(): ByteArray = buildPacket { public fun Source.decodeBase64Bytes(): Input = buildPacket { val data = ByteArray(4) - while (remaining > 0) { + while (!exhausted()) { val read = readAvailable(data) val chunk = data.foldIndexed(0) { index, result, current -> diff --git a/ktor-utils/common/src/io/ktor/util/ByteChannels.kt b/ktor-utils/common/src/io/ktor/util/ByteChannels.kt index 91403a67a8c..59e89c71536 100644 --- a/ktor-utils/common/src/io/ktor/util/ByteChannels.kt +++ b/ktor-utils/common/src/io/ktor/util/ByteChannels.kt @@ -5,7 +5,6 @@ package io.ktor.util import io.ktor.utils.io.* -import io.ktor.utils.io.core.* import io.ktor.utils.io.pool.* import kotlinx.coroutines.* @@ -64,8 +63,8 @@ public fun ByteReadChannel.copyToBoth(first: ByteWriteChannel, second: ByteWrite while (!isClosedForRead && (!first.isClosedForWrite || !second.isClosedForWrite)) { readRemaining(CHUNK_BUFFER_SIZE).use { try { - first.writePacket(it.copy()) - second.writePacket(it.copy()) + first.writePacket(it.peek()) + second.writePacket(it.peek()) } catch (cause: Throwable) { this@copyToBoth.cancel(cause) first.close(cause)