Skip to content

Commit

Permalink
Simplify base event stream test requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
jdisanti committed Dec 20, 2022
1 parent 2659182 commit be77fa0
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,25 @@ import software.amazon.smithy.rust.codegen.client.testutil.clientTestRustSetting
import software.amazon.smithy.rust.codegen.client.testutil.testSymbolProvider
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.implBlock
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements

abstract class EventStreamBaseRequirements : EventStreamTestRequirements<ClientCodegenContext> {
override fun createCodegenContext(
model: Model,
symbolProvider: RustSymbolProvider,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): ClientCodegenContext = ClientCodegenContext(
model,
symbolProvider,
testSymbolProvider(model),
serviceShape,
protocolShapeId,
clientTestRustSettings(),
CombinedClientCodegenDecorator(emptyList()),
)

override fun createSymbolProvider(model: Model): RustSymbolProvider = testSymbolProvider(model)

override fun renderBuilderForShape(
writer: RustWriter,
codegenContext: ClientCodegenContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,11 @@ interface EventStreamTestRequirements<C : CodegenContext> {
/** Create a codegen context for the tests */
fun createCodegenContext(
model: Model,
symbolProvider: RustSymbolProvider,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): C

/** Create a symbol provider for the tests */
fun createSymbolProvider(model: Model): RustSymbolProvider

/** Render the event stream marshall/unmarshall code generator */
fun renderGenerator(
codegenContext: C,
Expand All @@ -86,11 +82,9 @@ object EventStreamTestTools {
variety: EventStreamTestVariety,
) {
val model = EventStreamNormalizer.transform(OperationNormalizer.transform(testCase.model))
val symbolProvider = requirements.createSymbolProvider(model)
val serviceShape = model.expectShape(ShapeId.from("test#TestService")) as ServiceShape
val codegenContext = requirements.createCodegenContext(
model,
symbolProvider,
serviceShape,
ShapeId.from(testCase.protocolShapeId),
codegenTarget,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,20 @@ private fun testServiceShapeFor(model: Model) =
fun serverTestSymbolProvider(model: Model, serviceShape: ServiceShape? = null) =
serverTestSymbolProviders(model, serviceShape).symbolProvider

fun serverTestSymbolProviders(model: Model, serviceShape: ServiceShape? = null) =
fun serverTestSymbolProviders(
model: Model,
serviceShape: ServiceShape? = null,
settings: ServerRustSettings? = null,
) =
ServerSymbolProviders.from(
model,
serviceShape ?: testServiceShapeFor(model),
ServerTestSymbolVisitorConfig,
serverTestRustSettings((serviceShape ?: testServiceShapeFor(model)).id).codegenConfig.publicConstrainedTypes,
(
settings ?: serverTestRustSettings(
(serviceShape ?: testServiceShapeFor(model)).id,
)
).codegenConfig.publicConstrainedTypes,
RustCodegenServerPlugin::baseSymbolProvider,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,25 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.testutil.EventStreamTestRequirements
import software.amazon.smithy.rust.codegen.server.smithy.RustCodegenServerPlugin
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenConfig
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerSymbolProviders
import software.amazon.smithy.rust.codegen.server.smithy.testutil.ServerTestSymbolVisitorConfig
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestRustSettings
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider

abstract class EventStreamBaseRequirements : EventStreamTestRequirements<ServerCodegenContext> {
abstract val publicConstrainedTypes: Boolean

override fun createCodegenContext(
model: Model,
symbolProvider: RustSymbolProvider,
serviceShape: ServiceShape,
protocolShapeId: ShapeId,
codegenTarget: CodegenTarget,
): ServerCodegenContext {
val settings = serverTestRustSettings()
val serverSymbolProviders = ServerSymbolProviders.from(
model,
serviceShape,
ServerTestSymbolVisitorConfig,
settings.codegenConfig.publicConstrainedTypes,
RustCodegenServerPlugin::baseSymbolProvider,
)
return ServerCodegenContext(
model,
symbolProvider,
serviceShape,
protocolShapeId,
settings,
serverSymbolProviders.unconstrainedShapeSymbolProvider,
serverSymbolProviders.constrainedShapeSymbolProvider,
serverSymbolProviders.constraintViolationSymbolProvider,
serverSymbolProviders.pubCrateConstrainedShapeSymbolProvider,
)
}

override fun createSymbolProvider(model: Model): RustSymbolProvider =
serverTestSymbolProvider(model)
): ServerCodegenContext = serverTestCodegenContext(
model, serviceShape,
serverTestRustSettings(
codegenConfig = ServerCodegenConfig(publicConstrainedTypes = publicConstrainedTypes),
),
protocolShapeId,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class EventStreamMarshallerGeneratorTest {
EventStreamTestTools.runTestCase(
testCase,
object : EventStreamBaseRequirements() {
override val publicConstrainedTypes: Boolean get() = true

override fun renderGenerator(
codegenContext: ServerCodegenContext,
project: TestEventStreamProject,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class EventStreamUnmarshallerGeneratorTest {
EventStreamTestTools.runTestCase(
testCase.eventStreamTestCase,
object : EventStreamBaseRequirements() {
override val publicConstrainedTypes: Boolean get() = testCase.publicConstrainedTypes

override fun renderGenerator(
codegenContext: ServerCodegenContext,
project: TestEventStreamProject,
Expand Down

0 comments on commit be77fa0

Please sign in to comment.