Skip to content

Commit

Permalink
Allow injecting methods with generic type parameters in the config ob…
Browse files Browse the repository at this point in the history
…ject

This is a follow-up to #3111. Currently, the injected methods are
limited to taking in concrete types. This PR allows for these methods to
take in generic type parameters as well.

```rust
impl<L, H, M> SimpleServiceConfigBuilder<L, H, M> {
    pub fn aws_auth<C>(config: C) {
        ...
    }
}
```
  • Loading branch information
david-perez committed Nov 30, 2023
1 parent 6420816 commit efba260
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTypeParameters
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
Expand All @@ -33,7 +34,7 @@ data class ConfigMethod(
val docs: String,
/** The parameters of the method. **/
val params: List<Binding>,
/** In case the method is fallible, the error type it returns. **/
/** In case the method is fallible, the concrete error type it returns. **/
val errorType: RuntimeType?,
/** The code block inside the method. **/
val initializer: Initializer,
Expand Down Expand Up @@ -104,15 +105,42 @@ data class Initializer(
* }
*
* has two variable bindings. The `bar` name is bound to a `String` variable and the `baz` name is bound to a
* `u64` variable.
* `u64` variable. Both are bindings that use concrete types. Types can also be generic:
*
* ```rust
* fn foo<T>(bar: T) { }
* ```
*/
data class Binding(
/** The name of the variable. */
val name: String,
/** The type of the variable. */
val ty: RuntimeType,
)
sealed class Binding {
data class Generic(
/** The name of the variable. The name of the type parameter will be the PascalCased variable name. */
val name: String,
/** The type of the variable. */
val ty: RuntimeType,
/**
* The generic type parameters contained in `ty`. For example, if `ty` renders to `Vec<T>` with `T` being a
* generic type parameter, then `genericTys` should be a singleton list containing `"T"`.
* */
val genericTys: List<String>
): Binding()

data class Concrete(
/** The name of the variable. */
val name: String,
/** The type of the variable. */
val ty: RuntimeType,
): Binding()

fun name() = when (this) {
is Concrete -> this.name
is Generic -> this.name
}

fun ty() = when (this) {
is Concrete -> this.ty
is Generic -> this.ty
}
}

class ServiceConfigGenerator(
codegenContext: ServerCodegenContext,
Expand Down Expand Up @@ -317,8 +345,10 @@ class ServiceConfigGenerator(
private fun injectedMethods() = configMethods.map {
writable {
val paramBindings = it.params.map { binding ->
writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) }
writable { rustTemplate("${binding.name()}: #{BindingTy},", "BindingTy" to binding.ty()) }
}.join("\n")
val paramBindingsGenericTys = it.params.filterIsInstance<Binding.Generic>().flatMap { it.genericTys }
val paramBindingsGenericsWritable = rustTypeParameters(*paramBindingsGenericTys.toTypedArray())

// This produces a nested type like: "S<B, S<A, T>>", where
// - "S" denotes a "stack type" with two generic type parameters: the first is the "inner" part of the stack
Expand All @@ -332,7 +362,7 @@ class ServiceConfigGenerator(
rustTemplate(
"#{StackType}<#{Ty}, #{Acc:W}>",
"StackType" to stackType,
"Ty" to next.ty,
"Ty" to next.ty(),
"Acc" to acc,
)
}
Expand Down Expand Up @@ -376,14 +406,15 @@ class ServiceConfigGenerator(
docs(it.docs)
rustBlockTemplate(
"""
pub fn ${it.name}(
pub fn ${it.name}#{ParamBindingsGenericsWritable}(
##[allow(unused_mut)]
mut self,
#{ParamBindings:W}
) -> #{ReturnTy:W}
""",
"ReturnTy" to returnTy,
"ParamBindings" to paramBindings,
"ParamBindingsGenericsWritable" to paramBindingsGenericsWritable,
) {
rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code)

Expand All @@ -396,9 +427,9 @@ class ServiceConfigGenerator(
}
conditionalBlock("Ok(", ")", conditional = it.errorType != null) {
val registrations = (
it.initializer.layerBindings.map { ".layer(${it.name})" } +
it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } +
it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" }
it.initializer.layerBindings.map { ".layer(${it.name()})" } +
it.initializer.httpPluginBindings.map { ".http_plugin(${it.name()})" } +
it.initializer.modelPluginBindings.map { ".model_plugin(${it.name()})" }
).joinToString("")
rust("self$registrations")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ internal class ServiceConfigGeneratorTest {
name = "aws_auth",
docs = "Docs",
params = listOf(
Binding("auth_spec", RuntimeType.String),
Binding("authorizer", RuntimeType.U64),
Binding.Concrete("auth_spec", RuntimeType.String),
Binding.Concrete("authorizer", RuntimeType.U64),
Binding.Generic("generic_list", RuntimeType("::std::vec::Vec<T>"), listOf("T")),
),
errorType = RuntimeType.std.resolve("io::Error"),
initializer = Initializer(
Expand All @@ -51,8 +52,8 @@ internal class ServiceConfigGeneratorTest {
if authorizer != 69 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1"));
}
if auth_spec.len() != 69 {
if auth_spec.len() != 69 && generic_list.len() != 69 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2"));
}
let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin;
Expand All @@ -63,13 +64,13 @@ internal class ServiceConfigGeneratorTest {
},
layerBindings = emptyList(),
httpPluginBindings = listOf(
Binding(
Binding.Concrete(
"authn_plugin",
smithyHttpServer.resolve("plugin::IdentityPlugin"),
),
),
modelPluginBindings = listOf(
Binding(
Binding.Concrete(
"authz_plugin",
smithyHttpServer.resolve("plugin::IdentityPlugin"),
),
Expand Down Expand Up @@ -101,7 +102,7 @@ internal class ServiceConfigGeneratorTest {
// One model plugin has been applied.
PluginStack<IdentityPlugin, IdentityPlugin>,
> = SimpleServiceConfig::builder()
.aws_auth("a".repeat(69).to_owned(), 69)
.aws_auth("a".repeat(69).to_owned(), 69, vec![69])
.expect("failed to configure aws_auth")
.build()
.unwrap();
Expand All @@ -113,7 +114,7 @@ internal class ServiceConfigGeneratorTest {
rust(
"""
let actual_err = SimpleServiceConfig::builder()
.aws_auth("a".to_owned(), 69)
.aws_auth("a".to_owned(), 69, vec![69])
.unwrap_err();
let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string();
assert_eq!(actual_err.to_string(), expected);
Expand All @@ -125,7 +126,7 @@ internal class ServiceConfigGeneratorTest {
rust(
"""
let actual_err = SimpleServiceConfig::builder()
.aws_auth("a".repeat(69).to_owned(), 6969)
.aws_auth("a".repeat(69).to_owned(), 6969, vec!["69"])
.unwrap_err();
let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string();
assert_eq!(actual_err.to_string(), expected);
Expand All @@ -147,7 +148,7 @@ internal class ServiceConfigGeneratorTest {
}

@Test
fun `it should inject an method that applies three non-required layers`() {
fun `it should inject a method that applies three non-required layers`() {
val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel()

val decorator = object : ServerCodegenDecorator {
Expand Down Expand Up @@ -179,9 +180,9 @@ internal class ServiceConfigGeneratorTest {
)
},
layerBindings = listOf(
Binding("layer1", identityLayer),
Binding("layer2", identityLayer),
Binding("layer3", identityLayer),
Binding.Concrete("layer1", identityLayer),
Binding.Concrete("layer2", identityLayer),
Binding.Concrete("layer3", identityLayer),
),
httpPluginBindings = emptyList(),
modelPluginBindings = emptyList(),
Expand Down

0 comments on commit efba260

Please sign in to comment.