diff --git a/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt b/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt index d5340c6b..f88c7be6 100644 --- a/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt +++ b/conformance/google-java/src/test/kotlin/com/connectrpc/conformance/Conformance.kt @@ -34,6 +34,7 @@ import com.connectrpc.okhttp.ConnectOkHttpClient import com.connectrpc.protocols.NetworkProtocol import com.google.protobuf.ByteString import com.google.protobuf.Empty +import com.google.protobuf.empty import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking @@ -528,6 +529,179 @@ class Conformance( } } + @Test + fun emptyUnaryCallback(): Unit = runBlocking { + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.emptyCall(empty {}) { response -> + response.failure { + fail("expected error to be null") + } + response.success { success -> + assertThat(success.message).isEqualTo(empty {}) + countDownLatch.countDown() + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + @Test + fun largeUnaryCallback(): Unit = runBlocking { + val size = 314159 + val message = simpleRequest { + responseSize = size + payload = payload { + body = ByteString.copyFrom(ByteArray(size)) + } + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall(message) { response -> + response.failure { + fail("expected error to be null") + } + response.success { success -> + assertThat(success.message.payload?.body?.toByteArray()?.size).isEqualTo(size) + countDownLatch.countDown() + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + @Test + fun customMetadataCallback(): Unit = runBlocking { + val size = 314159 + val leadingKey = "x-grpc-test-echo-initial" + val leadingValue = "test_initial_metadata_value" + val trailingKey = "x-grpc-test-echo-trailing-bin" + val trailingValue = byteArrayOf(0xab.toByte(), 0xab.toByte(), 0xab.toByte()) + val headers = + mapOf( + leadingKey to listOf(leadingValue), + trailingKey to listOf(b64Encode(trailingValue)) + ) + val message = simpleRequest { + responseSize = size + payload = payload { body = ByteString.copyFrom(ByteArray(size)) } + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall(message, headers) { response -> + assertThat(response.code).isEqualTo(Code.OK) + assertThat(response.headers[leadingKey]).containsExactly(leadingValue) + assertThat(response.trailers[trailingKey]).containsExactly(b64Encode(trailingValue)) + response.failure { + fail("expected error to be null") + } + response.success { success -> + assertThat(success.message.payload!!.body!!.size()).isEqualTo(size) + countDownLatch.countDown() + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + @Test + fun statusCodeAndMessageCallback(): Unit = runBlocking { + val message = simpleRequest { + responseStatus = echoStatus { + code = Code.UNKNOWN.value + message = "test status message" + } + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall(message) { response -> + assertThat(response.code).isEqualTo(Code.UNKNOWN) + response.failure { errorResponse -> + assertThat(errorResponse.error).isNotNull() + assertThat(errorResponse.code).isEqualTo(Code.UNKNOWN) + assertThat(errorResponse.error.message).isEqualTo("test status message") + countDownLatch.countDown() + } + response.success { + fail("unexpected success") + } + } + + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + @Test + fun specialStatusCallback(): Unit = runBlocking { + val statusMessage = + "\\t\\ntest with whitespace\\r\\nand Unicode BMP ☺ and non-BMP \uD83D\uDE08\\t\\n" + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall( + simpleRequest { + responseStatus = echoStatus { + code = 2 + message = statusMessage + } + } + ) { response -> + response.failure { errorResponse -> + val error = errorResponse.error + assertThat(error.code).isEqualTo(Code.UNKNOWN) + assertThat(response.code).isEqualTo(Code.UNKNOWN) + assertThat(error.message).isEqualTo(statusMessage) + countDownLatch.countDown() + } + response.success { + fail("unexpected success") + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + @Test + fun unimplementedMethodCallback(): Unit = runBlocking { + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unimplementedCall(empty {}) { response -> + assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED) + countDownLatch.countDown() + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + @Test + fun unimplementedServiceCallback(): Unit = runBlocking { + val countDownLatch = CountDownLatch(1) + unimplementedServiceClient.unimplementedCall(empty {}) { response -> + assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED) + countDownLatch.countDown() + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + @Test + fun failUnaryCallback(): Unit = runBlocking { + val expectedErrorDetail = errorDetail { + reason = "soirée 🎉" + domain = "connect-conformance" + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.failUnaryCall(simpleRequest {}) { response -> + assertThat(response.code).isEqualTo(Code.RESOURCE_EXHAUSTED) + response.failure { errorResponse -> + val error = errorResponse.error + assertThat(error.code).isEqualTo(Code.RESOURCE_EXHAUSTED) + assertThat(error.message).isEqualTo("soirée 🎉") + val connectErrorDetails = error.unpackedDetails(ErrorDetail::class) + assertThat(connectErrorDetails).containsExactly(expectedErrorDetail) + countDownLatch.countDown() + } + response.success { + fail("unexpected success") + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + private fun b64Encode(trailingValue: ByteArray): String { return String(Base64.getEncoder().encode(trailingValue)) }