diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index 525487968f..93da475202 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join import software.amazon.smithy.rust.codegen.core.rustlang.rust import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope @@ -33,7 +34,7 @@ data class ConfigMethod( val docs: String, /** The parameters of the method. **/ val params: List, - /** In case the method is fallible, the error type it returns. **/ + /** In case the method is fallible, the concrete error type it returns. **/ val errorType: RuntimeType?, /** The code block inside the method. **/ val initializer: Initializer, @@ -104,15 +105,42 @@ data class Initializer( * } * * has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a - * `u64` variable. + * `u64` variable. Both are bindings that use concrete types. Types can also be generic: + * + * ```rust + * fn foo(bar: T) { } * ``` */ -data class Binding( - /** The name of the variable. */ - val name: String, - /** The type of the variable. */ - val ty: RuntimeType, -) +sealed class Binding { + data class Generic( + /** The name of the variable. The name of the type parameter will be the PascalCased variable name. */ + val name: String, + /** The type of the variable. */ + val ty: RuntimeType, + /** + * The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec` with `T` being a + * generic type parameter, then `genericTys` should be a singleton list containing `"T"`. + * */ + val genericTys: List + ): Binding() + + data class Concrete( + /** The name of the variable. */ + val name: String, + /** The type of the variable. */ + val ty: RuntimeType, + ): Binding() + + fun name() = when (this) { + is Concrete -> this.name + is Generic -> this.name + } + + fun ty() = when (this) { + is Concrete -> this.ty + is Generic -> this.ty + } +} class ServiceConfigGenerator( codegenContext: ServerCodegenContext, @@ -317,8 +345,10 @@ class ServiceConfigGenerator( private fun injectedMethods() = configMethods.map { writable { val paramBindings = it.params.map { binding -> - writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) } + writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) } }.join("\n") + val paramBindingsGenericTys = it.params.filterIsInstance().flatMap { it.genericTys } + val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray()) // This produces a nested type like: "S>", where // - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack @@ -332,7 +362,7 @@ class ServiceConfigGenerator( rustTemplate( "#{StackType}<#{Ty}, #{Acc:W}>", "StackType" to stackType, - "Ty" to next.ty, + "Ty" to next.ty(), "Acc" to acc, ) } @@ -376,7 +406,7 @@ class ServiceConfigGenerator( docs(it.docs) rustBlockTemplate( """ - pub fn ${it.name}( + pub fn ${it.name}#{ParamBindingsGenericsWritable}( ##[allow(unused_mut)] mut self, #{ParamBindings:W} @@ -384,6 +414,7 @@ class ServiceConfigGenerator( """, "ReturnTy" to returnTy, "ParamBindings" to paramBindings, + "ParamBindingsGenericsWritable" to paramBindingsGenericsWritable, ) { rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) @@ -396,9 +427,9 @@ class ServiceConfigGenerator( } conditionalBlock("Ok(", ")", conditional = it.errorType != null) { val registrations = ( - it.initializer.layerBindings.map { ".layer(${it.name})" } + - it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } + - it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" } + it.initializer.layerBindings.map { ".layer(${it.name()})" } + + it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } + + it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" } ).joinToString("") rust("self$registrations") } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt index c2c568b291..f1264b672f 100644 --- a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -40,8 +40,9 @@ internal class ServiceConfigGeneratorTest { name = "aws_auth", docs = "Docs", params = listOf( - Binding("auth_spec", RuntimeType.String), - Binding("authorizer", RuntimeType.U64), + Binding.Concrete("auth_spec", RuntimeType.String), + Binding.Concrete("authorizer", RuntimeType.U64), + Binding.Generic("generic_list", RuntimeType("::std::vec::Vec"), listOf("T")), ), errorType = RuntimeType.std.resolve("io::Error"), initializer = Initializer( @@ -51,8 +52,8 @@ internal class ServiceConfigGeneratorTest { if authorizer != 69 { return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); } - - if auth_spec.len() != 69 { + + if auth_spec.len() != 69 && generic_list.len() != 69 { return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); } let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; @@ -63,13 +64,13 @@ internal class ServiceConfigGeneratorTest { }, layerBindings = emptyList(), httpPluginBindings = listOf( - Binding( + Binding.Concrete( "authn_plugin", smithyHttpServer.resolve("plugin::IdentityPlugin"), ), ), modelPluginBindings = listOf( - Binding( + Binding.Concrete( "authz_plugin", smithyHttpServer.resolve("plugin::IdentityPlugin"), ), @@ -101,7 +102,7 @@ internal class ServiceConfigGeneratorTest { // One model plugin has been applied. PluginStack, > = SimpleServiceConfig::builder() - .aws_auth("a".repeat(69).to_owned(), 69) + .aws_auth("a".repeat(69).to_owned(), 69, vec![69]) .expect("failed to configure aws_auth") .build() .unwrap(); @@ -113,7 +114,7 @@ internal class ServiceConfigGeneratorTest { rust( """ let actual_err = SimpleServiceConfig::builder() - .aws_auth("a".to_owned(), 69) + .aws_auth("a".to_owned(), 69, vec![69]) .unwrap_err(); let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string(); assert_eq!(actual_err.to_string(), expected); @@ -125,7 +126,7 @@ internal class ServiceConfigGeneratorTest { rust( """ let actual_err = SimpleServiceConfig::builder() - .aws_auth("a".repeat(69).to_owned(), 6969) + .aws_auth("a".repeat(69).to_owned(), 6969, vec!["69"]) .unwrap_err(); let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string(); assert_eq!(actual_err.to_string(), expected); @@ -147,7 +148,7 @@ internal class ServiceConfigGeneratorTest { } @Test - fun `it should inject an method that applies three non-required layers`() { + fun `it should inject a method that applies three non-required layers`() { val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() val decorator = object : ServerCodegenDecorator { @@ -179,9 +180,9 @@ internal class ServiceConfigGeneratorTest { ) }, layerBindings = listOf( - Binding("layer1", identityLayer), - Binding("layer2", identityLayer), - Binding("layer3", identityLayer), + Binding.Concrete("layer1", identityLayer), + Binding.Concrete("layer2", identityLayer), + Binding.Concrete("layer3", identityLayer), ), httpPluginBindings = emptyList(), modelPluginBindings = emptyList(),