Skip to content

Commit

Permalink
Avoid explicitly emitting Unit type within Union
Browse files Browse the repository at this point in the history
This commit addresses #1546. Previously, the Unit type in a Union was
rendered as an enum variant whose inner data was crate::model::Unit.
The way such a variant appears in Rust code feels a bit odd as it
usually does not carry inner data for `()`. We now render a Union
member of type Unit to an enum variant without inner data.
  • Loading branch information
ysaito1001 committed Nov 15, 2022
1 parent 15f2fbb commit bdf7858
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.MemberShape
Expand Down Expand Up @@ -49,16 +50,18 @@ class UnionGenerator(
private val sortedMembers: List<MemberShape> = shape.allMembers.values.sortedBy { symbolProvider.toMemberName(it) }

fun render() {
renderUnion()
}

private fun renderUnion() {
writer.documentShape(shape, model)
writer.deprecatedShape(shape)

val unionSymbol = symbolProvider.toSymbol(shape)
val containerMeta = unionSymbol.expectRustMetadata()
containerMeta.render(writer)

renderUnion(unionSymbol)
renderImplBlock(unionSymbol)
}

private fun renderUnion(unionSymbol: Symbol) {
writer.rustBlock("enum ${unionSymbol.name}") {
sortedMembers.forEach { member ->
val memberSymbol = symbolProvider.toSymbol(member)
Expand All @@ -67,7 +70,7 @@ class UnionGenerator(
documentShape(member, model, note = note)
deprecatedShape(member)
memberSymbol.expectRustMetadata().renderAttributes(this)
write("${symbolProvider.toMemberName(member)}(#T),", symbolProvider.toSymbol(member))
writer.renderVariant(symbolProvider, member, memberSymbol)
}
if (renderUnknownVariant) {
docs("""The `Unknown` variant represents cases where new union variant was received. Consider upgrading the SDK to the latest available version.""")
Expand All @@ -82,6 +85,9 @@ class UnionGenerator(
rust("Unknown,")
}
}
}

private fun renderImplBlock(unionSymbol: Symbol) {
writer.rustBlock("impl ${unionSymbol.name}") {
sortedMembers.forEach { member ->
val memberSymbol = symbolProvider.toSymbol(member)
Expand All @@ -91,11 +97,7 @@ class UnionGenerator(
if (sortedMembers.size == 1) {
Attribute.Custom("allow(irrefutable_let_patterns)").render(this)
}
rust("/// Tries to convert the enum instance into [`$variantName`](#T::$variantName), extracting the inner #D.", unionSymbol, memberSymbol)
rust("/// Returns `Err(&Self)` if it can't be converted.")
rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<&#T, &Self>", memberSymbol) {
rust("if let ${unionSymbol.name}::$variantName(val) = &self { Ok(val) } else { Err(self) }")
}
writer.renderAsVariant(member, variantName, funcNamePart, unionSymbol, memberSymbol)
rust("/// Returns true if this is a [`$variantName`](#T::$variantName).", unionSymbol)
rustBlock("pub fn is_$funcNamePart(&self) -> bool") {
rust("self.as_$funcNamePart().is_ok()")
Expand All @@ -114,7 +116,44 @@ class UnionGenerator(
const val UnknownVariantName = "Unknown"
}
}

fun unknownVariantError(union: String) =
"Cannot serialize `$union::${UnionGenerator.UnknownVariantName}` for the request. " +
"The `Unknown` variant is intended for responses only. " +
"It occurs when an outdated client is used after a new enum variant was added on the server side."

private fun RustWriter.renderVariant(symbolProvider: SymbolProvider, member: MemberShape, memberSymbol: Symbol) {
if (member.target.name == "Unit") {
write("${symbolProvider.toMemberName(member)},")
} else {
write("${symbolProvider.toMemberName(member)}(#T),", memberSymbol)
}
}

private fun RustWriter.renderAsVariant(
member: MemberShape,
variantName: String,
funcNamePart: String,
unionSymbol: Symbol,
memberSymbol: Symbol,
) {
if (member.target.name == "Unit") {
rust(
"/// Tries to convert the enum instance into [`$variantName`], extracting the inner `()`.",
)
rust("/// Returns `Err(&Self)` if it can't be converted.")
rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<(), &Self>") {
rust("if let ${unionSymbol.name}::$variantName = &self { Ok(()) } else { Err(self) }")
}
} else {
rust(
"/// Tries to convert the enum instance into [`$variantName`](#T::$variantName), extracting the inner #D.",
unionSymbol,
memberSymbol,
)
rust("/// Returns `Err(&Self)` if it can't be converted.")
rustBlock("pub fn as_$funcNamePart(&self) -> std::result::Result<&#T, &Self>", memberSymbol) {
rust("if let ${unionSymbol.name}::$variantName(val) = &self { Ok(val) } else { Err(self) }")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ class UnionGeneratorTest {
writer.compileAndTest()
}

@Test
fun `unit types should not appear in generated enum`() {
val writer = generateUnion("union MyUnion { a: Unit, b: String }", unknownVariant = true)
writer.compileAndTest(
"""
let a = MyUnion::A;
assert!(a.as_a() == Ok(()));
""",
)
}

private fun generateUnion(modelSmithy: String, unionName: String = "MyUnion", unknownVariant: Boolean = true): RustWriter {
val model = "namespace test\n$modelSmithy".asSmithyModel()
val provider: SymbolProvider = testSymbolProvider(model)
Expand Down

0 comments on commit bdf7858

Please sign in to comment.