From f1a726c1d7b01b27ddceb33687970f4f514f9d2c Mon Sep 17 00:00:00 2001 From: Russell Cohen Date: Fri, 30 Jul 2021 11:25:10 -0400 Subject: [PATCH] Smithy 1.9/1.10 Upgrade (#618) * smithy 1.9.1 upgrade & primitive encode/decode This upgrades to Smithy 1.10, but the major change is a complete overhaul of how primitives are formatted and parsed. Primitive serialization was migrated and unified into Smithy Types with the end requirement of dealing with special float serialization semantics. * Switch to Smithy Core S3 Customization Trait Smithy 1.9.1 brings S3UnwrappedXmlOutput as a vended trait. This commit pulls in the new model & uses that trait. * Fix clippy warnings * Fix doc links * fix kotlin formatting * Fix s3 customization to use the operation shape * Ensure that numbers in string don't parse as numbers * remove unused itoa * Apply suggestions from code review Co-authored-by: John DiSanti * Fix tests, CR feedback * rename parse to parse_smithy_primitive * Fix some more clippy errors * Update changelog Co-authored-by: John DiSanti --- CHANGELOG.md | 1 + .../rustsdk/customize/s3/S3Decorator.kt | 19 -- aws/sdk/aws-models/accessanalyzer.json | 2 +- aws/sdk/aws-models/amp.json | 2 +- aws/sdk/aws-models/appmesh.json | 2 +- aws/sdk/aws-models/braket.json | 2 +- aws/sdk/aws-models/codeguruprofiler.json | 2 +- aws/sdk/aws-models/groundstation.json | 2 +- aws/sdk/aws-models/location.json | 2 +- aws/sdk/aws-models/mgn.json | 2 +- aws/sdk/aws-models/mwaa.json | 2 +- aws/sdk/aws-models/proton.json | 2 +- aws/sdk/aws-models/rds-data.json | 2 +- aws/sdk/aws-models/redshift-data.json | 2 +- aws/sdk/aws-models/s3.json | 1 + aws/sdk/aws-models/ssm-incidents.json | 2 +- aws/sdk/examples/batch/Cargo.toml | 2 +- aws/sdk/examples/ebs/Cargo.toml | 2 +- aws/sdk/examples/kinesis/Cargo.toml | 2 +- aws/sdk/examples/medialive/Cargo.toml | 2 +- aws/sdk/examples/mediapackage/Cargo.toml | 2 +- aws/sdk/examples/polly/Cargo.toml | 1 - aws/sdk/examples/qldb/Cargo.toml | 2 - aws/sdk/examples/qldb/README.md | 2 +- aws/sdk/examples/rds/Cargo.toml | 2 +- aws/sdk/examples/rdsdata/Cargo.toml | 2 +- aws/sdk/examples/sagemaker/Cargo.toml | 2 +- aws/sdk/examples/secretsmanager/Cargo.toml | 1 - aws/sdk/examples/snowball/Cargo.toml | 2 +- aws/sdk/examples/sqs/Cargo.toml | 2 +- .../smithy/rust/codegen/rustlang/RustTypes.kt | 11 + .../generators/HttpProtocolTestGenerator.kt | 24 +- .../codegen/smithy/generators/Instantiator.kt | 13 +- .../http/RequestBindingGenerator.kt | 82 +++--- .../http/ResponseBindingGenerator.kt | 14 +- .../parse/XmlBindingTraitParserGenerator.kt | 16 +- .../XmlBindingTraitSerializerGenerator.kt | 17 +- .../traits/S3UnwrappedXmlOutputTrait.kt | 34 --- .../amazon/smithy/rust/codegen/util/Smithy.kt | 9 + gradle.properties | 2 +- rust-runtime/inlineable/Cargo.toml | 2 +- rust-runtime/protocol-test-helpers/src/lib.rs | 51 +++- rust-runtime/smithy-http/src/header.rs | 149 +++++++--- rust-runtime/smithy-http/src/label.rs | 5 - rust-runtime/smithy-http/src/query.rs | 5 - rust-runtime/smithy-json/Cargo.toml | 2 - rust-runtime/smithy-json/src/deserialize.rs | 22 +- .../smithy-json/src/deserialize/token.rs | 77 ++++- rust-runtime/smithy-json/src/serialize.rs | 27 +- rust-runtime/smithy-query/Cargo.toml | 2 - rust-runtime/smithy-query/src/lib.rs | 21 +- rust-runtime/smithy-types/Cargo.toml | 2 + rust-runtime/smithy-types/src/lib.rs | 1 + rust-runtime/smithy-types/src/primitive.rs | 276 ++++++++++++++++++ 54 files changed, 708 insertions(+), 229 deletions(-) delete mode 100644 codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/S3UnwrappedXmlOutputTrait.kt create mode 100644 rust-runtime/smithy-types/src/primitive.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 2113b6e79f..d19d844d70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ **New This Week** - :bug: Correctly encode HTTP Checksums using base64 instead of hex. Fixes aws-sdk-rust#164. (#615) - (When complete) Add profile file provider for region (#594, #xyz) +- Overhaul serialization/deserialization of numeric/boolean types. This resolves issues around serialization of NaN/Infinity and should also reduce the number of allocations required during serialization. (#618) ## v0.18.1 (July 27th 2021) * Remove timestreamwrite and timestreamquery from the generated services (#613) diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt index 6db5a99be3..a3cb79dd0c 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt @@ -6,12 +6,8 @@ package software.amazon.smithy.rustsdk.customize.s3 import software.amazon.smithy.aws.traits.protocols.RestXmlTrait -import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape -import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.transform.ModelTransformer import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.Writable import software.amazon.smithy.rust.codegen.rustlang.asType @@ -28,7 +24,6 @@ import software.amazon.smithy.rust.codegen.smithy.letIf import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml import software.amazon.smithy.rust.codegen.smithy.protocols.RestXmlFactory -import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait import software.amazon.smithy.rustsdk.AwsRuntimeType /** @@ -59,20 +54,6 @@ class S3Decorator : RustCodegenDecorator { it + S3PubUse() } } - - override fun transformModel(service: ServiceShape, model: Model): Model { - return model.letIf(applies(service.id)) { - ModelTransformer.create().mapShapes(model) { shape -> - // Apply the S3UnwrappedXmlOutput customization to GetBucketLocation (more - // details on the S3UnwrappedXmlOutputTrait) - if (shape is StructureShape && shape.id == ShapeId.from("com.amazonaws.s3#GetBucketLocationOutput")) { - shape.toBuilder().addTrait(S3UnwrappedXmlOutputTrait()).build() - } else { - shape - } - } - } - } } class S3(protocolConfig: ProtocolConfig) : RestXml(protocolConfig) { diff --git a/aws/sdk/aws-models/accessanalyzer.json b/aws/sdk/aws-models/accessanalyzer.json index f2a87af9d3..b261071289 100644 --- a/aws/sdk/aws-models/accessanalyzer.json +++ b/aws/sdk/aws-models/accessanalyzer.json @@ -4823,4 +4823,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/amp.json b/aws/sdk/aws-models/amp.json index cd66167b68..84043fdc5d 100644 --- a/aws/sdk/aws-models/amp.json +++ b/aws/sdk/aws-models/amp.json @@ -935,4 +935,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/appmesh.json b/aws/sdk/aws-models/appmesh.json index 9b2500c5a0..9273d629e0 100644 --- a/aws/sdk/aws-models/appmesh.json +++ b/aws/sdk/aws-models/appmesh.json @@ -8671,4 +8671,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/braket.json b/aws/sdk/aws-models/braket.json index 209a8dadc6..c89e12506b 100644 --- a/aws/sdk/aws-models/braket.json +++ b/aws/sdk/aws-models/braket.json @@ -1404,4 +1404,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/codeguruprofiler.json b/aws/sdk/aws-models/codeguruprofiler.json index 68b2066e3c..e54bb4e6d1 100644 --- a/aws/sdk/aws-models/codeguruprofiler.json +++ b/aws/sdk/aws-models/codeguruprofiler.json @@ -3174,4 +3174,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/groundstation.json b/aws/sdk/aws-models/groundstation.json index 9fe4f04b6e..9af96283a8 100644 --- a/aws/sdk/aws-models/groundstation.json +++ b/aws/sdk/aws-models/groundstation.json @@ -3544,4 +3544,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/location.json b/aws/sdk/aws-models/location.json index 3dd786826e..1d7718404d 100644 --- a/aws/sdk/aws-models/location.json +++ b/aws/sdk/aws-models/location.json @@ -5937,4 +5937,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/mgn.json b/aws/sdk/aws-models/mgn.json index b0512d4832..10eb1f4f41 100644 --- a/aws/sdk/aws-models/mgn.json +++ b/aws/sdk/aws-models/mgn.json @@ -3944,4 +3944,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/mwaa.json b/aws/sdk/aws-models/mwaa.json index f3e312eddc..ad18f1105b 100644 --- a/aws/sdk/aws-models/mwaa.json +++ b/aws/sdk/aws-models/mwaa.json @@ -1938,4 +1938,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/proton.json b/aws/sdk/aws-models/proton.json index 3a569269dc..03616adff2 100644 --- a/aws/sdk/aws-models/proton.json +++ b/aws/sdk/aws-models/proton.json @@ -6482,4 +6482,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/rds-data.json b/aws/sdk/aws-models/rds-data.json index 488e7ed124..2e3e75b220 100644 --- a/aws/sdk/aws-models/rds-data.json +++ b/aws/sdk/aws-models/rds-data.json @@ -1336,4 +1336,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/redshift-data.json b/aws/sdk/aws-models/redshift-data.json index 9b4600e23b..6e009f4c6c 100644 --- a/aws/sdk/aws-models/redshift-data.json +++ b/aws/sdk/aws-models/redshift-data.json @@ -1423,4 +1423,4 @@ "type": "boolean" } } -} \ No newline at end of file +} diff --git a/aws/sdk/aws-models/s3.json b/aws/sdk/aws-models/s3.json index 15dc938600..b5293082c6 100644 --- a/aws/sdk/aws-models/s3.json +++ b/aws/sdk/aws-models/s3.json @@ -4073,6 +4073,7 @@ "target": "com.amazonaws.s3#GetBucketLocationOutput" }, "traits": { + "aws.customizations#s3UnwrappedXmlOutput": {}, "smithy.api#documentation": "

Returns the Region the bucket resides in. You set the bucket's Region using the\n LocationConstraint request parameter in a CreateBucket\n request. For more information, see CreateBucket.

\n\n

To use this implementation of the operation, you must be the bucket owner.

\n\n

The following operations are related to GetBucketLocation:

\n ", "smithy.api#http": { "method": "GET", diff --git a/aws/sdk/aws-models/ssm-incidents.json b/aws/sdk/aws-models/ssm-incidents.json index e2ab5381f2..20f4f25d93 100644 --- a/aws/sdk/aws-models/ssm-incidents.json +++ b/aws/sdk/aws-models/ssm-incidents.json @@ -4041,4 +4041,4 @@ } } } -} \ No newline at end of file +} diff --git a/aws/sdk/examples/batch/Cargo.toml b/aws/sdk/examples/batch/Cargo.toml index 4be66d845d..7998b9d8d7 100644 --- a/aws/sdk/examples/batch/Cargo.toml +++ b/aws/sdk/examples/batch/Cargo.toml @@ -11,4 +11,4 @@ batch = { package = "aws-sdk-batch", path = "../../build/aws-sdk/batch" } aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = { version = "1", features = ["full"] } structopt = { version = "0.3", default-features = false } -tracing-subscriber = "0.2.18" \ No newline at end of file +tracing-subscriber = "0.2.18" diff --git a/aws/sdk/examples/ebs/Cargo.toml b/aws/sdk/examples/ebs/Cargo.toml index f33fe0c87a..75e5c29d0f 100644 --- a/aws/sdk/examples/ebs/Cargo.toml +++ b/aws/sdk/examples/ebs/Cargo.toml @@ -14,4 +14,4 @@ tokio = { version = "1", features = ["full"]} base64 = "0.13.0" sha2 = "0.9.5" structopt = { version = "0.3", default-features = false } -tracing-subscriber = "0.2.19" \ No newline at end of file +tracing-subscriber = "0.2.19" diff --git a/aws/sdk/examples/kinesis/Cargo.toml b/aws/sdk/examples/kinesis/Cargo.toml index 35c5884944..57ca5fa25c 100644 --- a/aws/sdk/examples/kinesis/Cargo.toml +++ b/aws/sdk/examples/kinesis/Cargo.toml @@ -11,4 +11,4 @@ kinesis = { package = "aws-sdk-kinesis", path = "../../build/aws-sdk/kinesis" } aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = { version = "1", features = ["full"] } structopt = { version = "0.3", default-features = false } -tracing-subscriber = { version = "0.2.16", features = ["fmt"] } \ No newline at end of file +tracing-subscriber = { version = "0.2.16", features = ["fmt"] } diff --git a/aws/sdk/examples/medialive/Cargo.toml b/aws/sdk/examples/medialive/Cargo.toml index 22c535b6ad..552cd313d0 100644 --- a/aws/sdk/examples/medialive/Cargo.toml +++ b/aws/sdk/examples/medialive/Cargo.toml @@ -11,4 +11,4 @@ medialive = { package = "aws-sdk-medialive", path = "../../build/aws-sdk/mediali aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = { version = "1", features = ["full"] } structopt = { version = "0.3", default-features = false } -tracing-subscriber = { version = "0.2.16", features = ["fmt"] } \ No newline at end of file +tracing-subscriber = { version = "0.2.16", features = ["fmt"] } diff --git a/aws/sdk/examples/mediapackage/Cargo.toml b/aws/sdk/examples/mediapackage/Cargo.toml index dc486ffe7d..fca8c6e851 100644 --- a/aws/sdk/examples/mediapackage/Cargo.toml +++ b/aws/sdk/examples/mediapackage/Cargo.toml @@ -11,4 +11,4 @@ mediapackage = { package = "aws-sdk-mediapackage", path = "../../build/aws-sdk/m aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = { version = "1", features = ["full"] } structopt = { version = "0.3", default-features = false } -tracing-subscriber = { version = "0.2.16", features = ["fmt"] } \ No newline at end of file +tracing-subscriber = { version = "0.2.16", features = ["fmt"] } diff --git a/aws/sdk/examples/polly/Cargo.toml b/aws/sdk/examples/polly/Cargo.toml index d4991e4537..ffdc05b492 100644 --- a/aws/sdk/examples/polly/Cargo.toml +++ b/aws/sdk/examples/polly/Cargo.toml @@ -12,4 +12,3 @@ aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = { version = "1", features = ["full"] } structopt = { version = "0.3", default-features = false } tracing-subscriber = { version = "0.2.16", features = ["fmt"] } - diff --git a/aws/sdk/examples/qldb/Cargo.toml b/aws/sdk/examples/qldb/Cargo.toml index 8fd299aa61..e1f4e18ac6 100644 --- a/aws/sdk/examples/qldb/Cargo.toml +++ b/aws/sdk/examples/qldb/Cargo.toml @@ -13,5 +13,3 @@ aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = { version = "1", features = ["full"] } structopt = { version = "0.3", default-features = false } tracing-subscriber = { version = "0.2.16", features = ["fmt"] } - - diff --git a/aws/sdk/examples/qldb/README.md b/aws/sdk/examples/qldb/README.md index ffafc8a37a..49b61fb6a4 100644 --- a/aws/sdk/examples/qldb/README.md +++ b/aws/sdk/examples/qldb/README.md @@ -51,4 +51,4 @@ where: If the environment variable is not set, defaults to **us-west-2**. - __-v__ enables displaying additional information. -## +## diff --git a/aws/sdk/examples/rds/Cargo.toml b/aws/sdk/examples/rds/Cargo.toml index 0b42125c94..88fa8475dd 100644 --- a/aws/sdk/examples/rds/Cargo.toml +++ b/aws/sdk/examples/rds/Cargo.toml @@ -11,4 +11,4 @@ rds = {package = "aws-sdk-rds", path = "../../build/aws-sdk/rds"} aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = {version = "1", features = ["full"]} structopt = { version = "0.3", default-features = false } -tracing-subscriber = { version = "0.2.16", features = ["fmt"] } \ No newline at end of file +tracing-subscriber = { version = "0.2.16", features = ["fmt"] } diff --git a/aws/sdk/examples/rdsdata/Cargo.toml b/aws/sdk/examples/rdsdata/Cargo.toml index c6ff30216d..d9cf9d33f5 100644 --- a/aws/sdk/examples/rdsdata/Cargo.toml +++ b/aws/sdk/examples/rdsdata/Cargo.toml @@ -11,4 +11,4 @@ rdsdata = {package = "aws-sdk-rdsdata", path = "../../build/aws-sdk/rdsdata"} aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = {version = "1", features = ["full"]} structopt = { version = "0.3", default-features = false } -tracing-subscriber = { version = "0.2.16", features = ["fmt"] } \ No newline at end of file +tracing-subscriber = { version = "0.2.16", features = ["fmt"] } diff --git a/aws/sdk/examples/sagemaker/Cargo.toml b/aws/sdk/examples/sagemaker/Cargo.toml index c4090b603e..ef74a222d3 100644 --- a/aws/sdk/examples/sagemaker/Cargo.toml +++ b/aws/sdk/examples/sagemaker/Cargo.toml @@ -15,4 +15,4 @@ tokio = { version = "1", features = ["full"] } env_logger = "0.8.2" chrono = "0.4.19" structopt = { version = "0.3", default-features = false } -tracing-subscriber = "0.2.18" \ No newline at end of file +tracing-subscriber = "0.2.18" diff --git a/aws/sdk/examples/secretsmanager/Cargo.toml b/aws/sdk/examples/secretsmanager/Cargo.toml index cb1b2b715f..9aac062109 100644 --- a/aws/sdk/examples/secretsmanager/Cargo.toml +++ b/aws/sdk/examples/secretsmanager/Cargo.toml @@ -14,4 +14,3 @@ tokio = { version = "1", features = ["full"]} structopt = { version = "0.3", default-features = false } tracing-subscriber = { version = "0.2.16", features = ["fmt"] } - diff --git a/aws/sdk/examples/snowball/Cargo.toml b/aws/sdk/examples/snowball/Cargo.toml index a0044c1e1c..3c862a76a9 100644 --- a/aws/sdk/examples/snowball/Cargo.toml +++ b/aws/sdk/examples/snowball/Cargo.toml @@ -11,4 +11,4 @@ aws-sdk-snowball = { path = "../../build/aws-sdk/snowball" } aws-types = { path = "../../build/aws-sdk/aws-types" } tokio = { version = "1", features = ["full"] } structopt = { version = "0.3", default-features = false } -tracing-subscriber = "0.2.18" \ No newline at end of file +tracing-subscriber = "0.2.18" diff --git a/aws/sdk/examples/sqs/Cargo.toml b/aws/sdk/examples/sqs/Cargo.toml index 9fadf19c30..5196651b5b 100644 --- a/aws/sdk/examples/sqs/Cargo.toml +++ b/aws/sdk/examples/sqs/Cargo.toml @@ -9,4 +9,4 @@ edition = "2018" [dependencies] sqs = { package = "aws-sdk-sqs", path = "../../build/aws-sdk/sqs" } tokio = { version = "1", features = ["full"] } -tracing-subscriber = "0.2.18" \ No newline at end of file +tracing-subscriber = "0.2.18" diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt index fafd16332d..27f2e73516 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt @@ -8,6 +8,17 @@ package software.amazon.smithy.rust.codegen.rustlang import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.util.dq +/** + * Dereference [input] + * + * Clippy is upset about `*&`, so if [input] is already referenced, simply strip the leading '&' + */ +fun autoDeref(input: String) = if (input.startsWith("&")) { + input.removePrefix("&") +} else { + "*$input" +} + /** * A hierarchy of types handled by Smithy codegen */ diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt index 22c329a535..109feb8b82 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt @@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.smithy.generators import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.knowledge.OperationIndex +import software.amazon.smithy.model.shapes.DoubleShape +import software.amazon.smithy.model.shapes.FloatShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait @@ -306,7 +308,21 @@ class HttpProtocolTestGenerator( );""" ) } else { - rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""") + when (protocolConfig.model.expectShape(member.target)) { + is DoubleShape, is FloatShape -> { + addUseImports( + RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "FloatEquals").toSymbol() + ) + rust( + """ + assert!(parsed.$memberName.float_equals(&expected_output.$memberName), + "Unexpected value for `$memberName` {:?} vs. {:?}", expected_output.$memberName, parsed.$memberName); + """ + ) + } + else -> + rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""") + } } } } @@ -428,7 +444,11 @@ class HttpProtocolTestGenerator( private val RestXml = "aws.protocoltests.restxml#RestXml" private val AwsQuery = "aws.protocoltests.query#AwsQuery" private val Ec2Query = "aws.protocoltests.ec2#AwsEc2" - private val ExpectFail = setOf() + private val ExpectFail = setOf( + FailingTest( + service = RestJson, id = "RestJsonHostWithPath", action = Action.Request + ) + ) private val RunOnly: Set? = null // These tests are not even attempted to be generated, either because they will not compile diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt index 72fade3271..541732e8f0 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt @@ -114,7 +114,18 @@ class Instantiator( // Simple Shapes is StringShape -> renderString(writer, shape, arg as StringNode) - is NumberShape -> writer.write(arg.asNumberNode().get()) + is NumberShape -> when (arg) { + is StringNode -> { + val numberSymbol = symbolProvider.toSymbol(shape) + // support Smithy custom values, such as Infinity + writer.rust( + """<#T as #T>::parse_smithy_primitive(${arg.value.dq()}).expect("invalid string for number")""", + numberSymbol, + CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse") + ) + } + is NumberNode -> writer.write(arg.value) + } is BooleanShape -> writer.write(arg.asBooleanNode().get().toString()) is DocumentShape -> writer.rustBlock("") { val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType() diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt index fbc33f1ae0..da9d690bae 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.smithy.generators.http +import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex @@ -19,8 +20,10 @@ import software.amazon.smithy.model.traits.HttpTrait import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.rust.codegen.rustlang.Attribute +import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustWriter -import software.amazon.smithy.rust.codegen.rustlang.assignment +import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.autoDeref import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustTemplate @@ -36,6 +39,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.hasTrait +import software.amazon.smithy.rust.codegen.util.isPrimitive fun HttpTrait.uriFormatString(): String { return uri.rustFormatString("/", "/") @@ -71,6 +75,7 @@ class RequestBindingGenerator( ) { private val index = HttpBindingIndex.of(model) private val buildError = runtimeConfig.operationBuildError() + private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder") constructor( protocolConfig: ProtocolConfig, @@ -193,6 +198,9 @@ class RequestBindingGenerator( ifSet(memberType, memberSymbol, "&self.$memberName") { field -> listForEach(memberType, field) { innerField, targetId -> val innerMemberType = model.expectShape(targetId) + if (innerMemberType.isPrimitive()) { + rust("let mut encoder = #T::from(${autoDeref(innerField)});", Encoder) + } val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField) val safeName = safeName("formatted") write("let $safeName = $formatted;") @@ -241,10 +249,10 @@ class RequestBindingGenerator( target.isListShape || target.isMemberShape -> { throw IllegalArgumentException("lists should be handled at a higher level") } - else -> { - val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_default")) - "$func(&$targetName)" + target.isPrimitive() -> { + "encoder.encode()" } + else -> throw CodegenException("unexpected shape: $target") } } @@ -263,12 +271,13 @@ class RequestBindingGenerator( } val combinedArgs = listOf(formatString, *args.toTypedArray()) writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null) - writer.rustBlock("fn uri_base(&self, output: &mut String) -> Result<(), #T>", runtimeConfig.operationBuildError()) { + writer.rustBlock( + "fn uri_base(&self, output: &mut String) -> Result<(), #T>", + runtimeConfig.operationBuildError() + ) { httpTrait.uri.labels.map { label -> val member = inputShape.expectMember(label.content) - assignment(local(member)) { - serializeLabel(member, label) - } + serializeLabel(member, label, local(member)) } rust("""write!(output, ${combinedArgs.joinToString(", ")}).expect("formatting should succeed");""") rust("Ok(())") @@ -374,13 +383,12 @@ class RequestBindingGenerator( throw IllegalArgumentException("lists should be handled at a higher level") } else -> { - val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_default")) - "$func(&$targetName)" + "${writer.format(Encoder)}::from(${autoDeref(targetName)}).encode()" } } } - private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment) { + private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment, outputVar: String) { val target = model.expectShape(member.target) val symbol = symbolProvider.toSymbol(member) val buildError = { @@ -390,37 +398,37 @@ class RequestBindingGenerator( "cannot be empty or unset" ) } - rustBlock("") { - rust("let input = &self.${symbolProvider.toMemberName(member)};") - if (symbol.isOptional()) { - rust("let input = input.as_ref().ok_or(${buildError()})?;") + val input = safeName("input") + rust("let $input = &self.${symbolProvider.toMemberName(member)};") + if (symbol.isOptional()) { + rust("let $input = $input.as_ref().ok_or(${buildError()})?;") + } + when { + target.isStringShape -> { + val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string")) + rust("let $outputVar = $func($input, ${label.isGreedyLabel});") } - when { - target.isStringShape -> { - val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string")) - rust("let formatted = $func(input, ${label.isGreedyLabel});") - } - target.isTimestampShape -> { - val timestampFormat = - index.determineTimestampFormat(member, HttpBinding.Location.LABEL, defaultTimestampFormat) - val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) - val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp")) - rust("let formatted = $func(&input, ${format(timestampFormatType)});") - } - else -> { - val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_default")) - rust("let formatted = $func(input);") - } + target.isTimestampShape -> { + val timestampFormat = + index.determineTimestampFormat(member, HttpBinding.Location.LABEL, defaultTimestampFormat) + val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat) + val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp")) + rust("let $outputVar = $func(&$input, ${format(timestampFormatType)});") } - rust( - """ - if formatted.is_empty() { + else -> { + rust( + "let mut ${outputVar}_encoder = #T::from(${autoDeref(input)}); let $outputVar = ${outputVar}_encoder.encode();", + Encoder + ) + } + } + rust( + """ + if $outputVar.is_empty() { return Err(${buildError()}) } - formatted """ - ) - } + ) } /** End URI generation **/ } diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt index 14c0b465b0..2d70ea3264 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt @@ -38,6 +38,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.smithy.rustType import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.hasTrait +import software.amazon.smithy.rust.codegen.util.isPrimitive import software.amazon.smithy.rust.codegen.util.isStreaming import software.amazon.smithy.rust.codegen.util.toSnakeCase @@ -238,17 +239,22 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera headerUtil, timestampFormatType ) + } else if (coreShape.isPrimitive()) { + rust( + "let $parsedValue = #T::read_many_primitive::<${coreType.render(fullyQualified = true)}>(headers)?;", + headerUtil + ) } else { rust( - "let $parsedValue: Vec<${coreType.render(true)}> = #T::read_many(headers)?;", + "let $parsedValue: Vec<${coreType.render(fullyQualified = true)}> = #T::read_many_from_str(headers)?;", headerUtil ) if (coreShape.hasTrait()) { rustTemplate( """let $parsedValue: std::result::Result, _> = $parsedValue .iter().map(|s| - #{base_64_decode}(s).map_err(|_|#{header}::ParseError) - .and_then(|bytes|String::from_utf8(bytes).map_err(|_|#{header}::ParseError)) + #{base_64_decode}(s).map_err(|_|#{header}::ParseError::new_with_message("failed to decode base64")) + .and_then(|bytes|String::from_utf8(bytes).map_err(|_|#{header}::ParseError::new_with_message("base64 encoded data was not valid utf-8"))) ).collect();""", "base_64_decode" to RuntimeType.Base64Decode(runtimeConfig), "header" to headerUtil @@ -281,7 +287,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera else -> rustTemplate( """ if $parsedValue.len() > 1 { - Err(#{header_util}::ParseError) + Err(#{header_util}::ParseError::new_with_message(format!("expected one item but found {}", $parsedValue.len()))) } else { let mut $parsedValue = $parsedValue; Ok($parsedValue.pop()) diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt index 1a2cb050f0..0ab8d68ca7 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.rust.codegen.smithy.protocols.parse +import software.amazon.smithy.aws.traits.customizations.S3UnwrappedXmlOutputTrait import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.knowledge.HttpBinding import software.amazon.smithy.model.knowledge.HttpBindingIndex @@ -34,6 +35,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.rustlang.rustTemplate import software.amazon.smithy.rust.codegen.rustlang.withBlock +import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate import software.amazon.smithy.rust.codegen.smithy.RuntimeType import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator @@ -44,7 +46,6 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional import software.amazon.smithy.rust.codegen.smithy.protocols.XmlMemberIndex import software.amazon.smithy.rust.codegen.smithy.protocols.XmlNameIndex import software.amazon.smithy.rust.codegen.smithy.protocols.deserializeFunctionName -import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait import software.amazon.smithy.rust.codegen.util.dq import software.amazon.smithy.rust.codegen.util.expectMember import software.amazon.smithy.rust.codegen.util.hasTrait @@ -101,7 +102,8 @@ class XmlBindingTraitParserGenerator( "XmlError" to xmlError, "next_start_element" to smithyXml.member("decode::next_start_element"), "try_data" to smithyXml.member("decode::try_data"), - "ScopedDecoder" to scopedDecoder + "ScopedDecoder" to scopedDecoder, + "smithy_types" to CargoDependency.SmithyTypes(runtimeConfig).asType() ) private val model = protocolConfig.model private val index = HttpBindingIndex.of(model) @@ -192,7 +194,7 @@ class XmlBindingTraitParserGenerator( *codegenScope ) val context = OperationWrapperContext(operationShape, shapeName, xmlError) - if (outputShape.hasTrait()) { + if (operationShape.hasTrait()) { unwrappedResponseParser("builder", "decoder", "start_el", outputShape.members()) } else { writeOperationWrapper(context) { tagName -> @@ -561,8 +563,12 @@ class XmlBindingTraitParserGenerator( is StringShape -> parseStringInner(shape, provider) is NumberShape, is BooleanShape -> { rustBlock("") { - rust("use std::str::FromStr;") - withBlock("#T::from_str(", ")", symbolProvider.toSymbol(shape)) { + withBlockTemplate( + "<#{shape} as #{smithy_types}::primitive::Parse>::parse_smithy_primitive(", + ")", + *codegenScope, + "shape" to symbolProvider.toSymbol(shape) + ) { provider() } rustTemplate( diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt index f2b4ce53ee..469c8e824f 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt @@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.rustlang.RustType import software.amazon.smithy.rust.codegen.rustlang.RustWriter import software.amazon.smithy.rust.codegen.rustlang.asType +import software.amazon.smithy.rust.codegen.rustlang.autoDeref import software.amazon.smithy.rust.codegen.rustlang.render import software.amazon.smithy.rust.codegen.rustlang.rust import software.amazon.smithy.rust.codegen.rustlang.rustBlock @@ -195,17 +196,6 @@ class XmlBindingTraitSerializerGenerator( rust("scope.finish();") } - /** - * Dereference [input] - * - * Clippy is upset about `*&`, so if [input] is already referenced, simply strip the leading '&' - */ - private fun autoDeref(input: String) = if (input.startsWith("&")) { - input.removePrefix("&") - } else { - "*$input" - } - private fun RustWriter.serializeRawMember(member: MemberShape, input: String) { when (val shape = model.expectShape(member.target)) { is StringShape -> if (shape.hasTrait()) { @@ -213,8 +203,9 @@ class XmlBindingTraitSerializerGenerator( } else { rust("$input.as_ref()") } - is NumberShape -> rust("$input.to_string().as_ref()") - is BooleanShape -> rust("""if ${autoDeref(input)} { "true" } else { "false" }""") + is BooleanShape, is NumberShape -> { + rust("#T::from(${autoDeref(input)}).encode()", CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder")) + } is BlobShape -> rust("#T($input.as_ref()).as_ref()", RuntimeType.Base64Encode(runtimeConfig)) is TimestampShape -> { val timestampFormat = diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/S3UnwrappedXmlOutputTrait.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/S3UnwrappedXmlOutputTrait.kt deleted file mode 100644 index 7d7eb41b25..0000000000 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/S3UnwrappedXmlOutputTrait.kt +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -package software.amazon.smithy.rust.codegen.smithy.traits - -import software.amazon.smithy.model.node.Node -import software.amazon.smithy.model.shapes.ShapeId -import software.amazon.smithy.model.traits.AnnotationTrait - -/** - * S3's GetBucketLocation response shape can't be represented with Smithy's restXml protocol - * without customization. We add this trait to the S3 model at codegen time so that a different - * code path is taken in the XML deserialization codegen to generate code that parses the S3 - * response shape correctly. - * - * From what the S3 model states, the generated parser would expect: - * ``` - * - * us-west-2 - * - * ``` - * - * But S3 actually responds with: - * ``` - * us-west-2 - * ``` - */ -class S3UnwrappedXmlOutputTrait : AnnotationTrait(ID, Node.objectNode()) { - companion object { - val ID = ShapeId.from("smithy.api.internal#s3UnwrappedXmlOutputTrait") - } -} diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt index ea5c9af3fa..4391e6e7aa 100644 --- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt +++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt @@ -7,7 +7,9 @@ package software.amazon.smithy.rust.codegen.util import software.amazon.smithy.codegen.core.CodegenException import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.NumberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.ShapeId @@ -65,3 +67,10 @@ inline fun Shape.expectTrait(): T = expectTrait(T::class.jav /** Kotlin sugar for getTrait() check. e.g. shape.getTrait() instead of shape.getTrait(EnumTrait::class.java) */ inline fun Shape.getTrait(): T? = getTrait(T::class.java).orNull() + +fun Shape.isPrimitive(): Boolean { + return when (this) { + is NumberShape, is BooleanShape -> true + else -> false + } +} diff --git a/gradle.properties b/gradle.properties index abc15cd091..061d0c3249 100644 --- a/gradle.properties +++ b/gradle.properties @@ -6,7 +6,7 @@ kotlin.code.style=official # codegen -smithyVersion=1.8.0 +smithyVersion=1.10.0 # kotlin kotlinVersion=1.4.21 diff --git a/rust-runtime/inlineable/Cargo.toml b/rust-runtime/inlineable/Cargo.toml index 07aac64313..5464f6769f 100644 --- a/rust-runtime/inlineable/Cargo.toml +++ b/rust-runtime/inlineable/Cargo.toml @@ -20,4 +20,4 @@ are to allow this crate to be compilable and testable in isolation, no client co [dev-dependencies] proptest = "1" -regex = "1" \ No newline at end of file +regex = "1" diff --git a/rust-runtime/protocol-test-helpers/src/lib.rs b/rust-runtime/protocol-test-helpers/src/lib.rs index 1852c01f7f..a690ebe6ab 100644 --- a/rust-runtime/protocol-test-helpers/src/lib.rs +++ b/rust-runtime/protocol-test-helpers/src/lib.rs @@ -15,6 +15,39 @@ use std::fmt::{self, Debug}; use thiserror::Error; use urlencoded::try_url_encoded_form_equivalent; +/// Helper trait for tests for float comparisons +/// +/// This trait differs in float's default `PartialEq` implementation by considering all `NaN` values to +/// be equal. +pub trait FloatEquals { + fn float_equals(&self, other: &Self) -> bool; +} + +impl FloatEquals for f64 { + fn float_equals(&self, other: &Self) -> bool { + (self.is_nan() && other.is_nan()) || self.eq(other) + } +} + +impl FloatEquals for f32 { + fn float_equals(&self, other: &Self) -> bool { + (self.is_nan() && other.is_nan()) || self.eq(other) + } +} + +impl FloatEquals for Option +where + T: FloatEquals, +{ + fn float_equals(&self, other: &Self) -> bool { + match (self, other) { + (Some(this), Some(other)) => this.float_equals(other), + (None, None) => true, + _else => false, + } + } +} + #[derive(Debug, PartialEq, Eq, Error)] pub enum ProtocolTestFailure { #[error("missing query param: expected `{expected}`, found {found:?}")] @@ -326,7 +359,7 @@ fn try_json_eq(actual: &str, expected: &str) -> Result<(), ProtocolTestFailure> mod tests { use crate::{ forbid_headers, forbid_query_params, require_headers, require_query_params, validate_body, - validate_headers, validate_query_string, MediaType, ProtocolTestFailure, + validate_headers, validate_query_string, FloatEquals, MediaType, ProtocolTestFailure, }; use http::Request; @@ -472,4 +505,20 @@ mod tests { validate_body(&expected, expected, MediaType::from("something/else")) .expect("inputs matched exactly") } + + #[test] + fn test_float_equals() { + let a = f64::NAN; + let b = f64::NAN; + assert_ne!(a, b); + assert!(a.float_equals(&b)); + assert!(!a.float_equals(&5_f64)); + + assert!(5.0.float_equals(&5.0)); + assert!(!5.0.float_equals(&5.1)); + + assert!(f64::INFINITY.float_equals(&f64::INFINITY)); + assert!(!f64::INFINITY.float_equals(&f64::NEG_INFINITY)); + assert!(f64::NEG_INFINITY.float_equals(&f64::NEG_INFINITY)); + } } diff --git a/rust-runtime/smithy-http/src/header.rs b/rust-runtime/smithy-http/src/header.rs index eb44c926da..64c56ac495 100644 --- a/rust-runtime/smithy-http/src/header.rs +++ b/rust-runtime/smithy-http/src/header.rs @@ -8,18 +8,40 @@ use http::header::{HeaderName, ValueIter}; use http::HeaderValue; use smithy_types::instant::Format; +use smithy_types::primitive::Parse; use smithy_types::Instant; +use std::borrow::Cow; use std::error::Error; use std::fmt; use std::fmt::{Display, Formatter}; use std::str::FromStr; -#[derive(Debug)] -pub struct ParseError; +#[derive(Debug, Eq, PartialEq)] +#[non_exhaustive] +pub struct ParseError { + message: Option>, +} + +impl ParseError { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { message: None } + } + + pub fn new_with_message(message: impl Into>) -> Self { + Self { + message: Some(message.into()), + } + } +} impl Display for ParseError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "Output failed to parse in headers") + write!(f, "Output failed to parse in headers")?; + if let Some(message) = &self.message { + write!(f, ". {}", message)?; + } + Ok(()) } } @@ -35,9 +57,13 @@ pub fn many_dates( ) -> Result, ParseError> { let mut out = vec![]; for header in values { - let mut header = header.to_str().map_err(|_| ParseError)?; + let mut header = header + .to_str() + .map_err(|_| ParseError::new_with_message("header was not valid utf-8 string"))?; while !header.is_empty() { - let (v, next) = Instant::read(header, format, ',').map_err(|_| ParseError)?; + let (v, next) = Instant::read(header, format, ',').map_err(|err| { + ParseError::new_with_message(format!("header could not be parsed as date: {}", err)) + })?; out.push(v); header = next; } @@ -56,16 +82,36 @@ pub fn headers_for_prefix<'a>( .map(move |h| (&h.as_str()[key.len()..], h)) } +pub fn read_many_from_str( + values: ValueIter, +) -> Result, ParseError> { + read_many(values, |v: &str| { + v.parse() + .map_err(|_err| ParseError::new_with_message("failed during FromString conversion")) + }) +} + +pub fn read_many_primitive(values: ValueIter) -> Result, ParseError> { + read_many(values, |v: &str| { + T::parse_smithy_primitive(v).map_err(|primitive| { + ParseError::new_with_message(format!( + "failed reading a list of primitives: {}", + primitive + )) + }) + }) +} + /// Read many comma / header delimited values from HTTP headers for `FromStr` types -pub fn read_many(values: ValueIter) -> Result, ParseError> -where - T: FromStr, -{ +fn read_many( + values: ValueIter, + f: impl Fn(&str) -> Result, +) -> Result, ParseError> { let mut out = vec![]; for header in values { let mut header = header.as_bytes(); while !header.is_empty() { - let (v, next) = read_one::(&header)?; + let (v, next) = read_one(&header, &f)?; out.push(v); header = next; } @@ -83,10 +129,15 @@ pub fn one_or_none( Some(v) => v, None => return Ok(None), }; - let value = std::str::from_utf8(first.as_bytes()).map_err(|_| ParseError)?; + let value = std::str::from_utf8(first.as_bytes()) + .map_err(|_| ParseError::new_with_message("invalid utf-8"))?; match values.next() { - None => T::from_str(value.trim()).map_err(|_| ParseError).map(Some), - Some(_) => Err(ParseError), + None => T::from_str(value.trim()) + .map_err(|_| ParseError::new()) + .map(Some), + Some(_) => Err(ParseError::new_with_message( + "expected a single value but found multiple", + )), } } @@ -107,13 +158,14 @@ pub fn set_header_if_absent( } /// Read one comma delimited value for `FromStr` types -fn read_one(s: &[u8]) -> Result<(T, &[u8]), ParseError> -where - T: FromStr, -{ +fn read_one<'a, T>( + s: &'a [u8], + f: &impl Fn(&str) -> Result, +) -> Result<(T, &'a [u8]), ParseError> { let (head, rest) = split_at_delim(s); - let head = std::str::from_utf8(head).map_err(|_| ParseError)?; - Ok((T::from_str(head.trim()).map_err(|_| ParseError)?, rest)) + let head = std::str::from_utf8(head) + .map_err(|_| ParseError::new_with_message("header was not valid utf8"))?; + Ok((f(head.trim())?, rest)) } fn split_at_delim(s: &[u8]) -> (&[u8], &[u8]) { @@ -128,13 +180,15 @@ fn then_delim(s: &[u8]) -> Result<&[u8], ParseError> { } else if s.starts_with(b",") { Ok(&s[1..]) } else { - Err(ParseError) + Err(ParseError::new_with_message("expected delimiter `,`")) } } #[cfg(test)] mod test { - use crate::header::{headers_for_prefix, read_many, set_header_if_absent, ParseError}; + use crate::header::{ + headers_for_prefix, read_many_primitive, set_header_if_absent, ParseError, + }; use std::collections::HashMap; #[test] @@ -153,6 +207,27 @@ mod test { ); } + #[test] + fn parse_floats() { + let test_request = http::Request::builder() + .header("X-Float-Multi", "0.0,Infinity,-Infinity,5555.5") + .header("X-Float-Error", "notafloat") + .body(()) + .unwrap(); + assert_eq!( + read_many_primitive::(test_request.headers().get_all("X-Float-Multi").iter()) + .expect("valid"), + vec![0.0, f32::INFINITY, f32::NEG_INFINITY, 5555.5] + ); + assert_eq!( + read_many_primitive::(test_request.headers().get_all("X-Float-Error").iter()) + .expect_err("invalid"), + ParseError::new_with_message( + "failed reading a list of primitives: failed to parse input as f32" + ) + ) + } + #[test] fn read_many_bools() { let test_request = http::Request::builder() @@ -164,47 +239,50 @@ mod test { .body(()) .unwrap(); assert_eq!( - read_many::(test_request.headers().get_all("X-Bool-Multi").iter()) + read_many_primitive::(test_request.headers().get_all("X-Bool-Multi").iter()) .expect("valid"), vec![true, false, true] ); assert_eq!( - read_many::(test_request.headers().get_all("X-Bool").iter()).unwrap(), + read_many_primitive::(test_request.headers().get_all("X-Bool").iter()).unwrap(), vec![true] ); assert_eq!( - read_many::(test_request.headers().get_all("X-Bool-Single").iter()).unwrap(), + read_many_primitive::(test_request.headers().get_all("X-Bool-Single").iter()) + .unwrap(), vec![true, false, true, true] ); - read_many::(test_request.headers().get_all("X-Bool-Invalid").iter()) + read_many_primitive::(test_request.headers().get_all("X-Bool-Invalid").iter()) .expect_err("invalid"); } #[test] - fn read_many_u16() { + fn check_read_many_i16() { let test_request = http::Request::builder() .header("X-Multi", "123,456") .header("X-Multi", "789") .header("X-Num", "777") .header("X-Num-Invalid", "12ef3") - .header("X-Num-Single", "1,2,3,4,5") + .header("X-Num-Single", "1,2,3,-4,5") .body(()) .unwrap(); assert_eq!( - read_many::(test_request.headers().get_all("X-Multi").iter()).expect("valid"), + read_many_primitive::(test_request.headers().get_all("X-Multi").iter()) + .expect("valid"), vec![123, 456, 789] ); assert_eq!( - read_many::(test_request.headers().get_all("X-Num").iter()).unwrap(), + read_many_primitive::(test_request.headers().get_all("X-Num").iter()).unwrap(), vec![777] ); assert_eq!( - read_many::(test_request.headers().get_all("X-Num-Single").iter()).unwrap(), - vec![1, 2, 3, 4, 5] + read_many_primitive::(test_request.headers().get_all("X-Num-Single").iter()) + .unwrap(), + vec![1, 2, 3, -4, 5] ); - read_many::(test_request.headers().get_all("X-Num-Invalid").iter()) + read_many_primitive::(test_request.headers().get_all("X-Num-Invalid").iter()) .expect_err("invalid"); } @@ -217,15 +295,14 @@ mod test { .header("X-Prefix-C", "777") .body(()) .unwrap(); - let resp: Result>, ParseError> = + let resp: Result>, ParseError> = headers_for_prefix(test_request.headers(), "X-Prefix-") .map(|(key, header_name)| { let values = test_request.headers().get_all(header_name); - read_many(values.iter()).map(|v| (key.to_string(), v)) + read_many_primitive(values.iter()).map(|v| (key.to_string(), v)) }) .collect(); let resp = resp.expect("valid"); - println!("{:?}", resp); - assert_eq!(resp.get("a"), Some(&vec![123_u16, 456_u16])); + assert_eq!(resp.get("a"), Some(&vec![123_i16, 456_i16])); } } diff --git a/rust-runtime/smithy-http/src/label.rs b/rust-runtime/smithy-http/src/label.rs index 48b61755c4..4a8a5f0584 100644 --- a/rust-runtime/smithy-http/src/label.rs +++ b/rust-runtime/smithy-http/src/label.rs @@ -9,14 +9,9 @@ use crate::urlencode::BASE_SET; use percent_encoding::AsciiSet; use smithy_types::Instant; -use std::fmt::Debug; const GREEDY: &AsciiSet = &BASE_SET.remove(b'/'); -pub fn fmt_default(t: T) -> String { - format!("{:?}", t) -} - pub fn fmt_string>(t: T, greedy: bool) -> String { let uri_set = if greedy { GREEDY } else { BASE_SET }; percent_encoding::utf8_percent_encode(t.as_ref(), &uri_set).to_string() diff --git a/rust-runtime/smithy-http/src/query.rs b/rust-runtime/smithy-http/src/query.rs index 821376177e..59dc9cc73e 100644 --- a/rust-runtime/smithy-http/src/query.rs +++ b/rust-runtime/smithy-http/src/query.rs @@ -8,11 +8,6 @@ use percent_encoding::utf8_percent_encode; /// Formatting values into the query string as specified in /// [httpQuery](https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#httpquery-trait) use smithy_types::Instant; -use std::fmt::Debug; - -pub fn fmt_default(t: T) -> String { - format!("{:?}", t) -} pub fn fmt_string>(t: T) -> String { utf8_percent_encode(t.as_ref(), BASE_SET).to_string() diff --git a/rust-runtime/smithy-json/Cargo.toml b/rust-runtime/smithy-json/Cargo.toml index 5f0ca2d3a8..243bc1dc2e 100644 --- a/rust-runtime/smithy-json/Cargo.toml +++ b/rust-runtime/smithy-json/Cargo.toml @@ -5,8 +5,6 @@ authors = ["AWS Rust SDK Team ", "John DiSanti JsonTokenIterator<'a> { offset, value: if floating { Number::Float( - f64::from_str(&number_str).map_err(|_| self.error_at(start, InvalidNumber))?, + f64::from_str(&number_str) + .map_err(|_| self.error_at(start, InvalidNumber)) + .and_then(|f| { + must_be_finite(f).map_err(|_| self.error_at(start, InvalidNumber)) + })?, ) } else if negative { // If the negative value overflows, then stuff it into an f64 @@ -484,6 +488,22 @@ impl<'a> Iterator for JsonTokenIterator<'a> { } } +fn must_be_finite(f: f64) -> Result { + if f.is_finite() { + Ok(f) + } else { + Err(()) + } +} + +fn must_not_be_finite(f: f64) -> Result { + if !f.is_finite() { + Ok(f) + } else { + Err(()) + } +} + #[cfg(test)] mod tests { use crate::deserialize::token::test::{ diff --git a/rust-runtime/smithy-json/src/deserialize/token.rs b/rust-runtime/smithy-json/src/deserialize/token.rs index 626d4b3456..c50bebb00f 100644 --- a/rust-runtime/smithy-json/src/deserialize/token.rs +++ b/rust-runtime/smithy-json/src/deserialize/token.rs @@ -9,7 +9,9 @@ use smithy_types::instant::Format; use smithy_types::{base64, Blob, Document, Instant, Number}; use std::borrow::Cow; +use crate::deserialize::must_not_be_finite; pub use crate::escape::Error as EscapeError; +use smithy_types::primitive::Parse; use std::collections::HashMap; use std::iter::Peekable; @@ -151,9 +153,45 @@ macro_rules! expect_value_or_null_fn { } expect_value_or_null_fn!(expect_bool_or_null, ValueBool, bool, "Expects a [Token::ValueBool] or [Token::ValueNull], and returns the bool value if it's not null."); -expect_value_or_null_fn!(expect_number_or_null, ValueNumber, Number, "Expects a [Token::ValueNumber] or [Token::ValueNull], and returns the [Number] value if it's not null."); expect_value_or_null_fn!(expect_string_or_null, ValueString, EscapedStr, "Expects a [Token::ValueString] or [Token::ValueNull], and returns the [EscapedStr] value if it's not null."); +/// Expects a [Token::ValueString], [Token::ValueNumber] or [Token::ValueNull]. +/// +/// If the value is a string, it MUST be `Infinity`, `-Infinity` or `Nan`. +/// If the value is a number, it is returned directly +pub fn expect_number_or_null( + token: Option, Error>>, +) -> Result, Error> { + match token.transpose()? { + Some(Token::ValueNull { .. }) => Ok(None), + Some(Token::ValueNumber { value, .. }) => Ok(Some(value)), + Some(Token::ValueString { value, offset }) => match value.to_unescaped() { + Err(err) => Err(Error::new( + ErrorReason::Custom(format!("expected a valid string, escape was invalid: {}", err).into()), Some(offset.0)) + ), + Ok(v) => f64::parse_smithy_primitive(v.as_ref()) + // disregard the exact error + .map_err(|_|()) + // only infinite / NaN can be used as strings + .and_then(must_not_be_finite) + .map(|float| Some(smithy_types::Number::Float(float))) + // convert to a helpful error + .map_err(|_| { + Error::new( + ErrorReason::Custom(Cow::Owned(format!( + "only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `{}`", + v + ))), + Some(offset.0), + ) + }), + }, + _ => Err(Error::custom( + "expected ValueString, ValueNumber, or ValueNull", + )), + } +} + /// Expects a [Token::ValueString] or [Token::ValueNull]. If the value is a string, it interprets it as a base64 encoded [Blob] value. pub fn expect_blob_or_null(token: Option, Error>>) -> Result, Error> { Ok(match expect_string_or_null(token)? { @@ -386,6 +424,15 @@ pub mod test { )) } + #[test] + fn test_non_finite_floats() { + let mut tokens = json_token_iter(b"inf"); + tokens + .next() + .expect("there is a token") + .expect_err("but it is invalid, ensure that Rust float boundary cases don't parse"); + } + #[test] fn mismatched_braces() { // The skip_value function doesn't need to explicitly handle these cases since @@ -466,9 +513,27 @@ pub mod test { expect_number_or_null(value_number(0, Number::PosInt(5))) ); assert_eq!( - Err(Error::custom("expected ValueNumber or ValueNull")), + Err(Error::custom( + "expected ValueString, ValueNumber, or ValueNull" + )), expect_number_or_null(value_bool(0, true)) ); + assert_eq!( + Ok(Some(Number::Float(f64::INFINITY))), + expect_number_or_null(value_string(0, "Infinity")) + ); + assert_eq!( + Err(Error::new(ErrorReason::Custom("only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `123`".into()), Some(0))), + expect_number_or_null(value_string(0, "123")) + ); + match expect_number_or_null(value_string(0, "NaN")) { + Ok(Some(Number::Float(v))) if v.is_nan() => { + // ok + } + not_ok => { + panic!("expected nan, found: {:?}", not_ok) + } + } } #[test] @@ -505,8 +570,14 @@ pub mod test { Ok(Some(Instant::from_f64(1445412480.0))), expect_timestamp_or_null(value_string(0, "2015-10-21T07:28:00Z"), Format::DateTime) ); + let err = Error::new( + ErrorReason::Custom( + "only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `wrong`".into(), + ), + Some(0), + ); assert_eq!( - Err(Error::custom("expected ValueNumber or ValueNull")), + Err(err), expect_timestamp_or_null(value_string(0, "wrong"), Format::EpochSeconds) ); assert_eq!( diff --git a/rust-runtime/smithy-json/src/serialize.rs b/rust-runtime/smithy-json/src/serialize.rs index 4eefce52e6..bf356e31f6 100644 --- a/rust-runtime/smithy-json/src/serialize.rs +++ b/rust-runtime/smithy-json/src/serialize.rs @@ -5,6 +5,7 @@ use crate::escape::escape_string; use smithy_types::instant::Format; +use smithy_types::primitive::Encoder; use smithy_types::{Document, Instant, Number}; use std::borrow::Cow; @@ -76,19 +77,18 @@ impl<'a> JsonValueWriter<'a> { match value { Number::PosInt(value) => { // itoa::Buffer is a fixed-size stack allocation, so this is cheap - self.output.push_str(itoa::Buffer::new().format(value)); + self.output.push_str(Encoder::from(value).encode()); } Number::NegInt(value) => { - self.output.push_str(itoa::Buffer::new().format(value)); + self.output.push_str(Encoder::from(value).encode()); } Number::Float(value) => { - // If the value is NaN, Infinity, or -Infinity - if value.is_nan() || value.is_infinite() { - self.output.push_str("null"); + let mut encoder: Encoder = value.into(); + // Nan / infinite values actually get written in quotes as a string value + if value.is_infinite() || value.is_nan() { + self.string_unchecked(encoder.encode()) } else { - // ryu::Buffer is a fixed-size stack allocation, so this is cheap - self.output - .push_str(ryu::Buffer::new().format_finite(value)); + self.output.push_str(encoder.encode()) } } } @@ -394,18 +394,15 @@ mod tests { assert_eq!("10000000000.0", format_test_number(Number::Float(1e10))); assert_eq!("-1.2", format_test_number(Number::Float(-1.2))); - // JSON doesn't support NaN, Infinity, or -Infinity, so we're matching + // Smithy has specific behavior for infinity & NaN // the behavior of the serde_json crate in these cases. + assert_eq!("\"NaN\"", format_test_number(Number::Float(f64::NAN))); assert_eq!( - serde_json::to_string(&f64::NAN).unwrap(), - format_test_number(Number::Float(f64::NAN)) - ); - assert_eq!( - serde_json::to_string(&f64::INFINITY).unwrap(), + "\"Infinity\"", format_test_number(Number::Float(f64::INFINITY)) ); assert_eq!( - serde_json::to_string(&f64::NEG_INFINITY).unwrap(), + "\"-Infinity\"", format_test_number(Number::Float(f64::NEG_INFINITY)) ); } diff --git a/rust-runtime/smithy-query/Cargo.toml b/rust-runtime/smithy-query/Cargo.toml index 35810b3f1a..256f49428b 100644 --- a/rust-runtime/smithy-query/Cargo.toml +++ b/rust-runtime/smithy-query/Cargo.toml @@ -5,7 +5,5 @@ authors = ["AWS Rust SDK Team ", "John DiSanti QueryValueWriter<'a> { match value { Number::PosInt(value) => { // itoa::Buffer is a fixed-size stack allocation, so this is cheap - self.string(itoa::Buffer::new().format(value)); + self.string(Encoder::from(value).encode()); } Number::NegInt(value) => { - self.string(itoa::Buffer::new().format(value)); - } - Number::Float(value) => { - // If the value is NaN, Infinity, or -Infinity - if value.is_nan() || value.is_infinite() { - self.string(""); - } else { - // ryu::Buffer is a fixed-size stack allocation, so this is cheap - self.string(ryu::Buffer::new().format_finite(value)); - } + self.string(Encoder::from(value).encode()); } + Number::Float(value) => self.string(Encoder::from(value).encode()), } } @@ -378,9 +371,9 @@ mod tests { &Version=1.0\ &PosInt=5\ &NegInt=-5\ - &Infinity=\ - &NegInfinity=\ - &NaN=\ + &Infinity=Infinity\ + &NegInfinity=-Infinity\ + &NaN=NaN\ &Floating=5.2\ ", out diff --git a/rust-runtime/smithy-types/Cargo.toml b/rust-runtime/smithy-types/Cargo.toml index 636bb28d18..0c08c47137 100644 --- a/rust-runtime/smithy-types/Cargo.toml +++ b/rust-runtime/smithy-types/Cargo.toml @@ -11,6 +11,8 @@ default = ["chrono-conversions"] [dependencies] chrono = { version = "0.4", default-features = false, features = [] } +ryu = "1.0.5" +itoa = "0.4.0" [dev-dependencies] base64 = "0.13.0" diff --git a/rust-runtime/smithy-types/src/lib.rs b/rust-runtime/smithy-types/src/lib.rs index 733d1fc2fe..1bcfd42c1d 100644 --- a/rust-runtime/smithy-types/src/lib.rs +++ b/rust-runtime/smithy-types/src/lib.rs @@ -5,6 +5,7 @@ pub mod base64; pub mod instant; +pub mod primitive; pub mod retry; use std::collections::HashMap; diff --git a/rust-runtime/smithy-types/src/primitive.rs b/rust-runtime/smithy-types/src/primitive.rs new file mode 100644 index 0000000000..65aef4d51e --- /dev/null +++ b/rust-runtime/smithy-types/src/primitive.rs @@ -0,0 +1,276 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +//! Utilities for formatting and parsing primitives +//! +//! Smithy protocols have specific behavior for serializing +//! & deserializing floats, specifically: +//! - NaN should be serialized as `NaN` +//! - Positive infinity should be serialized as `Infinity` +//! - Negative infinity should be serialized as `-Infinity` +//! +//! This module defines the [`Parse`](Parse) trait which +//! enables parsing primitive values (numbers & booleans) that follow +//! these rules and [`Encoder`](Encoder), a struct that enables +//! allocation-free serialization. +//! +//! # Examples +//! ## Parsing +//! ```rust +//! use smithy_types::primitive::Parse; +//! let parsed = f64::parse_smithy_primitive("123.4").expect("valid float"); +//! ``` +//! +//! ## Encoding +//! ``` +//! use smithy_types::primitive::Encoder; +//! assert_eq!("123.4", Encoder::from(123.4).encode()); +//! assert_eq!("Infinity", Encoder::from(f64::INFINITY).encode()); +//! assert_eq!("true", Encoder::from(true).encode()); +//! ``` +use crate::primitive::private::Sealed; +use std::error::Error; +use std::fmt::{Display, Formatter}; +use std::str::FromStr; + +/// An error during primitive parsing +#[non_exhaustive] +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct PrimitiveParseError(&'static str); +impl Display for PrimitiveParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "failed to parse input as {}", self.0) + } +} +impl Error for PrimitiveParseError {} + +/// Sealed trait for custom parsing of primitive types +pub trait Parse: Sealed { + fn parse_smithy_primitive(input: &str) -> Result + where + Self: Sized; +} + +mod private { + pub trait Sealed {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for u64 {} + impl Sealed for bool {} +} + +macro_rules! parse_from_str { + ($t: ty) => { + impl Parse for $t { + fn parse_smithy_primitive(input: &str) -> Result { + FromStr::from_str(input).map_err(|_| PrimitiveParseError(stringify!($t))) + } + } + }; +} + +parse_from_str!(bool); +parse_from_str!(i8); +parse_from_str!(i16); +parse_from_str!(i32); +parse_from_str!(i64); + +impl Parse for f32 { + fn parse_smithy_primitive(input: &str) -> Result { + float::parse_f32(input).map_err(|_| PrimitiveParseError("f32")) + } +} + +impl Parse for f64 { + fn parse_smithy_primitive(input: &str) -> Result { + float::parse_f64(input).map_err(|_| PrimitiveParseError("f64")) + } +} + +/// Primitive Type Encoder +/// +/// This type implements `From` for all Smithy primitive types. +#[non_exhaustive] +pub enum Encoder { + #[non_exhaustive] + Bool(bool), + #[non_exhaustive] + I8(i8, itoa::Buffer), + #[non_exhaustive] + I16(i16, itoa::Buffer), + #[non_exhaustive] + I32(i32, itoa::Buffer), + #[non_exhaustive] + I64(i64, itoa::Buffer), + #[non_exhaustive] + U64(u64, itoa::Buffer), + #[non_exhaustive] + F32(f32, ryu::Buffer), + #[non_exhaustive] + F64(f64, ryu::Buffer), +} + +impl Encoder { + pub fn encode(&mut self) -> &str { + match self { + Encoder::Bool(true) => "true", + Encoder::Bool(false) => "false", + Encoder::I8(v, buf) => buf.format(*v), + Encoder::I16(v, buf) => buf.format(*v), + Encoder::I32(v, buf) => buf.format(*v), + Encoder::I64(v, buf) => buf.format(*v), + Encoder::U64(v, buf) => buf.format(*v), + Encoder::F32(v, buf) => { + if v.is_nan() { + float::NAN + } else if *v == f32::INFINITY { + float::INFINITY + } else if *v == f32::NEG_INFINITY { + float::NEG_INFINITY + } else { + buf.format_finite(*v) + } + } + Encoder::F64(v, buf) => { + if v.is_nan() { + float::NAN + } else if *v == f64::INFINITY { + float::INFINITY + } else if *v == f64::NEG_INFINITY { + float::NEG_INFINITY + } else { + buf.format_finite(*v) + } + } + } + } +} + +impl From for Encoder { + fn from(input: bool) -> Self { + Self::Bool(input) + } +} + +impl From for Encoder { + fn from(input: i8) -> Self { + Self::I8(input, itoa::Buffer::new()) + } +} + +impl From for Encoder { + fn from(input: i16) -> Self { + Self::I16(input, itoa::Buffer::new()) + } +} + +impl From for Encoder { + fn from(input: i32) -> Self { + Self::I32(input, itoa::Buffer::new()) + } +} + +impl From for Encoder { + fn from(input: i64) -> Self { + Self::I64(input, itoa::Buffer::new()) + } +} + +impl From for Encoder { + fn from(input: u64) -> Self { + Self::U64(input, itoa::Buffer::new()) + } +} + +impl From for Encoder { + fn from(input: f32) -> Self { + Self::F32(input, ryu::Buffer::new()) + } +} + +impl From for Encoder { + fn from(input: f64) -> Self { + Self::F64(input, ryu::Buffer::new()) + } +} + +mod float { + use std::num::ParseFloatError; + pub const INFINITY: &str = "Infinity"; + pub const NEG_INFINITY: &str = "-Infinity"; + pub const NAN: &str = "NaN"; + + pub fn parse_f32(data: &str) -> Result { + match data { + INFINITY => Ok(f32::INFINITY), + NEG_INFINITY => Ok(f32::NEG_INFINITY), + NAN => Ok(f32::NAN), + other => other.parse::(), + } + } + + pub fn parse_f64(data: &str) -> Result { + match data { + INFINITY => Ok(f64::INFINITY), + NEG_INFINITY => Ok(f64::NEG_INFINITY), + NAN => Ok(f64::NAN), + other => other.parse::(), + } + } +} + +#[cfg(test)] +mod test { + use crate::primitive::{Encoder, Parse}; + + #[test] + fn bool_format() { + assert_eq!(Encoder::from(true).encode(), "true"); + assert_eq!(Encoder::from(false).encode(), "false"); + let err = bool::parse_smithy_primitive("not a boolean").expect_err("should fail"); + assert_eq!(err.0, "bool"); + assert_eq!(bool::parse_smithy_primitive("true"), Ok(true)); + assert_eq!(bool::parse_smithy_primitive("false"), Ok(false)); + } + + #[test] + fn float_format() { + assert_eq!(Encoder::from(55_f64).encode(), "55.0"); + assert_eq!(Encoder::from(f64::INFINITY).encode(), "Infinity"); + assert_eq!(Encoder::from(f32::INFINITY).encode(), "Infinity"); + assert_eq!(Encoder::from(f32::NEG_INFINITY).encode(), "-Infinity"); + assert_eq!(Encoder::from(f64::NEG_INFINITY).encode(), "-Infinity"); + assert_eq!(Encoder::from(f32::NAN).encode(), "NaN"); + assert_eq!(Encoder::from(f64::NAN).encode(), "NaN"); + } + + #[test] + fn float_parse() { + assert_eq!(f64::parse_smithy_primitive("1234.5"), Ok(1234.5)); + assert!(f64::parse_smithy_primitive("NaN").unwrap().is_nan()); + assert_eq!( + f64::parse_smithy_primitive("Infinity").unwrap(), + f64::INFINITY + ); + assert_eq!( + f64::parse_smithy_primitive("-Infinity").unwrap(), + f64::NEG_INFINITY + ); + assert_eq!(f32::parse_smithy_primitive("1234.5"), Ok(1234.5)); + assert!(f32::parse_smithy_primitive("NaN").unwrap().is_nan()); + assert_eq!( + f32::parse_smithy_primitive("Infinity").unwrap(), + f32::INFINITY + ); + assert_eq!( + f32::parse_smithy_primitive("-Infinity").unwrap(), + f32::NEG_INFINITY + ); + } +}