Skip to content

Commit

Permalink
Make ServerOperationRegistryGenerator protocol-agnostic (#1525)
Browse files Browse the repository at this point in the history
`ServerOperationRegistryGenerator` is the only server generator that is
currently not protocol-agnostic, in the sense that it is the only
generator that contains protocol-specific logic, as opposed to
delegating to the `Protocol` interface like other generators do.

With this change, we should be good to implement a
`RustCodegenDecorator` loaded from the classpath that implements support
for any protocol, provided the decorator provides a class implementing
the `Protocol` interface.

This commit also contains some style changes that are making `ktlint`
fail. It seems like our `ktlint` config was recently broken and some
style violations slipped through the cracks in previous commits that
touched these files.
  • Loading branch information
david-perez authored Jul 5, 2022
1 parent 12b4943 commit e751ed6
Show file tree
Hide file tree
Showing 19 changed files with 288 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ import software.amazon.smithy.rust.codegen.smithy.BaseSymbolMetadataProvider
import software.amazon.smithy.rust.codegen.smithy.EventStreamSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.smithy.StreamingShapeMetadataProvider
import software.amazon.smithy.rust.codegen.smithy.StreamingShapeSymbolProvider
import software.amazon.smithy.rust.codegen.smithy.SymbolVisitor
import software.amazon.smithy.rust.codegen.smithy.SymbolVisitorConfig
import software.amazon.smithy.rust.codegen.smithy.customize.CombinedCodegenDecorator
import software.amazon.smithy.rust.codegen.smithy.StreamingShapeSymbolProvider
import java.util.logging.Level
import java.util.logging.Logger

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ class PythonServerCodegenVisitor(
rustCrate,
protocolGenerator,
protocolGeneratorFactory.support(),
protocolGeneratorFactory.protocol(codegenContext).httpBindingResolver,
codegenContext,
protocolGeneratorFactory.protocol(codegenContext),
codegenContext
)
.render()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.RustCrate
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol

/**
* PythonServerServiceGenerator
*
* Service generator is the main codegeneration entry point for Smithy services. Individual structures and unions are
* Service generator is the main code generation entry point for Smithy services. Individual structures and unions are
* generated in codegen visitor, but this class handles all protocol-specific code generation (i.e. operations).
*/
class PythonServerServiceGenerator(
private val rustCrate: RustCrate,
protocolGenerator: ProtocolGenerator,
protocolSupport: ProtocolSupport,
httpBindingResolver: HttpBindingResolver,
protocol: Protocol,
private val context: CoreCodegenContext,
) : ServerServiceGenerator(rustCrate, protocolGenerator, protocolSupport, httpBindingResolver, context) {
) : ServerServiceGenerator(rustCrate, protocolGenerator, protocolSupport, protocol, context) {

override fun renderCombinedErrors(writer: RustWriter, operation: OperationShape) {
PythonServerCombinedErrorGenerator(context.model, context.symbolProvider, operation).render(writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class RustCodegenServerPlugin : SmithyBuildPlugin {
CombinedCodegenDecorator.fromClasspath(context, ServerRequiredCustomizations())

// ServerCodegenVisitor is the main driver of code generation that traverses the model and generates code
logger.info("Loaded plugin to generate pure Rust bindings for the server SSDK")
logger.info("Loaded plugin to generate pure Rust bindings for the server SDK")
ServerCodegenVisitor(context, codegenDecorator).execute()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ open class ServerCodegenVisitor(
rustCrate,
protocolGenerator,
protocolGeneratorFactory.support(),
protocolGeneratorFactory.protocol(codegenContext).httpBindingResolver,
codegenContext,
protocolGeneratorFactory.protocol(codegenContext),
codegenContext
)
.render()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ open class ServerOperationHandlerGenerator(
private val model = coreCodegenContext.model
private val protocol = coreCodegenContext.protocol
private val symbolProvider = coreCodegenContext.symbolProvider
private val operationNames = operations.map { symbolProvider.toSymbol(it).name }
private val runtimeConfig = coreCodegenContext.runtimeConfig
private val codegenScope = arrayOf(
"AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(),
Expand All @@ -52,7 +51,7 @@ open class ServerOperationHandlerGenerator(
renderHandlerImplementations(writer, true)
}

/*
/**
* Renders the implementation of the `Handler` trait for all operations.
* Handlers are implemented for `FnOnce` function types whose signatures take in state or not.
*/
Expand Down Expand Up @@ -126,7 +125,7 @@ open class ServerOperationHandlerGenerator(
}
}

/*
/**
* Generates the trait bounds of the `Handler` trait implementation, depending on:
* - the presence of state; and
* - whether the operation is fallible or not.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.DocumentationTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
Expand All @@ -32,7 +28,7 @@ import software.amazon.smithy.rust.codegen.smithy.Inputs
import software.amazon.smithy.rust.codegen.smithy.Outputs
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.outputShape
Expand All @@ -51,12 +47,11 @@ import software.amazon.smithy.rust.codegen.util.toSnakeCase
*/
class ServerOperationRegistryGenerator(
coreCodegenContext: CoreCodegenContext,
private val httpBindingResolver: HttpBindingResolver,
private val protocol: Protocol,
private val operations: List<OperationShape>,
) {
private val crateName = coreCodegenContext.settings.moduleName
private val model = coreCodegenContext.model
private val protocol = coreCodegenContext.protocol
private val symbolProvider = coreCodegenContext.symbolProvider
private val serviceName = coreCodegenContext.serviceShape.toShapeId().name
private val operationNames = operations.map { symbolProvider.toSymbol(it).name.toSnakeCase() }
Expand Down Expand Up @@ -89,14 +84,20 @@ class ServerOperationRegistryGenerator(
}

private fun renderOperationRegistryRustDocs(writer: RustWriter) {
val inputOutputErrorsImport = if (operations.any { it.errors.isNotEmpty() }) {
"/// use $crateName::{${Inputs.namespace}, ${Outputs.namespace}, ${Errors.namespace}};"
} else {
"/// use $crateName::{${Inputs.namespace}, ${Outputs.namespace}};"
}

writer.rustTemplate(
"""
##[allow(clippy::tabs_in_doc_comments)]
/// The `${operationRegistryName}` is the place where you can register
/// The `$operationRegistryName` is the place where you can register
/// your service's operation implementations.
///
/// Use [`${operationRegistryBuilderName}`] to construct the
/// `${operationRegistryName}`. For each of the [operations] modeled in
/// Use [`$operationRegistryBuilderName`] to construct the
/// `$operationRegistryName`. For each of the [operations] modeled in
/// your Smithy service, you need to provide an implementation in the
/// form of a Rust async function or closure that takes in the
/// operation's input as their first parameter, and returns the
Expand All @@ -120,17 +121,13 @@ class ServerOperationRegistryGenerator(
///
/// ```rust
/// use std::net::SocketAddr;
${ if (operations.any { it.errors.isNotEmpty() }) {
"/// use ${crateName}::{${Inputs.namespace}, ${Outputs.namespace}, ${Errors.namespace}};"
} else {
"/// use ${crateName}::{${Inputs.namespace}, ${Outputs.namespace}};"
} }
/// use ${crateName}::operation_registry::${operationRegistryBuilderName};
$inputOutputErrorsImport
/// use $crateName::operation_registry::$operationRegistryBuilderName;
/// use #{Router};
///
/// ##[#{Tokio}::main]
/// pub async fn main() {
/// let app: Router = ${operationRegistryBuilderName}::default()
/// let app: Router = $operationRegistryBuilderName::default()
${operationNames.map { ".$it($it)" }.joinToString("\n") { it.prependIndent("/// ") }}
/// .build()
/// .expect("unable to build operation registry")
Expand Down Expand Up @@ -206,10 +203,10 @@ ${operationImplementationStubs(operations)}
Attribute.Derives(setOf(RuntimeType.Debug)).render(writer)
writer.rustTemplate(
"""
pub enum ${operationRegistryErrorName}{
pub enum $operationRegistryErrorName {
UninitializedField(&'static str)
}
impl #{Display} for ${operationRegistryErrorName}{
impl #{Display} for $operationRegistryErrorName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UninitializedField(v) => write!(f, "{}", v),
Expand Down Expand Up @@ -263,14 +260,14 @@ ${operationImplementationStubs(operations)}
)
}

rustBlock("pub fn build(self) -> Result<$operationRegistryNameWithArguments, ${operationRegistryErrorName}>") {
rustBlock("pub fn build(self) -> Result<$operationRegistryNameWithArguments, $operationRegistryErrorName>") {
withBlock("Ok( $operationRegistryName {", "})") {
for (operationName in operationNames) {
rust(
"""
$operationName: match self.$operationName {
Some(v) => v,
None => return Err(${operationRegistryErrorName}::UninitializedField("$operationName")),
None => return Err($operationRegistryErrorName::UninitializedField("$operationName")),
},
"""
)
Expand Down Expand Up @@ -320,7 +317,11 @@ ${operationImplementationStubs(operations)}
)
}

withBlockTemplate("#{Router}::${runtimeRouterConstructor()}(vec![", "])", *codegenScope) {
withBlockTemplate(
"#{Router}::${protocol.serverRouterRuntimeConstructor()}(vec![",
"])",
*codegenScope
) {
requestSpecsVarNames.zip(operationNames).forEach { (requestSpecVarName, operationName) ->
rustTemplate(
"(#{Tower}::util::BoxCloneService::new(#{ServerOperationHandler}::operation(registry.$operationName)), $requestSpecVarName),",
Expand All @@ -337,112 +338,18 @@ ${operationImplementationStubs(operations)}
*/
private fun phantomMembers() = operationNames.mapIndexed { i, _ -> "In$i" }.joinToString(separator = ",\n")

/**
* Finds the runtime function to construct a new `Router` based on the Protocol.
*/
private fun runtimeRouterConstructor(): String =
when (protocol) {
RestJson1Trait.ID -> "new_rest_json_router"
RestXmlTrait.ID -> "new_rest_xml_router"
AwsJson1_0Trait.ID -> "new_aws_json_10_router"
AwsJson1_1Trait.ID -> "new_aws_json_11_router"
else -> TODO("Protocol $protocol not supported yet")
}

/**
* Returns a writable for the `RequestSpec` for an operation based on the service's protocol.
*/
private fun OperationShape.requestSpec(): Writable =
when (protocol) {
RestJson1Trait.ID, RestXmlTrait.ID -> restRequestSpec()
AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> awsJsonOperationName()
else -> TODO("Protocol $protocol not supported yet")
}

/**
* Returns the operation name as required by the awsJson1.x protocols.
*/
private fun OperationShape.awsJsonOperationName(): Writable {
val operationName = symbolProvider.toSymbol(this).name
return writable {
rust("""String::from("$serviceName.$operationName")""")
}
}

/**
* Generates a restJson1 or restXml specific `RequestSpec`.
*/
private fun OperationShape.restRequestSpec(): Writable {
val httpTrait = httpBindingResolver.httpTrait(this)
val extraCodegenScope =
arrayOf("RequestSpec", "UriSpec", "PathAndQuerySpec", "PathSpec", "QuerySpec", "PathSegment", "QuerySegment").map {
it to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType().member("routing::request_spec::$it")
}.toTypedArray()

// TODO(https://github.com/awslabs/smithy-rs/issues/950): Support the `endpoint` trait.
val pathSegmentsVec = writable {
withBlock("vec![", "]") {
for (segment in httpTrait.uri.segments) {
val variant = when {
segment.isGreedyLabel -> "Greedy"
segment.isLabel -> "Label"
else -> """Literal(String::from("${segment.content}"))"""
}
rustTemplate(
"#{PathSegment}::$variant,",
*extraCodegenScope
)
}
}
}

val querySegmentsVec = writable {
withBlock("vec![", "]") {
for (queryLiteral in httpTrait.uri.queryLiterals) {
val variant = if (queryLiteral.value == "") {
"""Key(String::from("${queryLiteral.key}"))"""
} else {
"""KeyValue(String::from("${queryLiteral.key}"), String::from("${queryLiteral.value}"))"""
}
rustTemplate("#{QuerySegment}::$variant,", *extraCodegenScope)
}
}
}

return writable {
rustTemplate(
"""
#{RequestSpec}::new(
#{Method}::${httpTrait.method},
#{UriSpec}::new(
#{PathAndQuerySpec}::new(
#{PathSpec}::from_vector_unchecked(#{PathSegmentsVec:W}),
#{QuerySpec}::from_vector_unchecked(#{QuerySegmentsVec:W})
)
),
)
""",
*codegenScope,
*extraCodegenScope,
"PathSegmentsVec" to pathSegmentsVec,
"QuerySegmentsVec" to querySegmentsVec,
"Method" to CargoDependency.Http.asType().member("Method"),
)
}
}

private fun operationImplementationStubs(operations: List<OperationShape>): String =
operations.joinToString("\n///\n") {
val operationDocumentation = it.getTrait<DocumentationTrait>()?.value
val ret = if (!operationDocumentation.isNullOrBlank()) {
operationDocumentation.replace("#", "##").prependIndent("/// /// ") + "\n"
} else ""
ret +
"""
"""
/// ${it.signature()} {
/// todo!()
/// }
""".trimIndent()
""".trimIndent()
}

/**
Expand All @@ -465,4 +372,14 @@ ${operationImplementationStubs(operations)}
val operationName = symbolProvider.toSymbol(this).name.toSnakeCase()
return "async fn $operationName(input: $inputT) -> $outputT"
}

/**
* Returns a writable for the `RequestSpec` for an operation based on the service's protocol.
*/
private fun OperationShape.requestSpec(): Writable = protocol.serverRouterRequestSpec(
this,
symbolProvider.toSymbol(this).name,
serviceName,
ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType().member("routing::request_spec")
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ import software.amazon.smithy.rust.codegen.smithy.CoreCodegenContext
import software.amazon.smithy.rust.codegen.smithy.RustCrate
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolGenerator
import software.amazon.smithy.rust.codegen.smithy.generators.protocol.ProtocolSupport
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol

/**
* ServerServiceGenerator
*
* Service generator is the main codegeneration entry point for Smithy services. Individual structures and unions are
* Service generator is the main code generation entry point for Smithy services. Individual structures and unions are
* generated in codegen visitor, but this class handles all protocol-specific code generation (i.e. operations).
*/
open class ServerServiceGenerator(
private val rustCrate: RustCrate,
private val protocolGenerator: ProtocolGenerator,
private val protocolSupport: ProtocolSupport,
private val httpBindingResolver: HttpBindingResolver,
private val protocol: Protocol,
private val coreCodegenContext: CoreCodegenContext,
) {
private val index = TopDownIndex.of(coreCodegenContext.model)
Expand Down Expand Up @@ -84,6 +84,6 @@ open class ServerServiceGenerator(

// Render operations registry.
private fun renderOperationRegistry(writer: RustWriter, operations: List<OperationShape>) {
ServerOperationRegistryGenerator(coreCodegenContext, httpBindingResolver, operations).render(writer)
ServerOperationRegistryGenerator(coreCodegenContext, protocol, operations).render(writer)
}
}
Loading

0 comments on commit e751ed6

Please sign in to comment.