diff --git a/ktor-http/common/src/io/ktor/http/auth/AuthScheme.kt b/ktor-http/common/src/io/ktor/http/auth/AuthScheme.kt index a850008c340..475c206d04c 100644 --- a/ktor-http/common/src/io/ktor/http/auth/AuthScheme.kt +++ b/ktor-http/common/src/io/ktor/http/auth/AuthScheme.kt @@ -49,7 +49,7 @@ public object AuthScheme { /** * Bearer Authentication described in the RFC-6749 & RFC6750: * - * see https://tools.ietf.org/html/rfc6749 + * see https://tools.ietf.org/html/rfc6749 * & https://tools.ietf.org/html/rfc6750 */ public const val Bearer: String = "Bearer" diff --git a/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt b/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt index 9e79138813c..58e346c9950 100644 --- a/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt +++ b/ktor-server/ktor-server-test-host/jvm/test/TestApplicationTestJvm.kt @@ -4,9 +4,13 @@ package io.ktor.tests.server.testing +import io.ktor.client.network.sockets.* +import io.ktor.client.plugins.* import io.ktor.client.plugins.websocket.* import io.ktor.client.request.* import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.http.content.* import io.ktor.server.application.* import io.ktor.server.config.* import io.ktor.server.response.* @@ -14,6 +18,7 @@ import io.ktor.server.routing.* import io.ktor.server.testing.* import io.ktor.server.websocket.* import io.ktor.util.* +import io.ktor.utils.io.* import io.ktor.websocket.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* @@ -283,6 +288,54 @@ class TestApplicationTestJvm { } assertEquals("WebSocket connection failed", error.message) } + + private fun testSocketTimeoutWrite(timeout: Long, expectException: Boolean) = testApplication { + routing { + post { + call.respond(HttpStatusCode.OK, call.request.receiveChannel().readRemaining().toString()) + } + } + + val clientWithTimeout = createClient { + install(HttpTimeout) { + socketTimeoutMillis = timeout + } + } + + val body = object : OutgoingContent.WriteChannelContent() { + override suspend fun writeTo(channel: ByteWriteChannel) { + channel.writeAvailable("Hello".toByteArray()) + channel.flush() + delay(300) + channel.writeAvailable("World".toByteArray()) + channel.flush() + } + } + + if (expectException) { + assertFailsWith { + clientWithTimeout.post("/") { + setBody(body) + } + } + } else { + clientWithTimeout.post("/") { + setBody(body) + }.apply { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun testSocketTimeoutWriteElapsed() { + testSocketTimeoutWrite(100, true) + } + + @Test + fun testSocketTimeoutWriteNotElapsed() { + testSocketTimeoutWrite(1000, false) + } } class TestClass(val value: Int) : Serializable diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt index 7cd9d0935a0..f65304fdfef 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationEngine.kt @@ -6,6 +6,7 @@ package io.ktor.server.testing import io.ktor.client.* import io.ktor.client.engine.* +import io.ktor.client.plugins.* import io.ktor.http.* import io.ktor.server.application.* import io.ktor.server.engine.* @@ -203,7 +204,7 @@ class TestApplicationEngine( setup: TestApplicationRequest.() -> Unit ): TestApplicationCall { val callJob = GlobalScope.async(coroutineContext) { - handleRequestNonBlocking(closeRequest, setup) + handleRequestNonBlocking(closeRequest, timeoutAttributes = null, setup) } return runBlocking { callJob.await() } @@ -211,6 +212,7 @@ class TestApplicationEngine( internal suspend fun handleRequestNonBlocking( closeRequest: Boolean = true, + timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration? = null, setup: TestApplicationRequest.() -> Unit ): TestApplicationCall { val job = Job(testEngineJob) @@ -220,6 +222,9 @@ class TestApplicationEngine( setup = { processRequest(setup) }, context = Dispatchers.IOBridge + job ) + if (timeoutAttributes != null) { + call.attributes.put(timeoutAttributesKey, timeoutAttributes) + } val context = SupervisorJob(job) + CoroutineName("request") withContext(coroutineContext + context) { @@ -306,3 +311,5 @@ fun TestApplicationEngine.cookiesSession(callback: () -> Unit) { callback() } } + +internal val timeoutAttributesKey = AttributeKey("TimeoutAttributes") diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt index bb30e11a464..508cb7a517f 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/TestApplicationResponse.kt @@ -26,6 +26,9 @@ public class TestApplicationResponse( call: TestApplicationCall, private val readResponse: Boolean = false ) : BaseApplicationResponse(call), CoroutineScope by call { + private val scope: CoroutineScope get() = this + + private val timeoutAttributes get() = call.attributes.getOrNull(timeoutAttributesKey) /** * Gets a response body text content. Could be blocking. Remains `null` until response appears. @@ -76,7 +79,6 @@ public class TestApplicationResponse( } @Suppress("DEPRECATION") - @OptIn(DelicateCoroutinesApi::class) override suspend fun responseChannel(): ByteWriteChannel { val result = ByteChannel(autoFlush = true) @@ -84,8 +86,12 @@ public class TestApplicationResponse( launchResponseJob(result) } - val job = GlobalScope.reader(responseJob ?: EmptyCoroutineContext) { - channel.copyAndClose(result, Long.MAX_VALUE) + val job = scope.reader(responseJob ?: EmptyCoroutineContext) { + val readJob = launch { + channel.copyAndClose(result, Long.MAX_VALUE) + } + + configureSocketTimeoutIfNeeded(timeoutAttributes, readJob) { channel.totalBytesRead } } if (responseJob == null) { diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt index f573e07189e..401c09d1401 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/Utils.kt @@ -4,7 +4,12 @@ package io.ktor.server.testing +import io.ktor.client.plugins.* +import io.ktor.client.request.* import io.ktor.http.* +import io.ktor.util.* +import io.ktor.utils.io.* +import kotlinx.coroutines.* /** * [on] function receiver object @@ -35,3 +40,44 @@ fun TestApplicationResponse.contentType(): ContentType { val contentTypeHeader = requireNotNull(headers[HttpHeaders.ContentType]) return ContentType.parse(contentTypeHeader) } + +internal fun CoroutineScope.configureSocketTimeoutIfNeeded( + timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration?, + job: Job, + extract: () -> Long +) { + val socketTimeoutMillis = timeoutAttributes?.socketTimeoutMillis + if (socketTimeoutMillis != null) { + socketTimeoutKiller(socketTimeoutMillis, job, extract) + } +} + +internal fun CoroutineScope.socketTimeoutKiller(socketTimeoutMillis: Long, job: Job, extract: () -> Long) { + val killJob = launch { + var cur = extract() + while (job.isActive) { + delay(socketTimeoutMillis) + val next = extract() + if (cur == next) { + throw io.ktor.network.sockets.SocketTimeoutException("Socket timeout elapsed") + } + cur = next + } + } + job.invokeOnCompletion { + killJob.cancel() + } +} + +@OptIn(InternalAPI::class) +internal fun Throwable.mapToKtor(data: HttpRequestData): Throwable { + return when { + this is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException(data, this) + cause?.rootCause is io.ktor.network.sockets.SocketTimeoutException -> SocketTimeoutException( + data, + cause?.rootCause + ) + + else -> this + } +} diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt index 872c807f3b8..225df2ab395 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/src/io/ktor/server/testing/client/TestHttpClientEngine.kt @@ -6,6 +6,7 @@ package io.ktor.server.testing.client import io.ktor.client.call.* import io.ktor.client.engine.* +import io.ktor.client.plugins.* import io.ktor.client.request.* import io.ktor.http.* import io.ktor.http.content.* @@ -41,22 +42,26 @@ class TestHttpClientEngine(override val config: TestHttpClientConfig) : HttpClie @OptIn(InternalAPI::class) override suspend fun execute(data: HttpRequestData): HttpResponseData { - app.start() - if (data.isUpgradeRequest()) { - val (testServerCall, session) = with(data) { - bridge.runWebSocketRequest(url.fullPath, headers, body, callContext()) + try { + app.start() + if (data.isUpgradeRequest()) { + val (testServerCall, session) = with(data) { + bridge.runWebSocketRequest(url.fullPath, headers, body, callContext()) + } + return with(testServerCall.response) { + httpResponseData(session) + } } - return with(testServerCall.response) { - httpResponseData(session) - } - } - val testServerCall = with(data) { - runRequest(method, url, headers, body, url.protocol) - } + val testServerCall = with(data) { + runRequest(method, url, headers, body, url.protocol, data.getCapabilityOrNull(HttpTimeout)) + } - return with(testServerCall.response) { - httpResponseData(ByteReadChannel(byteContent ?: byteArrayOf())) + return with(testServerCall.response) { + httpResponseData(ByteReadChannel(byteContent ?: byteArrayOf())) + } + } catch (cause: Throwable) { + throw cause.mapToKtor(data) } } @@ -65,9 +70,10 @@ class TestHttpClientEngine(override val config: TestHttpClientConfig) : HttpClie url: Url, headers: Headers, content: OutgoingContent, - protocol: URLProtocol + protocol: URLProtocol, + timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration? = null ): TestApplicationCall { - return app.handleRequestNonBlocking { + return app.handleRequestNonBlocking(timeoutAttributes = timeoutAttributes) { this.uri = url.fullPath this.port = url.port this.method = method @@ -75,7 +81,7 @@ class TestHttpClientEngine(override val config: TestHttpClientConfig) : HttpClie this.protocol = protocol.name if (content !is OutgoingContent.NoContent) { - bodyChannel = content.toByteReadChannel() + bodyChannel = content.toByteReadChannel(timeoutAttributes) } } } @@ -112,14 +118,20 @@ class TestHttpClientEngine(override val config: TestHttpClientConfig) : HttpClie } } - @Suppress("DEPRECATION") - private fun OutgoingContent.toByteReadChannel(): ByteReadChannel = when (this) { - is OutgoingContent.NoContent -> ByteReadChannel.Empty - is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes()) - is OutgoingContent.ReadChannelContent -> readFrom() - is OutgoingContent.WriteChannelContent -> writer(coroutineContext) { - writeTo(channel) - }.channel - is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this) - } + private fun OutgoingContent.toByteReadChannel( + timeoutAttributes: HttpTimeout.HttpTimeoutCapabilityConfiguration? + ): ByteReadChannel = + when (this) { + is OutgoingContent.NoContent -> ByteReadChannel.Empty + is OutgoingContent.ByteArrayContent -> ByteReadChannel(bytes()) + is OutgoingContent.ReadChannelContent -> readFrom() + is OutgoingContent.WriteChannelContent -> writer(coroutineContext) { + val job = launch { + writeTo(channel) + } + + configureSocketTimeoutIfNeeded(timeoutAttributes, job) { channel.totalBytesWritten } + }.channel + is OutgoingContent.ProtocolUpgrade -> throw UnsupportedContentTypeException(this) + } } diff --git a/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt b/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt index 8f208cd8e5c..560a34c68c5 100644 --- a/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt +++ b/ktor-server/ktor-server-test-host/jvmAndNix/test/TestApplicationTest.kt @@ -5,6 +5,7 @@ package io.ktor.tests.server.testing import io.ktor.client.* +import io.ktor.client.network.sockets.* import io.ktor.client.plugins.* import io.ktor.client.request.* import io.ktor.client.statement.* @@ -21,6 +22,7 @@ import io.ktor.server.testing.* import io.ktor.server.testing.client.* import io.ktor.util.* import io.ktor.utils.io.* +import io.ktor.utils.io.core.* import kotlinx.coroutines.* import kotlin.coroutines.* import kotlin.test.* @@ -398,6 +400,49 @@ class TestApplicationTest { } } + private fun testSocketTimeoutRead(timeout: Long, expectException: Boolean) = testApplication { + routing { + get { + call.respond( + HttpStatusCode.OK, + object : OutgoingContent.WriteChannelContent() { + override suspend fun writeTo(channel: ByteWriteChannel) { + channel.writeAvailable("Hello".toByteArray()) + channel.flush() + delay(300) + } + } + ) + } + } + + val clientWithTimeout = createClient { + install(HttpTimeout) { + socketTimeoutMillis = timeout + } + } + + if (expectException) { + assertFailsWith { + clientWithTimeout.get("/") + } + } else { + clientWithTimeout.get("/").apply { + assertEquals(HttpStatusCode.OK, status) + } + } + } + + @Test + fun testSocketTimeoutReadElapsed() { + testSocketTimeoutRead(100, true) + } + + @Test + fun testSocketTimeoutReadNotElapsed() { + testSocketTimeoutRead(1000, false) + } + class MyElement(val data: String) : CoroutineContext.Element { override val key: CoroutineContext.Key<*> get() = MyElement