Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow server decorators to inject methods on config #3111

Merged
merged 8 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class TestWriterDelegator(
}

/**
* Generate a newtest module
* Generate a new test module
*
* This should only be used in test code—the generated module name will be something like `tests_123`
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ object ServerCargoDependency {
val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.13"))
val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
val ThisError: CargoDependency = CargoDependency("thiserror", CratesIo("1.0"))
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev)
val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.Unconstraine
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isBuilderFallible
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator
Expand Down Expand Up @@ -591,11 +592,15 @@ open class ServerCodegenVisitor(
logger.info("[rust-server-codegen] Generating a service $shape")
val serverProtocol = protocolGeneratorFactory.protocol(codegenContext) as ServerProtocol

val configMethods = codegenDecorator.configMethods(codegenContext)
val isConfigBuilderFallible = configMethods.isBuilderFallible()

// Generate root.
rustCrate.lib {
ServerRootGenerator(
serverProtocol,
codegenContext,
isConfigBuilderFallible,
).render(this)
}

Expand All @@ -612,9 +617,10 @@ open class ServerCodegenVisitor(
ServerServiceGenerator(
codegenContext,
serverProtocol,
isConfigBuilderFallible,
).render(this)

ServiceConfigGenerator(codegenContext).render(this)
ServiceConfigGenerator(codegenContext, configMethods).render(this)

ScopeMacroGenerator(codegenContext).render(this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult
import software.amazon.smithy.rust.codegen.server.smithy.generators.ConfigMethod
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator
import java.util.logging.Logger
Expand All @@ -41,6 +42,12 @@ interface ServerCodegenDecorator : CoreCodegenDecorator<ServerCodegenContext, Se
* Therefore, ensure that all the structure shapes returned by this method are not in the service's closure.
*/
fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> = emptyList()

/**
* Configuration methods that should be injected into the `${serviceName}Config` struct to allow users to configure
* pre-applied layers and plugins.
*/
fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> = emptyList()
}

/**
Expand Down Expand Up @@ -74,10 +81,11 @@ class CombinedServerCodegenDecorator(decorators: List<ServerCodegenDecorator>) :
decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated)
}

override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> {
return orderedDecorators.map { decorator -> decorator.postprocessGenerateAdditionalStructures(operationShape) }
.flatten()
}
override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List<StructureShape> =
orderedDecorators.flatMap { it.postprocessGenerateAdditionalStructures(operationShape) }

override fun configMethods(codegenContext: ServerCodegenContext): List<ConfigMethod> =
orderedDecorators.flatMap { it.configMethods(codegenContext) }

companion object {
fun fromClasspath(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output
open class ServerRootGenerator(
val protocol: ServerProtocol,
private val codegenContext: ServerCodegenContext,
private val isConfigBuilderFallible: Boolean,
drganjoo marked this conversation as resolved.
Show resolved Hide resolved
) {
private val index = TopDownIndex.of(codegenContext.model)
private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet(
Expand All @@ -57,6 +58,8 @@ open class ServerRootGenerator(
}
.join("//!\n")

val unwrapConfigBuilder = if (isConfigBuilderFallible) ".expect(\"config failed to build\")" else ""

writer.rustTemplate(
"""
//! A fast and customizable Rust implementation of the $serviceName Smithy service.
Expand All @@ -75,7 +78,10 @@ open class ServerRootGenerator(
//! ## async fn dummy() {
//! use $crateName::{$serviceName, ${serviceName}Config};
//!
//! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked();
//! ## let app = $serviceName::builder(
//! ## ${serviceName}Config::builder()
//! ## .build()$unwrapConfigBuilder
//! ## ).build_unchecked();
//! let server = app.into_make_service();
//! let bind: SocketAddr = "127.0.0.1:6969".parse()
//! .expect("unable to parse the server bind address and port");
Expand All @@ -92,7 +98,10 @@ open class ServerRootGenerator(
//! use $crateName::$serviceName;
//!
//! ## async fn dummy() {
//! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked();
//! ## let app = $serviceName::builder(
//! ## ${serviceName}Config::builder()
//! ## .build()$unwrapConfigBuilder
//! ## ).build_unchecked();
//! let handler = LambdaHandler::new(app);
//! lambda_http::run(handler).await.unwrap();
//! ## }
Expand All @@ -118,7 +127,7 @@ open class ServerRootGenerator(
//! let http_plugins = HttpPlugins::new()
//! .push(LoggingPlugin)
//! .push(MetricsPlugin);
//! let config = ${serviceName}Config::builder().build();
//! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
//! let builder: $builderName<Body, _, _, _> = $serviceName::builder(config);
//! ```
//!
Expand Down Expand Up @@ -183,13 +192,13 @@ open class ServerRootGenerator(
//!
//! ## Example
//!
//! ```rust
//! ```rust,no_run
//! ## use std::net::SocketAddr;
//! use $crateName::{$serviceName, ${serviceName}Config};
//!
//! ##[#{Tokio}::main]
//! pub async fn main() {
//! let config = ${serviceName}Config::builder().build();
//! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
//! let app = $serviceName::builder(config)
${builderFieldNames.values.joinToString("\n") { "//! .$it($it)" }}
//! .build()
Expand Down Expand Up @@ -236,6 +245,23 @@ open class ServerRootGenerator(
fun render(rustWriter: RustWriter) {
documentation(rustWriter)

rustWriter.rust("pub use crate::service::{$serviceName, ${serviceName}Config, ${serviceName}ConfigBuilder, ${serviceName}Builder, MissingOperationsError};")
// Only export config builder error if fallible.
val configErrorReExport = if (isConfigBuilderFallible) {
"${serviceName}ConfigError,"
} else {
""
}
rustWriter.rust(
"""
pub use crate::service::{
$serviceName,
${serviceName}Config,
${serviceName}ConfigBuilder,
$configErrorReExport
${serviceName}Builder,
MissingOperationsError
};
"""
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output
class ServerServiceGenerator(
private val codegenContext: ServerCodegenContext,
private val protocol: ServerProtocol,
private val isConfigBuilderFallible: Boolean,
) {
private val runtimeConfig = codegenContext.runtimeConfig
private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType()
Expand Down Expand Up @@ -107,6 +108,11 @@ class ServerServiceGenerator(
val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///")
val handler = docHandler.docSignature()
val handlerFixed = docHandler.docFixedSignature()
val unwrapConfigBuilder = if (isConfigBuilderFallible) {
".expect(\"config failed to build\")"
drganjoo marked this conversation as resolved.
Show resolved Hide resolved
} else {
""
}
rustTemplate(
"""
/// Sets the [`$structName`](crate::operation_shape::$structName) operation.
Expand All @@ -123,7 +129,7 @@ class ServerServiceGenerator(
///
#{Handler:W}
///
/// let config = ${serviceName}Config::builder().build();
/// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
/// let app = $serviceName::builder(config)
/// .$fieldName(handler)
/// /* Set other handlers */
Expand Down Expand Up @@ -186,7 +192,7 @@ class ServerServiceGenerator(
///
#{HandlerFixed:W}
///
/// let config = ${serviceName}Config::builder().build();
/// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder;
/// let svc = #{Tower}::util::service_fn(handler);
/// let app = $serviceName::builder(config)
/// .${fieldName}_service(svc)
Expand Down
Loading
Loading