Skip to content

Commit

Permalink
Prevent test dependencies from leaking into production (#2264)
Browse files Browse the repository at this point in the history
* Prevent test dependencies from leaking into production

* refactor & fix tests

* fix tests take two

* fix more tests

* Fix missed called to mergeDependencyFeatures

* Add test

* fix glacier compilation

* fix more tests

* fix one more test
  • Loading branch information
rcoh authored Feb 6, 2023
1 parent 7bf9251 commit c9275fb
Show file tree
Hide file tree
Showing 20 changed files with 323 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.client.Fluen
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientGenerics
import software.amazon.smithy.rust.codegen.client.smithy.generators.client.FluentClientSection
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.Feature
import software.amazon.smithy.rust.codegen.core.rustlang.GenericTypeArg
import software.amazon.smithy.rust.codegen.core.rustlang.RustGenerics
Expand Down Expand Up @@ -228,7 +227,7 @@ private class AwsFluentClientDocs(private val codegenContext: CodegenContext) :
private val serviceShape = codegenContext.serviceShape
private val crateName = codegenContext.moduleUseName()
private val codegenScope =
arrayOf("aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).copy(scope = DependencyScope.Dev).toType())
arrayOf("aws_config" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toDevDependency().toType())

// If no `aws-config` version is provided, assume that docs referencing `aws-config` cannot be given.
// Also, STS and SSO must NOT reference `aws-config` since that would create a circular dependency.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package software.amazon.smithy.rustsdk

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.DependencyScope
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeCrateLocation
Expand Down Expand Up @@ -63,7 +62,7 @@ object AwsRuntimeType {
fun awsCredentialTypes(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsCredentialTypes(runtimeConfig).toType()

fun awsCredentialTypesTestUtil(runtimeConfig: RuntimeConfig) =
AwsCargoDependency.awsCredentialTypes(runtimeConfig).copy(scope = DependencyScope.Dev).withFeature("test-util").toType()
AwsCargoDependency.awsCredentialTypes(runtimeConfig).toDevDependency().withFeature("test-util").toType()

fun awsEndpoint(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsEndpoint(runtimeConfig).toType()
fun awsHttp(runtimeConfig: RuntimeConfig) = AwsCargoDependency.awsHttp(runtimeConfig).toType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import java.nio.file.Files
import java.nio.file.Paths
import kotlin.io.path.absolute
Expand Down Expand Up @@ -72,7 +73,7 @@ class IntegrationTestDependencies(
private val hasBenches: Boolean,
) : LibRsCustomization() {
override fun section(section: LibRsSection) = when (section) {
is LibRsSection.Body -> writable {
is LibRsSection.Body -> testDependenciesOnly {
if (hasTests) {
val smithyClient = CargoDependency.smithyClient(runtimeConfig)
.copy(features = setOf("test-util"), scope = DependencyScope.Dev)
Expand All @@ -81,7 +82,7 @@ class IntegrationTestDependencies(
addDependency(SerdeJson)
addDependency(Tokio)
addDependency(FuturesUtil)
addDependency(Tracing)
addDependency(Tracing.toDevDependency())
addDependency(TracingSubscriber)
}
if (hasBenches) {
Expand All @@ -91,6 +92,7 @@ class IntegrationTestDependencies(
serviceSpecific.section(section)(this)
}
}

else -> emptySection
}

Expand All @@ -114,8 +116,8 @@ class S3TestDependencies : LibRsCustomization() {
override fun section(section: LibRsSection): Writable =
writable {
addDependency(AsyncStd)
addDependency(BytesUtils)
addDependency(FastRand)
addDependency(BytesUtils.toDevDependency())
addDependency(FastRand.toDevDependency())
addDependency(HdrHistogram)
addDependency(Smol)
addDependency(TempFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ private val UploadMultipartPart: ShapeId = ShapeId.from("com.amazonaws.glacier#U
private val Applies = setOf(UploadArchive, UploadMultipartPart)

class TreeHashHeader(private val runtimeConfig: RuntimeConfig) : OperationCustomization() {
private val glacierChecksums = RuntimeType.forInlineDependency(InlineAwsDependency.forRustFile("glacier_checksums"))
private val glacierChecksums = RuntimeType.forInlineDependency(
InlineAwsDependency.forRustFile(
"glacier_checksums",
additionalDependency = TreeHashDependencies.toTypedArray(),
),
)

override fun section(section: OperationSection): Writable {
return when (section) {
is OperationSection.MutateRequest -> writable {
TreeHashDependencies.forEach { dep ->
addDependency(dep)
}
rustTemplate(
"""
#{glacier_checksums}::add_checksum_treehash(
Expand All @@ -49,6 +52,7 @@ class TreeHashHeader(private val runtimeConfig: RuntimeConfig) : OperationCustom
"glacier_checksums" to glacierChecksums, "BuildError" to runtimeConfig.operationBuildError(),
)
}

else -> emptySection
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesG
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.PublicImportSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
Expand Down Expand Up @@ -146,8 +146,7 @@ class OperationInputTestGenerator(_ctx: ClientCodegenContext, private val test:
let _result = dbg!(#{invoke_operation});
#{assertion}
""",
"capture_request" to CargoDependency.smithyClient(runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
"conf" to config(testOperationInput),
"invoke_operation" to operationInvocation(testOperationInput),
"assertion" to writable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest
Expand Down Expand Up @@ -96,8 +96,7 @@ class EndpointsCredentialsTest {
let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap();
assert!(auth_header.contains("/us-west-2/foobaz/aws4_request"), "{}", auth_header);
""",
"capture_request" to CargoDependency.smithyClient(context.runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(context.runtimeConfig),
"Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig)
.withFeature("test-util").toType().resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"),
Expand All @@ -120,8 +119,7 @@ class EndpointsCredentialsTest {
let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap();
assert!(auth_header.contains("/region-custom-auth/name-custom-auth/aws4_request"), "{}", auth_header);
""",
"capture_request" to CargoDependency.smithyClient(context.runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(context.runtimeConfig),
"Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig)
.withFeature("test-util").toType().resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.symbol
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand Down Expand Up @@ -59,7 +58,7 @@ val EndpointTests = RustModule.new(
documentation = "Generated endpoint tests",
parent = EndpointsModule,
inline = true,
).copy(rustMetadata = RustMetadata.TestModule)
).cfgTest()

// stdlib is isolated because it contains code generated names of stdlib functions–we want to ensure we avoid clashing
val EndpointsStdLib = RustModule.private("endpoint_lib", "Endpoints standard library functions")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustom
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.escape
Expand Down Expand Up @@ -48,8 +47,7 @@ internal class EndpointTestGenerator(
"Error" to types.resolveEndpointError,
"Document" to RuntimeType.document(runtimeConfig),
"HashMap" to RuntimeType.HashMap,
"capture_request" to CargoDependency.smithyClient(runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
)

private val instantiator = clientInstantiator(codegenContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand Down Expand Up @@ -91,14 +89,10 @@ class ProtocolTestGenerator(
if (allTests.isNotEmpty()) {
val operationName = operationSymbol.name
val testModuleName = "${operationName.toSnakeCase()}_request_test"
val moduleMeta = RustMetadata(
visibility = Visibility.PRIVATE,
additionalAttributes = listOf(
Attribute.CfgTest,
Attribute(allow("unreachable_code", "unused_variables")),
),
val additionalAttributes = listOf(
Attribute(allow("unreachable_code", "unused_variables")),
)
writer.withInlineModule(RustModule.LeafModule(testModuleName, moduleMeta, inline = true)) {
writer.withInlineModule(RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes)) {
renderAllTestCases(allTests)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ data class CargoDependency(
return copy(features = features.toMutableSet().apply { add(feature) })
}

fun toDevDependency() = copy(scope = DependencyScope.Dev)

override fun version(): String = when (location) {
is CratesIo -> location.version
is Local -> "local"
Expand Down Expand Up @@ -220,7 +222,12 @@ data class CargoDependency(
val Smol: CargoDependency = CargoDependency("smol", CratesIo("1.2.0"), DependencyScope.Dev)
val TempFile: CargoDependency = CargoDependency("tempfile", CratesIo("3.2.0"), DependencyScope.Dev)
val Tokio: CargoDependency =
CargoDependency("tokio", CratesIo("1.8.4"), DependencyScope.Dev, features = setOf("macros", "test-util", "rt-multi-thread"))
CargoDependency(
"tokio",
CratesIo("1.8.4"),
DependencyScope.Dev,
features = setOf("macros", "test-util", "rt-multi-thread"),
)
val TracingAppender: CargoDependency = CargoDependency(
"tracing-appender",
CratesIo("0.2.2"),
Expand All @@ -236,12 +243,16 @@ data class CargoDependency(
fun smithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-async")
fun smithyChecksums(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-checksums")
fun smithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-client")
fun smithyClientTestUtil(runtimeConfig: RuntimeConfig) =
smithyClient(runtimeConfig).toDevDependency().withFeature("test-util")

fun smithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-eventstream")
fun smithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http")
fun smithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-tower")
fun smithyJson(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-json")
fun smithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) =
runtimeConfig.smithyRuntimeCrate("smithy-protocol-test", scope = DependencyScope.Dev)

fun smithyQuery(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-query")
fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types")
fun smithyXml(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-xml")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ sealed class RustModule {
val documentation: String? = null,
val parent: RustModule = LibRs,
val inline: Boolean = false,
/* module is a cfg(test) module */
val tests: Boolean = false,
) : RustModule() {

init {
check(!name.contains("::")) {
"Module names CANNOT contain `::`—modules must be nested with parent (name was: `$name`)"
Expand All @@ -45,6 +48,12 @@ sealed class RustModule {
"Module `$name` cannot be a module name—it is a reserved word."
}
}

/** Convert a module into a module gated with `#[cfg(test)]` */
fun cfgTest(): LeafModule = this.copy(
rustMetadata = rustMetadata.copy(additionalAttributes = rustMetadata.additionalAttributes + Attribute.CfgTest),
tests = true,
)
}

companion object {
Expand Down Expand Up @@ -78,12 +87,36 @@ sealed class RustModule {
fun pubCrate(name: String, documentation: String? = null, parent: RustModule): LeafModule =
new(name, visibility = Visibility.PUBCRATE, documentation = documentation, inline = false, parent = parent)

fun inlineTests(
name: String = "test",
parent: RustModule = LibRs,
additionalAttributes: List<Attribute> = listOf(),
) = new(
name,
Visibility.PRIVATE,
inline = true,
additionalAttributes = additionalAttributes,
parent = parent,
).cfgTest()

/* Common modules used across client, server and tests */
val Config = public("config", documentation = "Configuration for the service.")
val Error = public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.")
val Model = public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.")
val Input = public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.")
val Output = public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.")
val Error = public(
"error",
documentation = "All error types that operations can return. Documentation on these types is copied from the model.",
)
val Model = public(
"model",
documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.",
)
val Input = public(
"input",
documentation = "Input structures for operations. Documentation on these types is copied from the model.",
)
val Output = public(
"output",
documentation = "Output structures for operations. Documentation on these types is copied from the model.",
)
val Types = public("types", documentation = "Data primitives referenced by other data types.")

/**
Expand Down
Loading

0 comments on commit c9275fb

Please sign in to comment.