From 02ac6b5767e2d8643ef5d1bd4cbd17f4f907b100 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Tue, 25 May 2021 09:19:03 -0700 Subject: [PATCH 1/3] Rename JSON generators --- .../smithy/rust/codegen/smithy/protocols/AwsJson10.kt | 8 ++++---- .../smithy/rust/codegen/smithy/protocols/AwsRestJson.kt | 8 ++++---- ...JsonParserGenerator.kt => SerdeJsonParserGenerator.kt} | 2 +- ...alizerGenerator.kt => SerdeJsonSerializerGenerator.kt} | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) rename codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/{JsonParserGenerator.kt => SerdeJsonParserGenerator.kt} (98%) rename codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/{JsonSerializerGenerator.kt => SerdeJsonSerializerGenerator.kt} (97%) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt index c7408889a9..0a7de3333e 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsJson10.kt @@ -36,8 +36,8 @@ import software.amazon.smithy.rust.codegen.smithy.generators.error.errorSymbol import software.amazon.smithy.rust.codegen.smithy.generators.operationBuildError import software.amazon.smithy.rust.codegen.smithy.locatedIn import software.amazon.smithy.rust.codegen.smithy.meta -import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.JsonParserGenerator -import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.JsonSerializerGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonParserGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.smithy.traits.InputBodyTrait import software.amazon.smithy.rust.codegen.smithy.traits.OutputBodyTrait @@ -198,7 +198,7 @@ class BasicAwsJsonGenerator( } override fun RustWriter.body(self: String, operationShape: OperationShape): BodyMetadata { - val generator = JsonSerializerGenerator(protocolConfig) + val generator = SerdeJsonSerializerGenerator(protocolConfig) val serializer = generator.operationSerializer(operationShape) serializer?.also { sym -> rustTemplate( @@ -214,7 +214,7 @@ class BasicAwsJsonGenerator( val outputShape = operationIndex.getOutput(operationShape).get() val errorSymbol = operationShape.errorSymbol(symbolProvider) val jsonErrors = RuntimeType.awsJsonErrors(protocolConfig.runtimeConfig) - val generator = JsonParserGenerator(protocolConfig) + val generator = SerdeJsonParserGenerator(protocolConfig) fromResponseFun(implBlockWriter, operationShape) { rustBlock("if #T::is_error(&response)", jsonErrors) { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt index 85477e30f9..9444fd2769 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/AwsRestJson.kt @@ -18,8 +18,8 @@ import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolSupport -import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.JsonParserGenerator -import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.JsonSerializerGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonParserGenerator +import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.SerdeJsonSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.smithy.protocols.parsers.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer @@ -77,11 +77,11 @@ class AwsRestJsonFactory : ProtocolGeneratorFactory class RestJson(private val protocolConfig: ProtocolConfig) : Protocol { private val runtimeConfig = protocolConfig.runtimeConfig override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator { - return JsonParserGenerator(protocolConfig) + return SerdeJsonParserGenerator(protocolConfig) } override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator { - return JsonSerializerGenerator(protocolConfig) + return SerdeJsonSerializerGenerator(protocolConfig) } override fun parseGenericError(operationShape: OperationShape): RuntimeType { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonParserGenerator.kt similarity index 98% rename from codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonParserGenerator.kt rename to codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonParserGenerator.kt index 3db46dbc6a..c83cb24025 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonParserGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonParserGenerator.kt @@ -21,7 +21,7 @@ import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.outputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase -class JsonParserGenerator(protocolConfig: ProtocolConfig) : StructuredDataParserGenerator { +class SerdeJsonParserGenerator(protocolConfig: ProtocolConfig) : StructuredDataParserGenerator { private val model = protocolConfig.model private val symbolProvider = protocolConfig.symbolProvider private val runtimeConfig = protocolConfig.runtimeConfig diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonSerializerGenerator.kt similarity index 97% rename from codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt rename to codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonSerializerGenerator.kt index f46ca46ba3..cc5853fd06 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/SerdeJsonSerializerGenerator.kt @@ -19,7 +19,7 @@ import software.amazon.smithy.rust.codegen.util.expectTrait import software.amazon.smithy.rust.codegen.util.inputShape import software.amazon.smithy.rust.codegen.util.toSnakeCase -class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator { +class SerdeJsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator { private val model = protocolConfig.model private val symbolProvider = protocolConfig.symbolProvider private val runtimeConfig = protocolConfig.runtimeConfig From 2b544d7826feb895eb39b25f4b3838dae8b62e76 Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Tue, 25 May 2021 09:21:01 -0700 Subject: [PATCH 2/3] Implement new JsonSerializerGenerator without document/operation support --- .../rust/codegen/rustlang/CargoDependency.kt | 2 + .../smithy/rust/codegen/rustlang/RustTypes.kt | 21 +- .../rust/codegen/rustlang/RustWriter.kt | 10 + .../rust/codegen/smithy/RuntimeTypes.kt | 12 +- .../codegen/smithy/protocols/XmlNameIndex.kt | 6 +- .../parsers/JsonSerializerGenerator.kt | 325 ++++++++++++++++++ .../smithy/traits/SyntheticInputTrait.kt | 8 +- .../transformers/OperationNormalizer.kt | 1 + .../parsers/JsonSerializerGeneratorTest.kt | 128 +++++++ 9 files changed, 490 insertions(+), 23 deletions(-) create mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt create mode 100644 codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt index 2d573abcda..0f4920919a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/CargoDependency.kt @@ -187,6 +187,8 @@ data class CargoDependency( "protocol-test-helpers", Local(runtimeConfig.relativePath), scope = DependencyScope.Dev ) + fun smithyJson(runtimeConfig: RuntimeConfig): CargoDependency = + CargoDependency("${runtimeConfig.cratePrefix}-json", Local(runtimeConfig.relativePath)) fun smithyXml(runtimeConfig: RuntimeConfig): CargoDependency = CargoDependency("${runtimeConfig.cratePrefix}-xml", Local(runtimeConfig.relativePath)) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt index 45a630be93..23ed78f8d2 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt @@ -128,22 +128,15 @@ fun RustType.render(fullyQualified: Boolean = true): String { * Option.contains(Instant) would return true. * Option.contains(Blob) would return false. */ -fun RustType.contains(t: T): Boolean { - if (t == this) { - return true - } - - return when (this) { - is RustType.Container -> this.member.contains(t) - else -> false - } +fun RustType.contains(t: T): Boolean = when (this) { + t -> true + is RustType.Container -> this.member.contains(t) + else -> false } -inline fun RustType.stripOuter(): RustType { - return when (this) { - is T -> this.member - else -> this - } +inline fun RustType.stripOuter(): RustType = when (this) { + is T -> this.member + else -> this } /** diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt index 3ae8fe9087..024cd33abd 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustWriter.kt @@ -73,6 +73,16 @@ fun T.rust( this.write(contents, *args) } +/** + * Convenience wrapper that tells Intellij that the contents of this block are Rust + */ +fun T.rustInline( + @Language("Rust", prefix = "macro_rules! foo { () => {{ ", suffix = "}}}") contents: String, + vararg args: Any +) { + this.writeInline(contents, *args) +} + /** * Sibling method to [rustBlock] that enables `#{variablename}` style templating */ diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt index 4fd2cf260a..29d9b411d9 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/RuntimeTypes.kt @@ -65,16 +65,18 @@ data class RuntimeType(val name: String?, val dependency: RustDependency?, val n namespace = "${runtimeConfig.cratePrefix}_types::retry" ) - val Default: RuntimeType = RuntimeType("Default", dependency = null, namespace = "std::default") - val From = RuntimeType("From", dependency = null, namespace = "std::convert") - val AsRef = RuntimeType("AsRef", dependency = null, namespace = "std::convert") val std = RuntimeType(null, dependency = null, namespace = "std") val stdfmt = std.member("fmt") - val StdError = RuntimeType("Error", dependency = null, namespace = "std::error") + + val AsRef = RuntimeType("AsRef", dependency = null, namespace = "std::convert") val ByteSlab = RuntimeType("Vec", dependency = null, namespace = "std::vec") + val Clone = std.member("clone::Clone") val Debug = stdfmt.member("Debug") + val Default: RuntimeType = RuntimeType("Default", dependency = null, namespace = "std::default") + val From = RuntimeType("From", dependency = null, namespace = "std::convert") val PartialEq = std.member("cmp::PartialEq") - val Clone = std.member("clone::Clone") + val StdError = RuntimeType("Error", dependency = null, namespace = "std::error") + val String = RuntimeType("String", dependency = null, namespace = "std::string") fun Instant(runtimeConfig: RuntimeConfig) = RuntimeType("Instant", CargoDependency.SmithyTypes(runtimeConfig), "${runtimeConfig.cratePrefix}_types") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlNameIndex.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlNameIndex.kt index 5e5447604b..d837d9a59a 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlNameIndex.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/XmlNameIndex.kt @@ -44,9 +44,9 @@ class XmlNameIndex(private val model: Model) : KnowledgeIndex { } fun operationInputShapeName(operationShape: OperationShape): String? { - val outputShape = operationShape.inputShape(model) - val rename = outputShape.getTrait()?.value - return rename ?: outputShape.expectTrait().originalId?.name + val inputShape = operationShape.inputShape(model) + val rename = inputShape.getTrait()?.value + return rename ?: inputShape.expectTrait().originalId?.name } fun memberName(member: MemberShape): String { diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt new file mode 100644 index 0000000000..a0a3aa8bc3 --- /dev/null +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt @@ -0,0 +1,325 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.rust.codegen.smithy.protocols.parsers + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.model.knowledge.HttpBinding +import software.amazon.smithy.model.knowledge.HttpBindingIndex +import software.amazon.smithy.model.shapes.BlobShape +import software.amazon.smithy.model.shapes.BooleanShape +import software.amazon.smithy.model.shapes.CollectionShape +import software.amazon.smithy.model.shapes.DocumentShape +import software.amazon.smithy.model.shapes.MapShape +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.shapes.TimestampShape +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EnumTrait +import software.amazon.smithy.model.traits.JsonNameTrait +import software.amazon.smithy.model.traits.TimestampFormatTrait.Format.EPOCH_SECONDS +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency +import software.amazon.smithy.rust.codegen.rustlang.RustType +import software.amazon.smithy.rust.codegen.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.rust +import software.amazon.smithy.rust.codegen.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate +import software.amazon.smithy.rust.codegen.rustlang.rustInline +import software.amazon.smithy.rust.codegen.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.smithy.RustSymbolProvider +import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig +import software.amazon.smithy.rust.codegen.smithy.isOptional +import software.amazon.smithy.rust.codegen.smithy.rustType +import software.amazon.smithy.rust.codegen.smithy.traits.SyntheticInputTrait +import software.amazon.smithy.rust.codegen.util.dq +import software.amazon.smithy.rust.codegen.util.expectTrait +import software.amazon.smithy.rust.codegen.util.getTrait +import software.amazon.smithy.rust.codegen.util.hasTrait +import software.amazon.smithy.rust.codegen.util.inputShape +import software.amazon.smithy.rust.codegen.util.toPascalCase +import software.amazon.smithy.rust.codegen.util.toSnakeCase + +private data class SimpleContext( + /** Name of the JsonObjectWriter or JsonArrayWriter */ + val writerName: String, + val localName: String, + val shape: T, +) + +private data class StructContext( + /** Name of the JsonObjectWriter */ + val objectName: String, + val localName: String, + val shape: StructureShape, + val symbolProvider: RustSymbolProvider, +) { + fun member(member: MemberShape): MemberContext = + MemberContext(objectName, "$localName.${symbolProvider.toMemberName(member)}", member, structMember = true) +} + +private data class MemberContext( + /** Name of the JsonObjectWriter or JsonArrayWriter */ + val writerName: String, + val valueExpression: String, + val shape: MemberShape, + /** Whether we're working with a JsonObjectWriter (true) or JsonArrayWriter (false) */ + val structMember: Boolean, + private val givenKeyExpression: String? = null, +) { + val wireName: String = shape.getTrait()?.value ?: shape.memberName + val keyExpression: String = when (givenKeyExpression) { + null -> wireName.dq() + else -> givenKeyExpression + } + + /** Generates an expression that serializes the given [value] expression to the object/array */ + fun writeValue(w: RustWriter, jsonType: String, key: String, value: String) = when (structMember) { + true -> w.rust("$writerName.$jsonType($key, $value);") + else -> w.rust("$writerName.$jsonType($value);") + } + + /** Generates an expression that serializes the given [inner] expression to the object/array */ + fun writeInner(w: RustWriter, jsonType: String, key: String, inner: RustWriter.() -> Unit) { + w.rustInline("$writerName.$jsonType(") + if (structMember) { + w.writeInline("$key, ") + } + inner(w) + w.write(");") + } + + /** Generates a mutable declaration for serializing a new object */ + fun writeStartObject(w: RustWriter, decl: String, key: String) = when (structMember) { + true -> w.rust("let mut $decl = $writerName.start_object($key);") + else -> w.rust("let mut $decl = $writerName.start_object();") + } + + /** Generates a mutable declaration for serializing a new array */ + fun writeStartArray(w: RustWriter, decl: String, key: String) = when (structMember) { + true -> w.rust("let mut $decl = $writerName.start_array($key);") + else -> w.rust("let mut $decl = $writerName.start_array();") + } +} + +class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator { + private val model = protocolConfig.model + private val symbolProvider = protocolConfig.symbolProvider + private val runtimeConfig = protocolConfig.runtimeConfig + private val serializerError = RuntimeType.SerdeJson("error::Error") + private val smithyTypes = CargoDependency.SmithyTypes(runtimeConfig).asType() + private val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() + private val codegenScope = arrayOf( + "String" to RuntimeType.String, + "Error" to serializerError, + "SdkBody" to RuntimeType.sdkBody(runtimeConfig), + "JsonObjectWriter" to smithyJson.member("serialize::JsonObjectWriter"), + ) + private val httpIndex = HttpBindingIndex.of(model) + + override fun payloadSerializer(member: MemberShape): RuntimeType { + val target = model.expectShape(member.target, StructureShape::class.java) + val fnName = "serialize_payload_${target.id.name.toSnakeCase()}_${member.container.name.toSnakeCase()}" + return RuntimeType.forInlineFun(fnName, "operation_ser") { writer -> + writer.rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", + *codegenScope, + "target" to symbolProvider.toSymbol(target) + ) { + rust("let mut out = String::new();") + rustTemplate("let mut object = #{JsonObjectWriter}::new(&mut out);", *codegenScope) + serializeStructure(StructContext("object", "input", target, symbolProvider)) + rust("object.finish();") + rustTemplate("Ok(#{SdkBody}::from(out))", *codegenScope) + } + } + } + + override fun operationSerializer(operationShape: OperationShape): RuntimeType? { + val inputShape = operationShape.inputShape(model) + val inputShapeName = inputShape.expectTrait().originalId?.name + ?: throw CodegenException("operation must have a name if it has members") + val fnName = "serialize_operation_${inputShapeName.toSnakeCase()}" + return RuntimeType.forInlineFun(fnName, "operation_ser") { + it.rustBlockTemplate( + "pub fn $fnName(input: &#{target}) -> Result<#{SdkBody}, #{Error}>", + *codegenScope, "target" to symbolProvider.toSymbol(inputShape) + ) { + // TODO: Implement operation serialization + rust("unimplemented!()") + } + } + } + + override fun documentSerializer(): RuntimeType { + val fnName = "serialize_document" + return RuntimeType.forInlineFun(fnName, "operation_ser") { + it.rustTemplate( + // TODO: Implement document parsing + """ + pub fn $fnName(input: &#{Document}) -> Result<#{SdkBody}, #{Error}> { + unimplemented!(); + } + """, + "Document" to RuntimeType.Document(runtimeConfig), *codegenScope + ) + } + } + + private fun RustWriter.serializeStructure(context: StructContext) { + val fnName = "serialize_structure_${context.shape.id.name.toSnakeCase()}" + val structureSymbol = symbolProvider.toSymbol(context.shape) + val structureSerializer = RuntimeType.forInlineFun(fnName, "json_ser") { writer -> + writer.rustBlockTemplate( + "pub fn $fnName(${context.objectName}: &mut #{JsonObjectWriter}, input: &#{Shape})", + "Shape" to structureSymbol, + *codegenScope, + ) { + if (context.shape.members().isEmpty()) { + rust("let _ = input;") // Suppress an unused argument warning + } + for (member in context.shape.members()) { + serializeMember(context.member(member)) + } + } + } + rust("#T(${context.objectName.borrowMut()}, ${context.localName});", structureSerializer) + } + + private fun RustWriter.serializeMember(context: MemberContext) { + val target = model.expectShape(context.shape.target) + handleOptional(context) { inner -> + val key = inner.keyExpression + val value = inner.valueExpression.borrow() + when (target) { + is StringShape -> when (target.hasTrait()) { + true -> context.writeValue(this, "string", key, "$value.as_str()") + false -> context.writeValue(this, "string", key, value) + } + is BooleanShape -> context.writeValue(this, "boolean", key, value) + is NumberShape -> { + val numberType = when (symbolProvider.toSymbol(target).rustType()) { + is RustType.Float -> "Float" + is RustType.Integer -> "NegInt" + else -> throw IllegalStateException("unreachable") + } + context.writeInner(this, "number", key) { + rustInline("#T::$numberType(*${inner.valueExpression})", smithyTypes.member("Number")) + } + } + is BlobShape -> context.writeInner(this, "string_unchecked", key) { + rustInline("&#T($value)", RuntimeType.Base64Encode(runtimeConfig)) + } + is TimestampShape -> { + val timestampFormat = + httpIndex.determineTimestampFormat(context.shape, HttpBinding.Location.DOCUMENT, EPOCH_SECONDS) + val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) + context.writeInner(this, "instant", key) { + rustInline("$value, #T", timestampFormatType) + } + } + is CollectionShape -> jsonArrayWriter(inner) { arrayName -> + serializeCollection(SimpleContext(arrayName, inner.valueExpression, target)) + } + is MapShape -> jsonObjectWriter(inner) { objectName -> + serializeMap(SimpleContext(objectName, inner.valueExpression, target)) + } + is StructureShape -> jsonObjectWriter(inner) { objectName -> + serializeStructure(StructContext(objectName, inner.valueExpression, target, symbolProvider)) + } + is UnionShape -> jsonObjectWriter(inner) { objectName -> + serializeUnion(SimpleContext(objectName, inner.valueExpression, target)) + } + is DocumentShape -> { + // TODO: Implement document shapes + } + else -> TODO(target.toString()) + } + } + } + + private fun RustWriter.jsonArrayWriter(context: MemberContext, inner: RustWriter.(String) -> Unit) { + safeName("array").also { arrayName -> + context.writeStartArray(this, arrayName, context.keyExpression) + inner(arrayName) + rust("$arrayName.finish();") + } + } + + private fun RustWriter.jsonObjectWriter(context: MemberContext, inner: RustWriter.(String) -> Unit) { + safeName("object").also { objectName -> + context.writeStartObject(this, objectName, context.keyExpression) + inner(objectName) + rust("$objectName.finish();") + } + } + + private fun RustWriter.serializeCollection(context: SimpleContext) { + val itemName = safeName("item") + rustBlock("for $itemName in ${context.localName}") { + serializeMember(MemberContext(context.writerName, itemName, context.shape.member, structMember = false)) + } + } + + private fun RustWriter.serializeMap(context: SimpleContext) { + val keyName = safeName("key") + val valueName = safeName("value") + val valueShape = context.shape.value + rustBlock("for ($keyName, $valueName) in ${context.localName}") { + serializeMember( + MemberContext( + context.writerName, + valueName, + valueShape, + structMember = true, + givenKeyExpression = keyName + ) + ) + } + } + + private fun RustWriter.serializeUnion(context: SimpleContext) { + val fnName = "serialize_union_${context.shape.id.name.toSnakeCase()}" + val unionSymbol = symbolProvider.toSymbol(context.shape) + val unionSerializer = RuntimeType.forInlineFun(fnName, "json_ser") { writer -> + writer.rustBlockTemplate( + "pub fn $fnName(${context.writerName}: &mut #{JsonObjectWriter}, input: &#{Shape})", + "Shape" to unionSymbol, + *codegenScope, + ) { + rustBlock("match input") { + for (member in context.shape.members()) { + val variantName = member.memberName.toPascalCase() + withBlock("#T::$variantName(inner) => {", "},", unionSymbol) { + serializeMember(MemberContext(context.writerName, "inner", member, true)) + } + } + } + } + } + rust("#T(${context.writerName.borrowMut()}, ${context.localName});", unionSerializer) + } + + private fun RustWriter.handleOptional(context: MemberContext, inner: RustWriter.(MemberContext) -> Unit) { + if (symbolProvider.toSymbol(context.shape).isOptional()) { + safeName().also { localDecl -> + rustBlock("if let Some($localDecl) = ${context.valueExpression.borrow()}") { + inner(context.copy(valueExpression = localDecl)) + } + } + } else { + inner(context) + } + } +} + +private fun String.borrow(): String = "&$this" +private fun String.borrowMut(): String = "&mut $this" diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/SyntheticInputTrait.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/SyntheticInputTrait.kt index d98daebc4b..7493a32899 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/SyntheticInputTrait.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/SyntheticInputTrait.kt @@ -12,7 +12,12 @@ import software.amazon.smithy.model.traits.AnnotationTrait /** * Indicates that a shape is a synthetic input (see `OperationNormalizer.kt`) */ -class SyntheticInputTrait constructor(val operation: ShapeId, val originalId: ShapeId?, val body: ShapeId?) : +class SyntheticInputTrait( + val operation: ShapeId, + val originalId: ShapeId?, + // TODO: Remove synthetic body when cleaning up serde json generators + val body: ShapeId? +) : AnnotationTrait(ID, ObjectNode.fromStringMap(mapOf("body" to body.toString()))) { companion object { val ID = ShapeId.from("smithy.api.internal#syntheticInput") @@ -22,6 +27,7 @@ class SyntheticInputTrait constructor(val operation: ShapeId, val originalId: Sh /** * Indicates that a shape is a synthetic input body */ +// TODO: Remove synthetic body when cleaning up serde json generators class InputBodyTrait(objectNode: ObjectNode = ObjectNode.objectNode()) : AnnotationTrait(ID, objectNode) { companion object { val ID = ShapeId.from("smithy.api.internal#syntheticInputBody") diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt index 277eb76c51..4a87d83033 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/transformers/OperationNormalizer.kt @@ -114,6 +114,7 @@ class OperationNormalizer(private val model: Model) { // Rename safety: Operations cannot be renamed private fun OperationShape.syntheticInputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Input") private fun OperationShape.syntheticOutputId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}Output") + // TODO: Remove synthetic body when cleaning up serde json generators private fun OperationShape.syntheticInputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}InputBody") private fun OperationShape.syntheticOutputBodyId() = ShapeId.fromParts(this.id.namespace, "${this.id.name}OutputBody") diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt new file mode 100644 index 0000000000..2903583e71 --- /dev/null +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt @@ -0,0 +1,128 @@ +package software.amazon.smithy.rust.codegen.smithy.protocols.parsers + +import org.junit.jupiter.api.Test +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.StringShape +import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.rust.codegen.rustlang.RustModule +import software.amazon.smithy.rust.codegen.smithy.generators.EnumGenerator +import software.amazon.smithy.rust.codegen.smithy.generators.UnionGenerator +import software.amazon.smithy.rust.codegen.smithy.transformers.OperationNormalizer +import software.amazon.smithy.rust.codegen.smithy.transformers.RecursiveShapeBoxer +import software.amazon.smithy.rust.codegen.testutil.TestWorkspace +import software.amazon.smithy.rust.codegen.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.testutil.compileAndTest +import software.amazon.smithy.rust.codegen.testutil.renderWithModelBuilder +import software.amazon.smithy.rust.codegen.testutil.testProtocolConfig +import software.amazon.smithy.rust.codegen.testutil.testSymbolProvider +import software.amazon.smithy.rust.codegen.testutil.unitTest +import software.amazon.smithy.rust.codegen.util.expectTrait +import software.amazon.smithy.rust.codegen.util.inputShape +import software.amazon.smithy.rust.codegen.util.lookup + +class JsonSerializerGeneratorTest { + private val baseModel = """ + namespace test + use aws.protocols#restJson1 + + union Choice { + map: MyMap, + list: SomeList, + s: String, + enum: FooEnum, + date: Timestamp, + number: Double, + top: Top, + blob: Blob, + document: Document, + } + + @enum([{name: "FOO", value: "FOO"}]) + string FooEnum + + map MyMap { + key: String, + value: Choice, + } + + list SomeList { + member: Choice + } + + structure Top { + choice: Choice, + field: String, + extra: Long, + recursive: TopList + } + + list TopList { + member: Top + } + + structure OpInput { + @httpHeader("x-test") + someHeader: String, + @httpPayload + payload: Top + } + + @http(uri: "/top", method: "POST") + operation Op { + input: OpInput, + } + """.asSmithyModel() + + @Test + fun `generates valid serializers`() { + val model = RecursiveShapeBoxer.transform( + OperationNormalizer(baseModel).transformModel( + OperationNormalizer.NoBody, + OperationNormalizer.NoBody + ) + ) + val symbolProvider = testSymbolProvider(model) + val parserGenerator = JsonSerializerGenerator(testProtocolConfig(model)) + val payloadGenerator = parserGenerator.payloadSerializer(model.lookup("test#OpInput\$payload")) + val operationGenerator = parserGenerator.operationSerializer(model.lookup("test#Op")) + val documentGenerator = parserGenerator.documentSerializer() + + val project = TestWorkspace.testProject(testSymbolProvider(model)) + project.lib { writer -> + writer.unitTest( + """ + use model::Top; + + // Generate the operation/document serializers even if they're not directly tested + // ${writer.format(operationGenerator!!)} + // ${writer.format(documentGenerator)} + + let inp = crate::input::OpInput::builder().payload( + Top::builder() + .field("hello!") + .extra(45) + .recursive(Top::builder().extra(55).build()) + .build() + ).build().unwrap(); + let serialized = ${writer.format(payloadGenerator)}(&inp.payload.unwrap()).unwrap(); + let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap(); + assert_eq!(output, r#"{"field":"hello!","extra":45,"recursive":[{"extra":55}]}"#); + """ + ) + } + project.withModule(RustModule.default("model", public = true)) { + model.lookup("test#Top").renderWithModelBuilder(model, symbolProvider, it) + UnionGenerator(model, symbolProvider, it, model.lookup("test#Choice")).render() + val enum = model.lookup("test#FooEnum") + EnumGenerator(model, symbolProvider, it, enum, enum.expectTrait()).render() + } + + project.withModule(RustModule.default("input", public = true)) { + model.lookup("test#Op").inputShape(model).renderWithModelBuilder(model, symbolProvider, it) + } + println("file:///${project.baseDir}/src/operation_ser.rs") + println("file:///${project.baseDir}/src/json_ser.rs") + println("file:///${project.baseDir}/src/lib.rs") + project.compileAndTest() + } +} From f13e198bd7236b45968fadb1c56771e569020b2b Mon Sep 17 00:00:00 2001 From: John DiSanti Date: Tue, 25 May 2021 11:12:52 -0700 Subject: [PATCH 3/3] CR feedback --- .../parsers/JsonSerializerGenerator.kt | 95 ++++++++++--------- .../parsers/JsonSerializerGeneratorTest.kt | 6 +- 2 files changed, 54 insertions(+), 47 deletions(-) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt index a0a3aa8bc3..e44fb2f0f7 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGenerator.kt @@ -63,53 +63,67 @@ private data class StructContext( val symbolProvider: RustSymbolProvider, ) { fun member(member: MemberShape): MemberContext = - MemberContext(objectName, "$localName.${symbolProvider.toMemberName(member)}", member, structMember = true) + MemberContext(objectName, MemberDestination.Object(), "$localName.${symbolProvider.toMemberName(member)}", member) +} + +private sealed class MemberDestination { + // Add unused parameter so that Kotlin generates equals/hashCode for us + data class Array(private val unused: Int = 0) : MemberDestination() + data class Object(val keyNameOverride: String? = null) : MemberDestination() } private data class MemberContext( /** Name of the JsonObjectWriter or JsonArrayWriter */ val writerName: String, + val destination: MemberDestination, val valueExpression: String, val shape: MemberShape, - /** Whether we're working with a JsonObjectWriter (true) or JsonArrayWriter (false) */ - val structMember: Boolean, - private val givenKeyExpression: String? = null, ) { - val wireName: String = shape.getTrait()?.value ?: shape.memberName - val keyExpression: String = when (givenKeyExpression) { - null -> wireName.dq() - else -> givenKeyExpression + val keyExpression: String = when (destination) { + is MemberDestination.Object -> + destination.keyNameOverride ?: (shape.getTrait()?.value ?: shape.memberName).dq() + is MemberDestination.Array -> "" } /** Generates an expression that serializes the given [value] expression to the object/array */ - fun writeValue(w: RustWriter, jsonType: String, key: String, value: String) = when (structMember) { - true -> w.rust("$writerName.$jsonType($key, $value);") - else -> w.rust("$writerName.$jsonType($value);") + fun writeValue(w: RustWriter, writerFn: JsonWriterFn, key: String, value: String) = when (destination) { + is MemberDestination.Object -> w.rust("$writerName.$writerFn($key, $value);") + is MemberDestination.Array -> w.rust("$writerName.$writerFn($value);") } /** Generates an expression that serializes the given [inner] expression to the object/array */ - fun writeInner(w: RustWriter, jsonType: String, key: String, inner: RustWriter.() -> Unit) { - w.rustInline("$writerName.$jsonType(") - if (structMember) { - w.writeInline("$key, ") + fun writeInner(w: RustWriter, writerFn: JsonWriterFn, key: String, inner: RustWriter.() -> Unit) { + w.withBlock("$writerName.$writerFn(", ");") { + if (destination is MemberDestination.Object) { + w.writeInline("$key, ") + } + inner(w) } - inner(w) - w.write(");") } /** Generates a mutable declaration for serializing a new object */ - fun writeStartObject(w: RustWriter, decl: String, key: String) = when (structMember) { - true -> w.rust("let mut $decl = $writerName.start_object($key);") - else -> w.rust("let mut $decl = $writerName.start_object();") + fun writeStartObject(w: RustWriter, decl: String, key: String) = when (destination) { + is MemberDestination.Object -> w.rust("let mut $decl = $writerName.start_object($key);") + is MemberDestination.Array -> w.rust("let mut $decl = $writerName.start_object();") } /** Generates a mutable declaration for serializing a new array */ - fun writeStartArray(w: RustWriter, decl: String, key: String) = when (structMember) { - true -> w.rust("let mut $decl = $writerName.start_array($key);") - else -> w.rust("let mut $decl = $writerName.start_array();") + fun writeStartArray(w: RustWriter, decl: String, key: String) = when (destination) { + is MemberDestination.Object -> w.rust("let mut $decl = $writerName.start_array($key);") + is MemberDestination.Array -> w.rust("let mut $decl = $writerName.start_array();") } } +private enum class JsonWriterFn { + BOOLEAN, + INSTANT, + NUMBER, + STRING, + STRING_UNCHECKED; + + override fun toString(): String = name.toLowerCase() +} + class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSerializerGenerator { private val model = protocolConfig.model private val symbolProvider = protocolConfig.symbolProvider @@ -191,38 +205,38 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe } } } - rust("#T(${context.objectName.borrowMut()}, ${context.localName});", structureSerializer) + rust("#T(&mut ${context.objectName}, ${context.localName});", structureSerializer) } private fun RustWriter.serializeMember(context: MemberContext) { val target = model.expectShape(context.shape.target) handleOptional(context) { inner -> val key = inner.keyExpression - val value = inner.valueExpression.borrow() + val value = "&${inner.valueExpression}" when (target) { is StringShape -> when (target.hasTrait()) { - true -> context.writeValue(this, "string", key, "$value.as_str()") - false -> context.writeValue(this, "string", key, value) + true -> context.writeValue(this, JsonWriterFn.STRING, key, "$value.as_str()") + false -> context.writeValue(this, JsonWriterFn.STRING, key, value) } - is BooleanShape -> context.writeValue(this, "boolean", key, value) + is BooleanShape -> context.writeValue(this, JsonWriterFn.BOOLEAN, key, value) is NumberShape -> { val numberType = when (symbolProvider.toSymbol(target).rustType()) { is RustType.Float -> "Float" is RustType.Integer -> "NegInt" else -> throw IllegalStateException("unreachable") } - context.writeInner(this, "number", key) { + context.writeInner(this, JsonWriterFn.NUMBER, key) { rustInline("#T::$numberType(*${inner.valueExpression})", smithyTypes.member("Number")) } } - is BlobShape -> context.writeInner(this, "string_unchecked", key) { + is BlobShape -> context.writeInner(this, JsonWriterFn.STRING_UNCHECKED, key) { rustInline("&#T($value)", RuntimeType.Base64Encode(runtimeConfig)) } is TimestampShape -> { val timestampFormat = httpIndex.determineTimestampFormat(context.shape, HttpBinding.Location.DOCUMENT, EPOCH_SECONDS) val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) - context.writeInner(this, "instant", key) { + context.writeInner(this, JsonWriterFn.INSTANT, key) { rustInline("$value, #T", timestampFormatType) } } @@ -265,7 +279,7 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe private fun RustWriter.serializeCollection(context: SimpleContext) { val itemName = safeName("item") rustBlock("for $itemName in ${context.localName}") { - serializeMember(MemberContext(context.writerName, itemName, context.shape.member, structMember = false)) + serializeMember(MemberContext(context.writerName, MemberDestination.Array(), itemName, context.shape.member)) } } @@ -275,13 +289,7 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe val valueShape = context.shape.value rustBlock("for ($keyName, $valueName) in ${context.localName}") { serializeMember( - MemberContext( - context.writerName, - valueName, - valueShape, - structMember = true, - givenKeyExpression = keyName - ) + MemberContext(context.writerName, MemberDestination.Object(keyNameOverride = keyName), valueName, valueShape) ) } } @@ -299,19 +307,19 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe for (member in context.shape.members()) { val variantName = member.memberName.toPascalCase() withBlock("#T::$variantName(inner) => {", "},", unionSymbol) { - serializeMember(MemberContext(context.writerName, "inner", member, true)) + serializeMember(MemberContext(context.writerName, MemberDestination.Object(), "inner", member)) } } } } } - rust("#T(${context.writerName.borrowMut()}, ${context.localName});", unionSerializer) + rust("#T(&mut ${context.writerName}, ${context.localName});", unionSerializer) } private fun RustWriter.handleOptional(context: MemberContext, inner: RustWriter.(MemberContext) -> Unit) { if (symbolProvider.toSymbol(context.shape).isOptional()) { safeName().also { localDecl -> - rustBlock("if let Some($localDecl) = ${context.valueExpression.borrow()}") { + rustBlock("if let Some($localDecl) = &${context.valueExpression}") { inner(context.copy(valueExpression = localDecl)) } } @@ -320,6 +328,3 @@ class JsonSerializerGenerator(protocolConfig: ProtocolConfig) : StructuredDataSe } } } - -private fun String.borrow(): String = "&$this" -private fun String.borrowMut(): String = "&mut $this" diff --git a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt index 2903583e71..55a2bfa2e8 100644 --- a/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt +++ b/codegen/src/test/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parsers/JsonSerializerGeneratorTest.kt @@ -53,6 +53,7 @@ class JsonSerializerGeneratorTest { choice: Choice, field: String, extra: Long, + @jsonName("rec") recursive: TopList } @@ -106,7 +107,7 @@ class JsonSerializerGeneratorTest { ).build().unwrap(); let serialized = ${writer.format(payloadGenerator)}(&inp.payload.unwrap()).unwrap(); let output = std::str::from_utf8(serialized.bytes().unwrap()).unwrap(); - assert_eq!(output, r#"{"field":"hello!","extra":45,"recursive":[{"extra":55}]}"#); + assert_eq!(output, r#"{"field":"hello!","extra":45,"rec":[{"extra":55}]}"#); """ ) } @@ -120,9 +121,10 @@ class JsonSerializerGeneratorTest { project.withModule(RustModule.default("input", public = true)) { model.lookup("test#Op").inputShape(model).renderWithModelBuilder(model, symbolProvider, it) } - println("file:///${project.baseDir}/src/operation_ser.rs") println("file:///${project.baseDir}/src/json_ser.rs") println("file:///${project.baseDir}/src/lib.rs") + println("file:///${project.baseDir}/src/model.rs") + println("file:///${project.baseDir}/src/operation_ser.rs") project.compileAndTest() } }