diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 7f3e993e..79aa94bc 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -71,7 +71,7 @@ class ProtocolClient( ) val unaryFunc = config.createInterceptorChain() val finalRequest = unaryFunc.requestFunction(unaryRequest) - val cancelable = httpClient.unary(finalRequest) { httpResponse -> + val cancelable = httpClient.unary(finalRequest) httpClientUnary@{ httpResponse -> val finalResponse = unaryFunc.responseFunction(httpResponse) val code = finalResponse.code val exception = finalResponse.cause?.setErrorParser(serializationStrategy.errorDetailParser()) @@ -84,20 +84,31 @@ class ProtocolClient( finalResponse.trailers, ), ) - } else { - val responseCodec = serializationStrategy.codec(methodSpec.responseClass) - val responseMessage = responseCodec.deserialize( - finalResponse.message, - ) + return@httpClientUnary + } + val responseCodec = serializationStrategy.codec(methodSpec.responseClass) + val responseMessage: Output + try { + responseMessage = responseCodec.deserialize(finalResponse.message) + } catch (e: Exception) { onResult( - ResponseMessage.Success( - responseMessage, - code, + ResponseMessage.Failure( + ConnectException(code = Code.INTERNAL_ERROR, exception = e), + Code.INTERNAL_ERROR, finalResponse.headers, finalResponse.trailers, ), ) + return@httpClientUnary } + onResult( + ResponseMessage.Success( + responseMessage, + code, + finalResponse.headers, + finalResponse.trailers, + ), + ) } return cancelable } catch (e: Exception) { diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index 30299d2c..fbc96c09 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -95,13 +95,33 @@ internal class ConnectInterceptor( val responseHeaders = response.headers.filter { entry -> !entry.key.startsWith("trailer-") } val compressionPool = clientConfig.compressionPool(responseHeaders[CONTENT_ENCODING]?.first()) + val responseBody = try { + compressionPool?.decompress(response.message.buffer) ?: response.message.buffer + } catch (e: Exception) { + return@UnaryFunction HTTPResponse( + code = Code.INTERNAL_ERROR, + message = Buffer(), + headers = responseHeaders, + trailers = trailers, + cause = ConnectException( + code = Code.INTERNAL_ERROR, + errorDetailParser = serializationStrategy.errorDetailParser(), + message = e.message, + exception = e, + ), + tracingInfo = response.tracingInfo, + ) + } + val message: Buffer val (code, exception) = if (response.code != Code.OK) { - val error = parseConnectUnaryException(code = response.code, response.headers, response.message.buffer) + val error = parseConnectUnaryException(code = response.code, response.headers, responseBody) + // We've already read the response body to parse an error - don't read again. + message = Buffer() error.code to error } else { + message = responseBody response.code to null } - val message = compressionPool?.decompress(response.message.buffer) ?: response.message.buffer HTTPResponse( code = code, message = message, @@ -122,18 +142,12 @@ internal class ConnectInterceptor( mutableMapOf(CONNECT_PROTOCOL_VERSION_KEY to listOf(CONNECT_PROTOCOL_VERSION_VALUE)) requestHeaders.putAll(request.headers) if (requestCompression != null) { - requestHeaders.put( - CONNECT_STREAMING_CONTENT_ENCODING, - listOf(requestCompression.compressionPool.name()), - ) + requestHeaders[CONNECT_STREAMING_CONTENT_ENCODING] = listOf(requestCompression.compressionPool.name()) } if (requestHeaders.keys.none { it.equals(USER_AGENT, ignoreCase = true) }) { requestHeaders[USER_AGENT] = listOf("connect-kotlin/${ConnectConstants.VERSION}") } - requestHeaders.put( - CONNECT_STREAMING_ACCEPT_ENCODING, - clientConfig.compressionPools().map { entry -> entry.name() }, - ) + requestHeaders[CONNECT_STREAMING_ACCEPT_ENCODING] = clientConfig.compressionPools().map { entry -> entry.name() } request.clone( url = request.url, contentType = request.contentType, @@ -246,7 +260,7 @@ internal class ConnectInterceptor( serializationStrategy.errorDetailParser(), errorJSON, ) - } catch (e: Throwable) { + } catch (e: Exception) { return ConnectException(code, serializationStrategy.errorDetailParser(), errorJSON) } val errorDetails = parseErrorDetails(errorPayloadJSON) diff --git a/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerTests.kt b/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerTests.kt index 5651df42..29b4736e 100644 --- a/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerTests.kt +++ b/okhttp/src/test/kotlin/com/connectrpc/okhttp/MockWebServerTests.kt @@ -17,9 +17,11 @@ package com.connectrpc.okhttp import com.connectrpc.Code import com.connectrpc.ProtocolClientConfig import com.connectrpc.RequestCompression +import com.connectrpc.SerializationStrategy import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.eliza.v1.ElizaServiceClient import com.connectrpc.eliza.v1.sayRequest +import com.connectrpc.extensions.GoogleJavaJSONStrategy import com.connectrpc.extensions.GoogleJavaProtobufStrategy import com.connectrpc.impl.ProtocolClient import com.connectrpc.protocols.NetworkProtocol @@ -31,12 +33,16 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.Rule import org.junit.Test +/** + * Tests to exercise end to end failure cases not easily verified with conformance tests. + * Over time these may be moved to conformance tests. + */ class MockWebServerTests { @get:Rule val mockWebServerRule = MockWebServerRule() @Test - fun `compressed empty failure response is parsed correctly`() = runTest { + fun `invalid compressed failure response is handled correctly`() = runTest { mockWebServerRule.server.enqueue( MockResponse().apply { addHeader("accept-encoding", "gzip") @@ -45,9 +51,66 @@ class MockWebServerTests { setResponseCode(401) }, ) + val response = createClient().say(sayRequest { sentence = "hello" }) + mockWebServerRule.server.takeRequest().apply { + assertThat(path).isEqualTo("/connectrpc.eliza.v1.ElizaService/Say") + } + assertThat(response.code).isEqualTo(Code.INTERNAL_ERROR) + } - val host = mockWebServerRule.server.url("/") + @Test + fun `invalid compressed response data is handled correctly`() = runTest { + mockWebServerRule.server.enqueue( + MockResponse().apply { + addHeader("accept-encoding", "gzip") + addHeader("content-encoding", "gzip") + setBody("this isn't gzipped") + setResponseCode(200) + }, + ) + val response = createClient().say(sayRequest { sentence = "hello" }) + mockWebServerRule.server.takeRequest().apply { + assertThat(path).isEqualTo("/connectrpc.eliza.v1.ElizaService/Say") + } + assertThat(response.code).isEqualTo(Code.INTERNAL_ERROR) + } + @Test + fun `invalid protobuf response data is handled correctly`() = runTest { + mockWebServerRule.server.enqueue( + MockResponse().apply { + addHeader("accept-encoding", "gzip") + addHeader("content-type", "application/proto") + setBody("this isn't valid protobuf") + setResponseCode(200) + }, + ) + val response = createClient().say(sayRequest { sentence = "hello" }) + mockWebServerRule.server.takeRequest().apply { + assertThat(path).isEqualTo("/connectrpc.eliza.v1.ElizaService/Say") + } + assertThat(response.code).isEqualTo(Code.INTERNAL_ERROR) + } + + @Test + fun `invalid json response data is handled correctly`() = runTest { + mockWebServerRule.server.enqueue( + MockResponse().apply { + addHeader("accept-encoding", "gzip") + addHeader("content-type", "application/json") + setBody("{ invalid json") + setResponseCode(200) + }, + ) + val response = createClient(serializationStrategy = GoogleJavaJSONStrategy()).say(sayRequest { sentence = "hello" }) + mockWebServerRule.server.takeRequest().apply { + assertThat(path).isEqualTo("/connectrpc.eliza.v1.ElizaService/Say") + } + assertThat(response.code).isEqualTo(Code.INTERNAL_ERROR) + } + + private fun createClient(serializationStrategy: SerializationStrategy = GoogleJavaProtobufStrategy()): ElizaServiceClient { + val host = mockWebServerRule.server.url("/") val protocolClient = ProtocolClient( ConnectOkHttpClient( OkHttpClient.Builder() @@ -56,19 +119,11 @@ class MockWebServerTests { ), ProtocolClientConfig( host = host.toString(), - serializationStrategy = GoogleJavaProtobufStrategy(), + serializationStrategy = serializationStrategy, networkProtocol = NetworkProtocol.CONNECT, requestCompression = RequestCompression(0, GzipCompressionPool), - compressionPools = listOf(GzipCompressionPool), ), ) - - val response = ElizaServiceClient(protocolClient).say(sayRequest { sentence = "hello" }) - - mockWebServerRule.server.takeRequest().apply { - assertThat(path).isEqualTo("/connectrpc.eliza.v1.ElizaService/Say") - } - - assertThat(response.code).isEqualTo(Code.UNKNOWN) + return ElizaServiceClient(protocolClient) } }