diff --git a/README.md b/README.md index 7da8bef5..c7baff31 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,13 @@ Comprehensive documentation for everything, including [interceptors][interceptors], [streaming][streaming], and [error handling][error-handling] is available on the [connect.build website][getting-started]. +## Generation Options + +| **Option** | **Type** | **Default** | **Repeatable** | **Details** | +|----------------------------|:--------:|:-----------:|:--------------:|-------------------------------------------------| +| `generateCallbackMethods` | Boolean | `false` | No | Generate callback signatures for unary methods. | +| `generateCoroutineMethods` | Boolean | `true` | No | Generate suspend signatures for unary methods. | + ## Example Apps Example apps are available in [`/examples`](./examples). First, run `make generate` to generate diff --git a/crosstests/buf.gen.yaml b/crosstests/buf.gen.yaml index 02063d57..6c2ab266 100644 --- a/crosstests/buf.gen.yaml +++ b/crosstests/buf.gen.yaml @@ -6,6 +6,9 @@ plugins: - name: connect-kotlin out: google-java/src/main/kotlin/generated path: ./protoc-gen-connect-kotlin/protoc-gen-connect-kotlin + opt: + - generateCallbackMethods=true + - generateCoroutineMethods=true - name: java out: google-java/src/main/java/generated - name: kotlin diff --git a/crosstests/common/src/main/kotlin/build/buf/connect/crosstest/ssl/TestSuite.kt b/crosstests/common/src/main/kotlin/build/buf/connect/crosstest/ssl/TestSuite.kt index 9d189aa3..06636683 100644 --- a/crosstests/common/src/main/kotlin/build/buf/connect/crosstest/ssl/TestSuite.kt +++ b/crosstests/common/src/main/kotlin/build/buf/connect/crosstest/ssl/TestSuite.kt @@ -38,3 +38,15 @@ interface TestSuite { suspend fun failUnary() suspend fun failServerStreaming() } + +interface UnaryCallbackTestSuite { + suspend fun test(tag: String) + suspend fun emptyUnary() + suspend fun largeUnary() + suspend fun customMetadata() + suspend fun statusCodeAndMessage() + suspend fun specialStatus() + suspend fun unimplementedMethod() + suspend fun unimplementedService() + suspend fun failUnary() +} diff --git a/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/Main.kt b/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/Main.kt index 9724e3ee..40848cc4 100644 --- a/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/Main.kt +++ b/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/Main.kt @@ -87,10 +87,11 @@ class Main { compressionPools = listOf(GzipCompressionPool) ) ) - tests(tag, connectClient, shortTimeoutClient) + coroutineTests(tag, connectClient, shortTimeoutClient) + callbackTests(tag, connectClient) } - private suspend fun tests( + private suspend fun coroutineTests( tag: String, protocolClient: ProtocolClient, shortTimeoutClient: ProtocolClient @@ -114,5 +115,22 @@ class Main { testServiceClientSuite.test(tag) } + + private suspend fun callbackTests( + tag: String, + protocolClient: ProtocolClient + ) { + val testServiceClientSuite = TestServiceClientCallbackSuite(protocolClient) + testServiceClientSuite.emptyUnary() + testServiceClientSuite.largeUnary() + testServiceClientSuite.customMetadata() + testServiceClientSuite.statusCodeAndMessage() + testServiceClientSuite.specialStatus() + testServiceClientSuite.unimplementedMethod() + testServiceClientSuite.unimplementedService() + testServiceClientSuite.failUnary() + + testServiceClientSuite.test(tag) + } } } diff --git a/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientCallbackSuite.kt b/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientCallbackSuite.kt new file mode 100644 index 00000000..f55ab9f6 --- /dev/null +++ b/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientCallbackSuite.kt @@ -0,0 +1,222 @@ +// Copyright 2022-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package build.buf.connect.crosstest + +import build.buf.connect.Code +import build.buf.connect.crosstest.ssl.UnaryCallbackTestSuite +import build.buf.connect.impl.ProtocolClient +import com.google.protobuf.ByteString +import com.grpc.testing.ErrorDetail +import com.grpc.testing.TestServiceClient +import com.grpc.testing.UnimplementedServiceClient +import com.grpc.testing.echoStatus +import com.grpc.testing.empty +import com.grpc.testing.errorDetail +import com.grpc.testing.payload +import com.grpc.testing.simpleRequest +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.fail +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import kotlin.system.measureTimeMillis + +class TestServiceClientCallbackSuite( + client: ProtocolClient +) : UnaryCallbackTestSuite { + private val testServiceConnectClient = TestServiceClient(client) + private val unimplementedServiceClient = UnimplementedServiceClient(client) + + private val tests = mutableListOf Unit>>() + + override suspend fun test(tag: String) { + println() + tests.forEachIndexed { index, (testName, test) -> + print("[$tag] Executing test case ${index + 1}/${tests.size}: $testName") + val millis = measureTimeMillis { + test() + } + println(" [$millis ms]") + } + } + + private fun register(testName: String, test: suspend () -> Unit) { + tests.add(testName to test) + } + + override suspend fun emptyUnary() = register("empty_unary") { + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.emptyCall(empty {}) { response -> + response.failure { + fail("expected error to be null") + } + response.success { success -> + assertThat(success.message).isEqualTo(empty {}) + countDownLatch.countDown() + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + override suspend fun largeUnary() = register("large_unary") { + val size = 314159 + val message = simpleRequest { + responseSize = size + payload = payload { + body = ByteString.copyFrom(ByteArray(size)) + } + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall(message) { response -> + response.failure { + fail("expected error to be null") + } + response.success { success -> + assertThat(success.message.payload?.body?.toByteArray()?.size).isEqualTo(size) + countDownLatch.countDown() + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + override suspend fun customMetadata() = register("custom_metadata") { + val size = 314159 + val leadingKey = "x-grpc-test-echo-initial" + val leadingValue = "test_initial_metadata_value" + val trailingKey = "x-grpc-test-echo-trailing-bin" + val trailingValue = byteArrayOf(0xab.toByte(), 0xab.toByte(), 0xab.toByte()) + val headers = + mapOf( + leadingKey to listOf(leadingValue), + trailingKey to listOf(trailingValue.b64Encode()) + ) + val message = simpleRequest { + responseSize = size + payload = payload { body = ByteString.copyFrom(ByteArray(size)) } + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall(message, headers) { response -> + assertThat(response.code).isEqualTo(Code.OK) + assertThat(response.headers[leadingKey]).containsExactly(leadingValue) + assertThat(response.trailers[trailingKey]).containsExactly(trailingValue.b64Encode()) + response.failure { + fail("expected error to be null") + } + response.success { success -> + assertThat(success.message.payload!!.body!!.size()).isEqualTo(size) + countDownLatch.countDown() + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + override suspend fun statusCodeAndMessage() = register("status_code_and_message") { + val message = simpleRequest { + responseStatus = echoStatus { + code = Code.UNKNOWN.value + message = "test status message" + } + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall(message) { response -> + assertThat(response.code).isEqualTo(Code.UNKNOWN) + response.failure { errorResponse -> + assertThat(errorResponse.error).isNotNull() + assertThat(errorResponse.code).isEqualTo(Code.UNKNOWN) + assertThat(errorResponse.error.message).isEqualTo("test status message") + countDownLatch.countDown() + } + response.success { + fail("unexpected success") + } + } + + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + override suspend fun specialStatus() = register("special_status") { + val statusMessage = + "\\t\\ntest with whitespace\\r\\nand Unicode BMP ☺ and non-BMP \uD83D\uDE08\\t\\n" + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unaryCall( + simpleRequest { + responseStatus = echoStatus { + code = 2 + message = statusMessage + } + } + ) { response -> + response.failure { errorResponse -> + val error = errorResponse.error + assertThat(error.code).isEqualTo(Code.UNKNOWN) + assertThat(response.code).isEqualTo(Code.UNKNOWN) + assertThat(error.message).isEqualTo(statusMessage) + countDownLatch.countDown() + } + response.success { + fail("unexpected success") + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + override suspend fun unimplementedMethod() = register("unimplemented_method") { + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.unimplementedCall(empty {}) { response -> + assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED) + countDownLatch.countDown() + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + override suspend fun unimplementedService() = register("unimplemented_service") { + val countDownLatch = CountDownLatch(1) + unimplementedServiceClient.unimplementedCall(empty {}) { response -> + assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED) + countDownLatch.countDown() + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } + + override suspend fun failUnary() = register("fail_unary") { + val expectedErrorDetail = errorDetail { + reason = "soirée 🎉" + domain = "connect-crosstest" + } + val countDownLatch = CountDownLatch(1) + testServiceConnectClient.failUnaryCall(simpleRequest {}) { response -> + assertThat(response.code).isEqualTo(Code.RESOURCE_EXHAUSTED) + response.failure { errorResponse -> + val error = errorResponse.error + assertThat(error.code).isEqualTo(Code.RESOURCE_EXHAUSTED) + assertThat(error.message).isEqualTo("soirée 🎉") + val connectErrorDetails = error.unpackedDetails(ErrorDetail::class) + assertThat(connectErrorDetails).containsExactly(expectedErrorDetail) + countDownLatch.countDown() + } + response.success { + fail("unexpected success") + } + } + countDownLatch.await(500, TimeUnit.MILLISECONDS) + assertThat(countDownLatch.count).isZero() + } +} diff --git a/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt b/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt index cb976c5b..2bee7ae0 100644 --- a/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt +++ b/crosstests/google-java/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt @@ -428,6 +428,6 @@ class TestServiceClientSuite( } } -private fun ByteArray.b64Encode(): String { +internal fun ByteArray.b64Encode(): String { return this.toByteString().base64() } diff --git a/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/Main.kt b/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/Main.kt index 85febc7e..c6fe3ce0 100644 --- a/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/Main.kt +++ b/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/Main.kt @@ -87,10 +87,10 @@ class Main { compressionPools = listOf(GzipCompressionPool) ) ) - tests(tag, connectClient, shortTimeoutClient) + suspendTests(tag, connectClient, shortTimeoutClient) } - private suspend fun tests( + private suspend fun suspendTests( tag: String, protocolClient: ProtocolClient, shortTimeoutClient: ProtocolClient diff --git a/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt b/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt index cb976c5b..2bee7ae0 100644 --- a/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt +++ b/crosstests/google-javalite/src/main/kotlin/build/buf/connect/crosstest/TestServiceClientSuite.kt @@ -428,6 +428,6 @@ class TestServiceClientSuite( } } -private fun ByteArray.b64Encode(): String { +internal fun ByteArray.b64Encode(): String { return this.toByteString().base64() } diff --git a/protoc-gen-connect-kotlin/buf.gen.yaml b/protoc-gen-connect-kotlin/buf.gen.yaml index fae9725d..fd60ff0b 100644 --- a/protoc-gen-connect-kotlin/buf.gen.yaml +++ b/protoc-gen-connect-kotlin/buf.gen.yaml @@ -3,5 +3,9 @@ plugins: - name: connect-kotlin out: src/test/java/ path: ./protoc-gen-connect-kotlin/protoc-gen-connect-kotlin + opt: + - generateCallbackMethods=true + - generateCoroutineMethods=true - name: java out: src/test/java/ + diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/Generator.kt b/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/Generator.kt index 3c10e6fe..91f32ab1 100644 --- a/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/Generator.kt +++ b/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/Generator.kt @@ -21,9 +21,11 @@ import build.buf.connect.ProtocolClientInterface import build.buf.connect.ResponseMessage import build.buf.connect.ServerOnlyStreamInterface import build.buf.protocgen.connect.internal.CodeGenerator +import build.buf.protocgen.connect.internal.Configuration import build.buf.protocgen.connect.internal.Plugin import build.buf.protocgen.connect.internal.getClassName import build.buf.protocgen.connect.internal.getFileJavaPackage +import build.buf.protocgen.connect.internal.parse import com.google.protobuf.Descriptors import com.google.protobuf.compiler.PluginProtos import com.squareup.kotlinpoet.ClassName @@ -31,11 +33,13 @@ import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier +import com.squareup.kotlinpoet.LambdaTypeName import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.asClassName +import com.squareup.kotlinpoet.asTypeName /* * These are constants since build.buf.connect.Headers and build.buf.connect.http.Cancelable @@ -48,9 +52,11 @@ import com.squareup.kotlinpoet.asClassName * move off of type aliases, this can be changed without user API breakage. */ private val HEADERS_CLASS_NAME = ClassName("build.buf.connect", "Headers") +private val CANCELABLE_CLASS_NAME = ClassName("build.buf.connect.http", "Cancelable") class Generator : CodeGenerator { private lateinit var descriptorSource: Plugin.DescriptorSource + private lateinit var configuration: Configuration override fun generate( request: PluginProtos.CodeGeneratorRequest, @@ -58,7 +64,7 @@ class Generator : CodeGenerator { response: Plugin.Response ) { this.descriptorSource = descriptorSource - + configuration = parse(request.parameter) for (fileName in request.fileToGenerateList) { val file = descriptorSource.findFileByName(fileName) ?: throw RuntimeException("no descriptor sources found.") if (file.services.isEmpty()) { @@ -92,7 +98,7 @@ class Generator : CodeGenerator { .addFileComment("\n") .addFileComment("Source: ${file.name}\n") // Set the file package for the generated methods. - .addType(serviceClientImplementation(file.`package`, packageName, service)) + .addType(serviceClientImplementation(packageName, service)) .build() fileSpecs.put(serviceClientImplementationClassName(packageName, service), implementationFileSpec) } @@ -153,21 +159,37 @@ class Generator : CodeGenerator { .build() functions.add(clientStreamingFunction) } else { - val unarySuspendFunction = FunSpec.builder(method.name.lowerCamelCase()) - .addModifiers(KModifier.ABSTRACT) - .addModifiers(KModifier.SUSPEND) - .addParameter("request", inputClassName) - .addParameter(headerParameterSpec) - .returns(ResponseMessage::class.asClassName().parameterizedBy(outputClassName)) - .build() - functions.add(unarySuspendFunction) + if (configuration.generateCoroutineMethods) { + val unarySuspendFunction = FunSpec.builder(method.name.lowerCamelCase()) + .addModifiers(KModifier.ABSTRACT) + .addModifiers(KModifier.SUSPEND) + .addParameter("request", inputClassName) + .addParameter(headerParameterSpec) + .returns(ResponseMessage::class.asClassName().parameterizedBy(outputClassName)) + .build() + functions.add(unarySuspendFunction) + } + + if (configuration.generateCallbackMethods) { + val callbackType = LambdaTypeName.get( + parameters = listOf(ParameterSpec("", ResponseMessage::class.asTypeName().parameterizedBy(outputClassName))), + returnType = Unit::class.java.asTypeName() + ) + val unaryCallbackFunction = FunSpec.builder(method.name.lowerCamelCase()) + .addModifiers(KModifier.ABSTRACT) + .addParameter("request", inputClassName) + .addParameter(headerParameterSpec) + .addParameter("onResult", callbackType) + .returns(CANCELABLE_CLASS_NAME) + .build() + functions.add(unaryCallbackFunction) + } } } return functions } private fun serviceClientImplementation( - packageName: String, javaPackageName: String, service: Descriptors.ServiceDescriptor ): TypeSpec { @@ -184,31 +206,25 @@ class Generator : CodeGenerator { .initializer("client") .build() ) - val functionSpecs = implementationMethods( - packageName, - service.name, - service.methods - ) + val functionSpecs = implementationMethods(service.methods) return classBuilder .addFunctions(functionSpecs) .build() } private fun implementationMethods( - packageName: String, - serviceName: String, methods: List ): List { val functions = mutableListOf() for (method in methods) { val inputClassName = classNameFromType(method.inputType) val outputClassName = classNameFromType(method.outputType) - val methodCallBlock = CodeBlock.builder() + val methodSpecCallBlock = CodeBlock.builder() .addStatement("MethodSpec(") - .addStatement("\"$packageName.$serviceName/${method.name}\",") + .addStatement("\"${method.service.fullName}/${method.name}\",") .indent() .addStatement("$inputClassName::class,") - .addStatement("$outputClassName::class,") + .addStatement("$outputClassName::class") .unindent() .addStatement("),") .build() @@ -230,7 +246,7 @@ class Generator : CodeGenerator { .addStatement("client.stream(") .indent() .addStatement("headers,") - .add(methodCallBlock) + .add(methodSpecCallBlock) .unindent() .addStatement(")") .build() @@ -251,7 +267,7 @@ class Generator : CodeGenerator { .addStatement("client.serverStream(") .indent() .addStatement("headers,") - .add(methodCallBlock) + .add(methodSpecCallBlock) .unindent() .addStatement(")") .build() @@ -272,7 +288,7 @@ class Generator : CodeGenerator { .addStatement("client.clientStream(") .indent() .addStatement("headers,") - .add(methodCallBlock) + .add(methodSpecCallBlock) .unindent() .addStatement(")") .build() @@ -280,26 +296,55 @@ class Generator : CodeGenerator { .build() functions.add(clientStreamingFunction) } else { - val unarySuspendFunction = FunSpec.builder(method.name.lowerCamelCase()) - .addModifiers(KModifier.SUSPEND) - .addModifiers(KModifier.OVERRIDE) - .addParameter("request", inputClassName) - .addParameter("headers", HEADERS_CLASS_NAME) - .returns(ResponseMessage::class.asClassName().parameterizedBy(outputClassName)) - .addStatement( - "return %L", - CodeBlock.builder() - .addStatement("client.unary(") - .indent() - .addStatement("request,") - .addStatement("headers,") - .add(methodCallBlock) - .unindent() - .addStatement(")") - .build() + if (configuration.generateCoroutineMethods) { + val unarySuspendFunction = FunSpec.builder(method.name.lowerCamelCase()) + .addModifiers(KModifier.SUSPEND) + .addModifiers(KModifier.OVERRIDE) + .addParameter("request", inputClassName) + .addParameter("headers", HEADERS_CLASS_NAME) + .returns(ResponseMessage::class.asClassName().parameterizedBy(outputClassName)) + .addStatement( + "return %L", + CodeBlock.builder() + .addStatement("client.unary(") + .indent() + .addStatement("request,") + .addStatement("headers,") + .add(methodSpecCallBlock) + .unindent() + .addStatement(")") + .build() + ) + .build() + functions.add(unarySuspendFunction) + } + if (configuration.generateCallbackMethods) { + val callbackType = LambdaTypeName.get( + parameters = listOf(ParameterSpec("", ResponseMessage::class.asTypeName().parameterizedBy(outputClassName))), + returnType = Unit::class.java.asTypeName() ) - .build() - functions.add(unarySuspendFunction) + val unaryCallbackFunction = FunSpec.builder(method.name.lowerCamelCase()) + .addModifiers(KModifier.OVERRIDE) + .addParameter("request", inputClassName) + .addParameter("headers", HEADERS_CLASS_NAME) + .addParameter("onResult", callbackType) + .returns(CANCELABLE_CLASS_NAME) + .addStatement( + "return %L", + CodeBlock.builder() + .addStatement("client.unary(") + .indent() + .addStatement("request,") + .addStatement("headers,") + .add(methodSpecCallBlock) + .addStatement("onResult") + .unindent() + .addStatement(")") + .build() + ) + .build() + functions.add(unaryCallbackFunction) + } } } return functions diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/internal/Parameters.kt b/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/internal/Parameters.kt new file mode 100644 index 00000000..aabb770a --- /dev/null +++ b/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/internal/Parameters.kt @@ -0,0 +1,44 @@ +// Copyright 2022-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package build.buf.protocgen.connect.internal + +internal const val CALLBACK_SIGNATURE = "generateCallbackMethods" +internal const val COROUTINE_SIGNATURE = "generateCoroutineMethods" + +/** + * The protoc plugin configuration class representation. + */ +internal data class Configuration( + // Enable or disable callback signature generation. + val generateCallbackMethods: Boolean, + // Enable or disable coroutine signature generation. + val generateCoroutineMethods: Boolean +) + +/** + * Parse options passed as a string. + * + * Key values are parsed with `parseGeneratorParameter()`. + * The key values are expected to be in camel casing but + * will internally translate from snake casing to camel + * casing. + */ +internal fun parse(input: String): Configuration { + val parameters = parseGeneratorParameter(input) + return Configuration( + generateCallbackMethods = parameters[CALLBACK_SIGNATURE]?.toBoolean() ?: false, + generateCoroutineMethods = parameters[COROUTINE_SIGNATURE]?.toBoolean() ?: true // Defaulted to true. + ) +} diff --git a/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/internal/ProtoHelpers.kt b/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/internal/ProtoHelpers.kt index 9aadc5ec..8e0ffb2a 100644 --- a/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/internal/ProtoHelpers.kt +++ b/protoc-gen-connect-kotlin/src/main/kotlin/build/buf/protocgen/connect/internal/ProtoHelpers.kt @@ -47,7 +47,7 @@ internal fun parseGeneratorParameter( if (text.isEmpty()) { return emptyMap() } - val ret: MutableMap = HashMap() + val result: MutableMap = HashMap() val parts = text.split(",".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() for (part in parts) { if (part.isEmpty()) { @@ -63,9 +63,10 @@ internal fun parseGeneratorParameter( key = part.substring(0, equalsPos) value = part.substring(equalsPos + 1) } - ret[key] = value + val normalizedKey = underscoresToCamelCaseImpl(key, false) + result[normalizedKey] = value } - return ret + return result } /** diff --git a/protoc-gen-connect-kotlin/src/test/kotlin/PluginGenerationTest.kt b/protoc-gen-connect-kotlin/src/test/kotlin/PluginGenerationTest.kt index bd34dd7e..081e2b47 100644 --- a/protoc-gen-connect-kotlin/src/test/kotlin/PluginGenerationTest.kt +++ b/protoc-gen-connect-kotlin/src/test/kotlin/PluginGenerationTest.kt @@ -17,6 +17,7 @@ import buf.javamultiplefiles.disabled.v1.DisabledEmptyServiceClient import buf.javamultiplefiles.disabled.v1.DisabledInnerMessageServiceClient import buf.javamultiplefiles.disabled.v1.DisabledServiceClient import buf.javamultiplefiles.enabled.v1.EnabledEmptyRPCRequest +import buf.javamultiplefiles.enabled.v1.EnabledEmptyRPCResponse import buf.javamultiplefiles.enabled.v1.EnabledEmptyServiceClient import buf.javamultiplefiles.enabled.v1.EnabledInnerMessageServiceClient import buf.javamultiplefiles.enabled.v1.EnabledServiceClient @@ -81,4 +82,29 @@ class PluginGenerationTest { assertThat(UnspecifiedServiceClient::class.java).isNotNull assertThat(UnspecifiedInnerMessageServiceClient::class.java).isNotNull } + + @Test + fun callbackSignature() { + val unspecifiedEmptyServiceClient = UnspecifiedEmptyServiceClient(mock { }) + val request = UnspecifiedEmptyOuterClass.UnspecifiedEmptyRPCRequest.getDefaultInstance() + unspecifiedEmptyServiceClient.unspecifiedEmptyRPC(request) { response -> + response.success { success -> + assertThat(success.message).isOfAnyClassIn(UnspecifiedEmptyOuterClass.UnspecifiedEmptyRPCResponse::class.java) + } + } + val disabledEmptyServiceClient = DisabledEmptyServiceClient(mock { }) + disabledEmptyServiceClient.disabledEmptyRPC(DisabledEmptyOuterClass.DisabledEmptyRPCRequest.getDefaultInstance()) { response -> + response.success { success -> + success.message + assertThat(success.message).isOfAnyClassIn(DisabledEmptyOuterClass.DisabledEmptyRPCResponse::class.java) + } + } + val enabledEmptyServiceClient = EnabledEmptyServiceClient(mock { }) + enabledEmptyServiceClient.enabledEmptyRPC(EnabledEmptyRPCRequest.getDefaultInstance()) { response -> + response.success { success -> + success.message + assertThat(success.message).isOfAnyClassIn(EnabledEmptyRPCResponse::class.java) + } + } + } }