Skip to content

Commit

Permalink
KTOR-8210 Fix copy() and Source multipart processing (#4686)
Browse files Browse the repository at this point in the history
* KTOR-8210 Use peek over copy; avoid assumptions of Source contents fully buffered

* Revert "KTOR-8210 Use peek over copy; avoid assumptions of Source contents fully buffered"

This reverts commit 37bf04a.

* KTOR-8210 Fix copy() and multipart source content processing

* fixup! KTOR-8210 Fix copy() and multipart source content processing

* Increase Jetty test timeout to resolve flakiness
  • Loading branch information
bjhham authored Feb 21, 2025
1 parent a309ce8 commit 9ba9388
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import io.ktor.http.content.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.io.*
import kotlinx.io.Buffer
import kotlin.contracts.*

/**
Expand Down Expand Up @@ -49,8 +50,10 @@ public fun formData(vararg values: FormPart<*>): List<PartData> {
PartData.BinaryItem({ ByteReadPacket(value) }, {}, partHeaders.build())
}
is Source -> {
partHeaders.append(HttpHeaders.ContentLength, value.remaining.toString())
PartData.BinaryItem({ value.copy() }, { value.close() }, partHeaders.build())
if (value is Buffer) {
partHeaders.append(HttpHeaders.ContentLength, value.remaining.toString())
}
PartData.BinaryItem({ value.peek() }, { value.close() }, partHeaders.build())
}
is InputProvider -> {
val size = value.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
*/

import io.ktor.client.request.forms.*
import io.ktor.http.Headers
import io.ktor.http.HttpHeaders
import io.ktor.test.dispatcher.*
import io.ktor.utils.io.*
import io.ktor.utils.io.charsets.*
import kotlinx.coroutines.*
import kotlinx.coroutines.test.runTest
import kotlinx.io.*
import kotlinx.io.files.Path
import kotlin.random.Random
import kotlin.test.*

class MultiPartFormDataContentTest {

@Test
fun testMultiPartFormDataContentHasCorrectPrefix() = testSuspend {
fun testMultiPartFormDataContentHasCorrectPrefix() = runTest {
val formData = MultiPartFormDataContent(
formData {
append("Hello", "World")
Expand All @@ -33,7 +38,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testEmptyByteReadChannel() = testSuspend {
fun testEmptyByteReadChannel() = runTest {
val data = MultiPartFormDataContent(
formData {
append("channel", ChannelProvider { ByteReadChannel.Empty })
Expand All @@ -55,7 +60,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testByteReadChannelWithString() = testSuspend {
fun testByteReadChannelWithString() = runTest {
val content = "body"
val data = MultiPartFormDataContent(
formData {
Expand All @@ -79,7 +84,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testNumberQuoted() = testSuspend {
fun testNumberQuoted() = runTest {
val data = MultiPartFormDataContent(
formData {
append("not_a_forty_two", 1337)
Expand All @@ -102,7 +107,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testBooleanQuoted() = testSuspend {
fun testBooleanQuoted() = runTest {
val data = MultiPartFormDataContent(
formData {
append("is_forty_two", false)
Expand All @@ -125,7 +130,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testStringsList() = testSuspend {
fun testStringsList() = runTest {
val data = MultiPartFormDataContent(
formData {
append("platforms[]", listOf("windows", "linux", "osx"))
Expand Down Expand Up @@ -158,7 +163,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testStringsArray() = testSuspend {
fun testStringsArray() = runTest {
val data = MultiPartFormDataContent(
formData {
append("platforms[]", arrayOf("windows", "linux", "osx"))
Expand Down Expand Up @@ -191,7 +196,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testStringsListBadKey() = testSuspend {
fun testStringsListBadKey() = runTest {
val attempt = {
MultiPartFormDataContent(
formData {
Expand All @@ -206,7 +211,7 @@ class MultiPartFormDataContentTest {
}

@Test
fun testByteReadChannelOverBufferSize() = testSuspend {
fun testByteReadChannelOverBufferSize() = runTest {
val body = ByteArray(4089) { 'k'.code.toByte() }
val data = MultiPartFormDataContent(
formData {
Expand All @@ -228,6 +233,36 @@ class MultiPartFormDataContentTest {
)
}

@Test
fun testFileContentFromSource() = runTest {
val expected = "This content should appear in the multipart body."
val fileSource = try {
with(kotlinx.io.files.SystemFileSystem) {
val file = Path(kotlinx.io.files.SystemTemporaryDirectory, "temp${Random.nextInt(1000, 9999)}.txt")
sink(file).buffered().use { it.writeString(expected) }
source(file).buffered()
}
} catch (_: Throwable) {
// filesystem is not supported for web platforms (yet)
return@runTest
}
val data = MultiPartFormDataContent(
formData {
append(
key = "key",
value = fileSource,
headers = Headers.build {
append(HttpHeaders.ContentType, "text/plain")
append(HttpHeaders.ContentDisposition, "filename=\"file.txt\"")
},
)
}
)
assertTrue("File contents should be present in the multipart body.") {
data.readString().contains(expected)
}
}

private suspend fun MultiPartFormDataContent.readString(charset: Charset = Charsets.UTF_8): String {
val bytes = readBytes()
return bytes.decodeToString(0, 0 + bytes.size)
Expand Down
9 changes: 8 additions & 1 deletion ktor-io/common/src/io/ktor/utils/io/core/ByteReadPacket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,15 @@ public fun Source.readAvailable(out: kotlinx.io.Buffer): Int {
return result.toInt()
}

/**
* Returns a copy of the current buffer attached to this Source.
*/
@Deprecated(
"Use peek() or buffer.copy() instead, depending on your use case.",
ReplaceWith("peek()", "kotlinx.io.Source")
)
@OptIn(InternalIoApi::class)
public fun Source.copy(): Source = buffer.copy()
public fun Source.copy(): Source = peek()

@OptIn(InternalIoApi::class)
public fun Source.readShortLittleEndian(): Short {
Expand Down
2 changes: 1 addition & 1 deletion ktor-io/jvm/src/io/ktor/utils/io/streams/Streams.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public fun Source.inputStream(): InputStream = asInputStream()

@OptIn(InternalIoApi::class)
public fun OutputStream.writePacket(packet: Source) {
packet.buffer.copyTo(this)
packet.transferTo(this.asSink())
}

public fun OutputStream.writePacket(block: Sink.() -> Unit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class JettyIdleTimeoutTest : EngineTestBase<JettyApplicationEngine, JettyApplica

override fun configure(configuration: JettyApplicationEngineBase.Configuration) {
super.configure(configuration)
configuration.idleTimeout = 10.milliseconds
configuration.idleTimeout = 100.milliseconds
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class JettyIdleTimeoutTest : EngineTestBase<JettyApplicationEngine, JettyApplica

override fun configure(configuration: JettyApplicationEngineBase.Configuration) {
super.configure(configuration)
configuration.idleTimeout = 10.milliseconds
configuration.idleTimeout = 100.milliseconds
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class HighLoadHttpGenerator(
private val request = RequestResponseBuilder().apply(builder).build()

private val requestByteBuffer = ByteBuffer.allocateDirect(request.remaining.toInt())!!.apply {
request.copy().readFully(this)
request.peek().readFully(this)
clear()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ abstract class SustainabilityTestSuite<TEngine : ApplicationEngine, TConfigurati
emptyLine()
}.build().use { request ->
repeat(repeatCount) {
getOutputStream().writePacket(request.copy())
getOutputStream().writePacket(request.peek())
getOutputStream().write(body)
getOutputStream().flush()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ package io.ktor.websocket.internals
import io.ktor.utils.io.core.*
import kotlinx.io.*

@OptIn(InternalIoApi::class)
internal fun Source.endsWith(data: ByteArray): Boolean {
copy().apply {
buffer.copy().apply {
discard(remaining - data.size)
return readByteArray().contentEquals(data)
}
Expand Down
2 changes: 1 addition & 1 deletion ktor-utils/common/src/io/ktor/util/Base64.kt
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public fun String.decodeBase64Bytes(): ByteArray = buildPacket {
public fun Source.decodeBase64Bytes(): Input = buildPacket {
val data = ByteArray(4)

while (remaining > 0) {
while (!exhausted()) {
val read = readAvailable(data)

val chunk = data.foldIndexed(0) { index, result, current ->
Expand Down
5 changes: 2 additions & 3 deletions ktor-utils/common/src/io/ktor/util/ByteChannels.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.ktor.util

import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import io.ktor.utils.io.pool.*
import kotlinx.coroutines.*

Expand Down Expand Up @@ -64,8 +63,8 @@ public fun ByteReadChannel.copyToBoth(first: ByteWriteChannel, second: ByteWrite
while (!isClosedForRead && (!first.isClosedForWrite || !second.isClosedForWrite)) {
readRemaining(CHUNK_BUFFER_SIZE).use {
try {
first.writePacket(it.copy())
second.writePacket(it.copy())
first.writePacket(it.peek())
second.writePacket(it.peek())
} catch (cause: Throwable) {
this@copyToBoth.cancel(cause)
first.close(cause)
Expand Down

0 comments on commit 9ba9388

Please sign in to comment.