Skip to content

Commit

Permalink
Test webSocketCloseTimeout (square#8317)
Browse files Browse the repository at this point in the history
* Test webSocketCloseTimeout

* Spotless
  • Loading branch information
squarejesse authored Apr 1, 2024
1 parent 6874f11 commit 79aa6fc
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 36 deletions.
36 changes: 36 additions & 0 deletions okhttp-testing-support/src/main/kotlin/okhttp3/FailingCall.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (C) 2024 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3

import okio.Timeout

open class FailingCall : Call {
override fun request(): Request = error("unexpected")

override fun execute(): Response = error("unexpected")

override fun enqueue(responseCallback: Callback): Unit = error("unexpected")

override fun cancel(): Unit = error("unexpected")

override fun isExecuted(): Boolean = error("unexpected")

override fun isCanceled(): Boolean = error("unexpected")

override fun timeout(): Timeout = error("unexpected")

override fun clone(): Call = error("unexpected")
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class RealWebSocket(
private val key: String

/** Non-null for client web sockets. These can be canceled. */
private var call: Call? = null
internal var call: Call? = null

/** This task processes the outgoing queues. Call [runWriter] to after enqueueing. */
private var writerTask: Task? = null
Expand Down
43 changes: 20 additions & 23 deletions okhttp/src/test/java/okhttp3/KotlinSourceModernTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,24 @@ class KotlinSourceModernTest {

@Test
fun call() {
val call: Call = newCall()
val call: Call =
object : Call {
override fun request(): Request = TODO()

override fun execute(): Response = TODO()

override fun enqueue(responseCallback: Callback) = TODO()

override fun cancel() = TODO()

override fun isExecuted(): Boolean = TODO()

override fun isCanceled(): Boolean = TODO()

override fun timeout(): Timeout = TODO()

override fun clone(): Call = TODO()
}
}

@Test
Expand Down Expand Up @@ -734,7 +751,7 @@ class KotlinSourceModernTest {

@Test
fun loggingEventListener() {
var loggingEventListener: EventListener = LoggingEventListener.Factory().create(newCall())
var loggingEventListener: EventListener = LoggingEventListener.Factory().create(FailingCall())
}

@Test
Expand All @@ -745,7 +762,7 @@ class KotlinSourceModernTest {
object : LoggingEventListener.Factory() {
override fun create(call: Call): EventListener = TODO()
}
val eventListener: EventListener = factory.create(newCall())
val eventListener: EventListener = factory.create(FailingCall())
}

@Test
Expand Down Expand Up @@ -1284,26 +1301,6 @@ class KotlinSourceModernTest {
}
}

private fun newCall(): Call {
return object : Call {
override fun request(): Request = TODO()

override fun execute(): Response = TODO()

override fun enqueue(responseCallback: Callback) = TODO()

override fun cancel() = TODO()

override fun isExecuted(): Boolean = TODO()

override fun isCanceled(): Boolean = TODO()

override fun timeout(): Timeout = TODO()

override fun clone(): Call = TODO()
}
}

private fun newCookieHandler(): CookieHandler {
return object : CookieHandler() {
override fun put(
Expand Down
74 changes: 62 additions & 12 deletions okhttp/src/test/java/okhttp3/internal/ws/RealWebSocketTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import java.net.ProtocolException
import java.net.SocketTimeoutException
import java.util.Random
import kotlin.test.assertFailsWith
import okhttp3.FailingCall
import okhttp3.Headers
import okhttp3.Headers.Companion.headersOf
import okhttp3.Protocol
Expand Down Expand Up @@ -58,8 +59,8 @@ class RealWebSocketTest {

@BeforeEach
fun setUp() {
client.initWebSocket(random, 0)
server.initWebSocket(random, 0)
client.initWebSocket(random)
server.initWebSocket(random)
}

@AfterEach
Expand Down Expand Up @@ -197,6 +198,43 @@ class RealWebSocketTest {
client.listener.assertClosed(1000, "Goodbye!")
}

@Test
fun clientCloseCancelsConnectionAfterTimeout() {
client.webSocket!!.close(1000, "Hello!")
taskFaker.runTasks()
// Note: we don't process server frames so our client 'close' doesn't receive a server 'close'.
assertThat(client.canceled).isFalse()

taskFaker.advanceUntil(ns(RealWebSocket.CANCEL_AFTER_CLOSE_MILLIS - 1))
assertThat(client.canceled).isFalse()

taskFaker.advanceUntil(ns(RealWebSocket.CANCEL_AFTER_CLOSE_MILLIS))
assertThat(client.canceled).isTrue()

client.processNextFrame() // This won't get a frame, but it will get a closed pipe.
client.listener.assertFailure(IOException::class.java, "canceled")
taskFaker.runTasks()
}

@Test
fun clientCloseCancelsConnectionAfterCustomTimeout() {
client.initWebSocket(random, webSocketCloseTimeout = 5_000)
client.webSocket!!.close(1000, "Hello!")
taskFaker.runTasks()
// Note: we don't process server frames so our client 'close' doesn't receive a server 'close'.
assertThat(client.canceled).isFalse()

taskFaker.advanceUntil(ns(4_999))
assertThat(client.canceled).isFalse()

taskFaker.advanceUntil(ns(5_000))
assertThat(client.canceled).isTrue()

client.processNextFrame() // This won't get a frame, but it will get a closed pipe.
client.listener.assertFailure(IOException::class.java, "canceled")
taskFaker.runTasks()
}

@Test
fun serverCloseClosesConnection() {
server.webSocket!!.close(1000, "Hello!")
Expand Down Expand Up @@ -333,7 +371,7 @@ class RealWebSocketTest {

@Test
fun pingOnInterval() {
client.initWebSocket(random, 500)
client.initWebSocket(random, pingIntervalMillis = 500)
taskFaker.advanceUntil(ns(500L))
server.processNextFrame() // Ping.
client.processNextFrame() // Pong.
Expand All @@ -347,7 +385,7 @@ class RealWebSocketTest {

@Test
fun unacknowledgedPingFailsConnection() {
client.initWebSocket(random, 500)
client.initWebSocket(random, pingIntervalMillis = 500)

// Don't process the ping and pong frames!
taskFaker.advanceUntil(ns(500L))
Expand All @@ -360,7 +398,7 @@ class RealWebSocketTest {

@Test
fun unexpectedPongsDoNotInterfereWithFailureDetection() {
client.initWebSocket(random, 500)
client.initWebSocket(random, pingIntervalMillis = 500)

// At 0ms the server sends 3 unexpected pongs. The client accepts 'em and ignores em.
server.webSocket!!.pong("pong 1".encodeUtf8())
Expand Down Expand Up @@ -401,8 +439,8 @@ class RealWebSocketTest {
@Test
fun messagesCompressedWhenConfigured() {
val headers = headersOf("Sec-WebSocket-Extensions", "permessage-deflate")
client.initWebSocket(random, 0, headers)
server.initWebSocket(random, 0, headers)
client.initWebSocket(random, responseHeaders = headers)
server.initWebSocket(random, responseHeaders = headers)
val message = repeat('a', RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE.toInt())
server.webSocket!!.send(message)
taskFaker.runTasks()
Expand All @@ -415,8 +453,8 @@ class RealWebSocketTest {
@Test
fun smallMessagesNotCompressed() {
val headers = headersOf("Sec-WebSocket-Extensions", "permessage-deflate")
client.initWebSocket(random, 0, headers)
server.initWebSocket(random, 0, headers)
client.initWebSocket(random, responseHeaders = headers)
server.initWebSocket(random, responseHeaders = headers)
val message = repeat('a', RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE.toInt() - 1)
server.webSocket!!.send(message)
taskFaker.runTasks()
Expand All @@ -437,11 +475,13 @@ class RealWebSocketTest {
val listener = WebSocketRecorder(name)
var webSocket: RealWebSocket? = null
var closed = false
var canceled = false

fun initWebSocket(
random: Random?,
pingIntervalMillis: Int,
pingIntervalMillis: Int = 0,
responseHeaders: Headers? = headersOf(),
webSocketCloseTimeout: Long = RealWebSocket.CANCEL_AFTER_CLOSE_MILLIS,
) {
val url = "http://example.com/websocket"
val response =
Expand All @@ -463,8 +503,17 @@ class RealWebSocketTest {
responseHeaders,
),
RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE,
RealWebSocket.CANCEL_AFTER_CLOSE_MILLIS,
)
webSocketCloseTimeout,
).apply {
if (client) {
call =
object : FailingCall() {
override fun cancel() {
this@TestStreams.cancel()
}
}
}
}
webSocket!!.initReaderAndWriter(name, this)
}

Expand Down Expand Up @@ -497,6 +546,7 @@ class RealWebSocketTest {
}

override fun cancel() {
canceled = true
sourcePipe.cancel()
sinkPipe.cancel()
}
Expand Down

0 comments on commit 79aa6fc

Please sign in to comment.