diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2113b6e79f..d19d844d70 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,7 @@
**New This Week**
- :bug: Correctly encode HTTP Checksums using base64 instead of hex. Fixes aws-sdk-rust#164. (#615)
- (When complete) Add profile file provider for region (#594, #xyz)
+- Overhaul serialization/deserialization of numeric/boolean types. This resolves issues around serialization of NaN/Infinity and should also reduce the number of allocations required during serialization. (#618)
## v0.18.1 (July 27th 2021)
* Remove timestreamwrite and timestreamquery from the generated services (#613)
diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt
index 6db5a99be3..a3cb79dd0c 100644
--- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt
+++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/customize/s3/S3Decorator.kt
@@ -6,12 +6,8 @@
package software.amazon.smithy.rustsdk.customize.s3
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
-import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
-import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
-import software.amazon.smithy.model.shapes.StructureShape
-import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.Writable
import software.amazon.smithy.rust.codegen.rustlang.asType
@@ -28,7 +24,6 @@ import software.amazon.smithy.rust.codegen.smithy.letIf
import software.amazon.smithy.rust.codegen.smithy.protocols.ProtocolMap
import software.amazon.smithy.rust.codegen.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.smithy.protocols.RestXmlFactory
-import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.rustsdk.AwsRuntimeType
/**
@@ -59,20 +54,6 @@ class S3Decorator : RustCodegenDecorator {
it + S3PubUse()
}
}
-
- override fun transformModel(service: ServiceShape, model: Model): Model {
- return model.letIf(applies(service.id)) {
- ModelTransformer.create().mapShapes(model) { shape ->
- // Apply the S3UnwrappedXmlOutput customization to GetBucketLocation (more
- // details on the S3UnwrappedXmlOutputTrait)
- if (shape is StructureShape && shape.id == ShapeId.from("com.amazonaws.s3#GetBucketLocationOutput")) {
- shape.toBuilder().addTrait(S3UnwrappedXmlOutputTrait()).build()
- } else {
- shape
- }
- }
- }
- }
}
class S3(protocolConfig: ProtocolConfig) : RestXml(protocolConfig) {
diff --git a/aws/sdk/aws-models/accessanalyzer.json b/aws/sdk/aws-models/accessanalyzer.json
index f2a87af9d3..b261071289 100644
--- a/aws/sdk/aws-models/accessanalyzer.json
+++ b/aws/sdk/aws-models/accessanalyzer.json
@@ -4823,4 +4823,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/amp.json b/aws/sdk/aws-models/amp.json
index cd66167b68..84043fdc5d 100644
--- a/aws/sdk/aws-models/amp.json
+++ b/aws/sdk/aws-models/amp.json
@@ -935,4 +935,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/appmesh.json b/aws/sdk/aws-models/appmesh.json
index 9b2500c5a0..9273d629e0 100644
--- a/aws/sdk/aws-models/appmesh.json
+++ b/aws/sdk/aws-models/appmesh.json
@@ -8671,4 +8671,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/braket.json b/aws/sdk/aws-models/braket.json
index 209a8dadc6..c89e12506b 100644
--- a/aws/sdk/aws-models/braket.json
+++ b/aws/sdk/aws-models/braket.json
@@ -1404,4 +1404,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/codeguruprofiler.json b/aws/sdk/aws-models/codeguruprofiler.json
index 68b2066e3c..e54bb4e6d1 100644
--- a/aws/sdk/aws-models/codeguruprofiler.json
+++ b/aws/sdk/aws-models/codeguruprofiler.json
@@ -3174,4 +3174,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/groundstation.json b/aws/sdk/aws-models/groundstation.json
index 9fe4f04b6e..9af96283a8 100644
--- a/aws/sdk/aws-models/groundstation.json
+++ b/aws/sdk/aws-models/groundstation.json
@@ -3544,4 +3544,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/location.json b/aws/sdk/aws-models/location.json
index 3dd786826e..1d7718404d 100644
--- a/aws/sdk/aws-models/location.json
+++ b/aws/sdk/aws-models/location.json
@@ -5937,4 +5937,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/mgn.json b/aws/sdk/aws-models/mgn.json
index b0512d4832..10eb1f4f41 100644
--- a/aws/sdk/aws-models/mgn.json
+++ b/aws/sdk/aws-models/mgn.json
@@ -3944,4 +3944,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/mwaa.json b/aws/sdk/aws-models/mwaa.json
index f3e312eddc..ad18f1105b 100644
--- a/aws/sdk/aws-models/mwaa.json
+++ b/aws/sdk/aws-models/mwaa.json
@@ -1938,4 +1938,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/proton.json b/aws/sdk/aws-models/proton.json
index 3a569269dc..03616adff2 100644
--- a/aws/sdk/aws-models/proton.json
+++ b/aws/sdk/aws-models/proton.json
@@ -6482,4 +6482,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/rds-data.json b/aws/sdk/aws-models/rds-data.json
index 488e7ed124..2e3e75b220 100644
--- a/aws/sdk/aws-models/rds-data.json
+++ b/aws/sdk/aws-models/rds-data.json
@@ -1336,4 +1336,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/redshift-data.json b/aws/sdk/aws-models/redshift-data.json
index 9b4600e23b..6e009f4c6c 100644
--- a/aws/sdk/aws-models/redshift-data.json
+++ b/aws/sdk/aws-models/redshift-data.json
@@ -1423,4 +1423,4 @@
"type": "boolean"
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/aws-models/s3.json b/aws/sdk/aws-models/s3.json
index 15dc938600..b5293082c6 100644
--- a/aws/sdk/aws-models/s3.json
+++ b/aws/sdk/aws-models/s3.json
@@ -4073,6 +4073,7 @@
"target": "com.amazonaws.s3#GetBucketLocationOutput"
},
"traits": {
+ "aws.customizations#s3UnwrappedXmlOutput": {},
"smithy.api#documentation": "
Returns the Region the bucket resides in. You set the bucket's Region using the\n LocationConstraint
request parameter in a CreateBucket
\n request. For more information, see CreateBucket .
\n\n To use this implementation of the operation, you must be the bucket owner.
\n\n The following operations are related to GetBucketLocation
:
\n ",
"smithy.api#http": {
"method": "GET",
diff --git a/aws/sdk/aws-models/ssm-incidents.json b/aws/sdk/aws-models/ssm-incidents.json
index e2ab5381f2..20f4f25d93 100644
--- a/aws/sdk/aws-models/ssm-incidents.json
+++ b/aws/sdk/aws-models/ssm-incidents.json
@@ -4041,4 +4041,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/aws/sdk/examples/batch/Cargo.toml b/aws/sdk/examples/batch/Cargo.toml
index 4be66d845d..7998b9d8d7 100644
--- a/aws/sdk/examples/batch/Cargo.toml
+++ b/aws/sdk/examples/batch/Cargo.toml
@@ -11,4 +11,4 @@ batch = { package = "aws-sdk-batch", path = "../../build/aws-sdk/batch" }
aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = "0.2.18"
\ No newline at end of file
+tracing-subscriber = "0.2.18"
diff --git a/aws/sdk/examples/ebs/Cargo.toml b/aws/sdk/examples/ebs/Cargo.toml
index f33fe0c87a..75e5c29d0f 100644
--- a/aws/sdk/examples/ebs/Cargo.toml
+++ b/aws/sdk/examples/ebs/Cargo.toml
@@ -14,4 +14,4 @@ tokio = { version = "1", features = ["full"]}
base64 = "0.13.0"
sha2 = "0.9.5"
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = "0.2.19"
\ No newline at end of file
+tracing-subscriber = "0.2.19"
diff --git a/aws/sdk/examples/kinesis/Cargo.toml b/aws/sdk/examples/kinesis/Cargo.toml
index 35c5884944..57ca5fa25c 100644
--- a/aws/sdk/examples/kinesis/Cargo.toml
+++ b/aws/sdk/examples/kinesis/Cargo.toml
@@ -11,4 +11,4 @@ kinesis = { package = "aws-sdk-kinesis", path = "../../build/aws-sdk/kinesis" }
aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
\ No newline at end of file
+tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
diff --git a/aws/sdk/examples/medialive/Cargo.toml b/aws/sdk/examples/medialive/Cargo.toml
index 22c535b6ad..552cd313d0 100644
--- a/aws/sdk/examples/medialive/Cargo.toml
+++ b/aws/sdk/examples/medialive/Cargo.toml
@@ -11,4 +11,4 @@ medialive = { package = "aws-sdk-medialive", path = "../../build/aws-sdk/mediali
aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
\ No newline at end of file
+tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
diff --git a/aws/sdk/examples/mediapackage/Cargo.toml b/aws/sdk/examples/mediapackage/Cargo.toml
index dc486ffe7d..fca8c6e851 100644
--- a/aws/sdk/examples/mediapackage/Cargo.toml
+++ b/aws/sdk/examples/mediapackage/Cargo.toml
@@ -11,4 +11,4 @@ mediapackage = { package = "aws-sdk-mediapackage", path = "../../build/aws-sdk/m
aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
\ No newline at end of file
+tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
diff --git a/aws/sdk/examples/polly/Cargo.toml b/aws/sdk/examples/polly/Cargo.toml
index d4991e4537..ffdc05b492 100644
--- a/aws/sdk/examples/polly/Cargo.toml
+++ b/aws/sdk/examples/polly/Cargo.toml
@@ -12,4 +12,3 @@ aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
-
diff --git a/aws/sdk/examples/qldb/Cargo.toml b/aws/sdk/examples/qldb/Cargo.toml
index 8fd299aa61..e1f4e18ac6 100644
--- a/aws/sdk/examples/qldb/Cargo.toml
+++ b/aws/sdk/examples/qldb/Cargo.toml
@@ -13,5 +13,3 @@ aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
-
-
diff --git a/aws/sdk/examples/qldb/README.md b/aws/sdk/examples/qldb/README.md
index ffafc8a37a..49b61fb6a4 100644
--- a/aws/sdk/examples/qldb/README.md
+++ b/aws/sdk/examples/qldb/README.md
@@ -51,4 +51,4 @@ where:
If the environment variable is not set, defaults to **us-west-2**.
- __-v__ enables displaying additional information.
-##
+##
diff --git a/aws/sdk/examples/rds/Cargo.toml b/aws/sdk/examples/rds/Cargo.toml
index 0b42125c94..88fa8475dd 100644
--- a/aws/sdk/examples/rds/Cargo.toml
+++ b/aws/sdk/examples/rds/Cargo.toml
@@ -11,4 +11,4 @@ rds = {package = "aws-sdk-rds", path = "../../build/aws-sdk/rds"}
aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = {version = "1", features = ["full"]}
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
\ No newline at end of file
+tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
diff --git a/aws/sdk/examples/rdsdata/Cargo.toml b/aws/sdk/examples/rdsdata/Cargo.toml
index c6ff30216d..d9cf9d33f5 100644
--- a/aws/sdk/examples/rdsdata/Cargo.toml
+++ b/aws/sdk/examples/rdsdata/Cargo.toml
@@ -11,4 +11,4 @@ rdsdata = {package = "aws-sdk-rdsdata", path = "../../build/aws-sdk/rdsdata"}
aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = {version = "1", features = ["full"]}
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
\ No newline at end of file
+tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
diff --git a/aws/sdk/examples/sagemaker/Cargo.toml b/aws/sdk/examples/sagemaker/Cargo.toml
index c4090b603e..ef74a222d3 100644
--- a/aws/sdk/examples/sagemaker/Cargo.toml
+++ b/aws/sdk/examples/sagemaker/Cargo.toml
@@ -15,4 +15,4 @@ tokio = { version = "1", features = ["full"] }
env_logger = "0.8.2"
chrono = "0.4.19"
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = "0.2.18"
\ No newline at end of file
+tracing-subscriber = "0.2.18"
diff --git a/aws/sdk/examples/secretsmanager/Cargo.toml b/aws/sdk/examples/secretsmanager/Cargo.toml
index cb1b2b715f..9aac062109 100644
--- a/aws/sdk/examples/secretsmanager/Cargo.toml
+++ b/aws/sdk/examples/secretsmanager/Cargo.toml
@@ -14,4 +14,3 @@ tokio = { version = "1", features = ["full"]}
structopt = { version = "0.3", default-features = false }
tracing-subscriber = { version = "0.2.16", features = ["fmt"] }
-
diff --git a/aws/sdk/examples/snowball/Cargo.toml b/aws/sdk/examples/snowball/Cargo.toml
index a0044c1e1c..3c862a76a9 100644
--- a/aws/sdk/examples/snowball/Cargo.toml
+++ b/aws/sdk/examples/snowball/Cargo.toml
@@ -11,4 +11,4 @@ aws-sdk-snowball = { path = "../../build/aws-sdk/snowball" }
aws-types = { path = "../../build/aws-sdk/aws-types" }
tokio = { version = "1", features = ["full"] }
structopt = { version = "0.3", default-features = false }
-tracing-subscriber = "0.2.18"
\ No newline at end of file
+tracing-subscriber = "0.2.18"
diff --git a/aws/sdk/examples/sqs/Cargo.toml b/aws/sdk/examples/sqs/Cargo.toml
index 9fadf19c30..5196651b5b 100644
--- a/aws/sdk/examples/sqs/Cargo.toml
+++ b/aws/sdk/examples/sqs/Cargo.toml
@@ -9,4 +9,4 @@ edition = "2018"
[dependencies]
sqs = { package = "aws-sdk-sqs", path = "../../build/aws-sdk/sqs" }
tokio = { version = "1", features = ["full"] }
-tracing-subscriber = "0.2.18"
\ No newline at end of file
+tracing-subscriber = "0.2.18"
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt
index fafd16332d..27f2e73516 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/rustlang/RustTypes.kt
@@ -8,6 +8,17 @@ package software.amazon.smithy.rust.codegen.rustlang
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.util.dq
+/**
+ * Dereference [input]
+ *
+ * Clippy is upset about `*&`, so if [input] is already referenced, simply strip the leading '&'
+ */
+fun autoDeref(input: String) = if (input.startsWith("&")) {
+ input.removePrefix("&")
+} else {
+ "*$input"
+}
+
/**
* A hierarchy of types handled by Smithy codegen
*/
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt
index 22c329a535..109feb8b82 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/HttpProtocolTestGenerator.kt
@@ -7,6 +7,8 @@ package software.amazon.smithy.rust.codegen.smithy.generators
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.OperationIndex
+import software.amazon.smithy.model.shapes.DoubleShape
+import software.amazon.smithy.model.shapes.FloatShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
@@ -306,7 +308,21 @@ class HttpProtocolTestGenerator(
);"""
)
} else {
- rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""")
+ when (protocolConfig.model.expectShape(member.target)) {
+ is DoubleShape, is FloatShape -> {
+ addUseImports(
+ RuntimeType.ProtocolTestHelper(protocolConfig.runtimeConfig, "FloatEquals").toSymbol()
+ )
+ rust(
+ """
+ assert!(parsed.$memberName.float_equals(&expected_output.$memberName),
+ "Unexpected value for `$memberName` {:?} vs. {:?}", expected_output.$memberName, parsed.$memberName);
+ """
+ )
+ }
+ else ->
+ rust("""assert_eq!(parsed.$memberName, expected_output.$memberName, "Unexpected value for `$memberName`");""")
+ }
}
}
}
@@ -428,7 +444,11 @@ class HttpProtocolTestGenerator(
private val RestXml = "aws.protocoltests.restxml#RestXml"
private val AwsQuery = "aws.protocoltests.query#AwsQuery"
private val Ec2Query = "aws.protocoltests.ec2#AwsEc2"
- private val ExpectFail = setOf()
+ private val ExpectFail = setOf(
+ FailingTest(
+ service = RestJson, id = "RestJsonHostWithPath", action = Action.Request
+ )
+ )
private val RunOnly: Set? = null
// These tests are not even attempted to be generated, either because they will not compile
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt
index 72fade3271..541732e8f0 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/Instantiator.kt
@@ -114,7 +114,18 @@ class Instantiator(
// Simple Shapes
is StringShape -> renderString(writer, shape, arg as StringNode)
- is NumberShape -> writer.write(arg.asNumberNode().get())
+ is NumberShape -> when (arg) {
+ is StringNode -> {
+ val numberSymbol = symbolProvider.toSymbol(shape)
+ // support Smithy custom values, such as Infinity
+ writer.rust(
+ """<#T as #T>::parse_smithy_primitive(${arg.value.dq()}).expect("invalid string for number")""",
+ numberSymbol,
+ CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Parse")
+ )
+ }
+ is NumberNode -> writer.write(arg.value)
+ }
is BooleanShape -> writer.write(arg.asBooleanNode().get().toString())
is DocumentShape -> writer.rustBlock("") {
val smithyJson = CargoDependency.smithyJson(runtimeConfig).asType()
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt
index fbc33f1ae0..da9d690bae 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/RequestBindingGenerator.kt
@@ -5,6 +5,7 @@
package software.amazon.smithy.rust.codegen.smithy.generators.http
+import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
@@ -19,8 +20,10 @@ import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.rust.codegen.rustlang.Attribute
+import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
-import software.amazon.smithy.rust.codegen.rustlang.assignment
+import software.amazon.smithy.rust.codegen.rustlang.asType
+import software.amazon.smithy.rust.codegen.rustlang.autoDeref
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
@@ -36,6 +39,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait
+import software.amazon.smithy.rust.codegen.util.isPrimitive
fun HttpTrait.uriFormatString(): String {
return uri.rustFormatString("/", "/")
@@ -71,6 +75,7 @@ class RequestBindingGenerator(
) {
private val index = HttpBindingIndex.of(model)
private val buildError = runtimeConfig.operationBuildError()
+ private val Encoder = CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder")
constructor(
protocolConfig: ProtocolConfig,
@@ -193,6 +198,9 @@ class RequestBindingGenerator(
ifSet(memberType, memberSymbol, "&self.$memberName") { field ->
listForEach(memberType, field) { innerField, targetId ->
val innerMemberType = model.expectShape(targetId)
+ if (innerMemberType.isPrimitive()) {
+ rust("let mut encoder = #T::from(${autoDeref(innerField)});", Encoder)
+ }
val formatted = headerFmtFun(this, innerMemberType, memberShape, innerField)
val safeName = safeName("formatted")
write("let $safeName = $formatted;")
@@ -241,10 +249,10 @@ class RequestBindingGenerator(
target.isListShape || target.isMemberShape -> {
throw IllegalArgumentException("lists should be handled at a higher level")
}
- else -> {
- val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_default"))
- "$func(&$targetName)"
+ target.isPrimitive() -> {
+ "encoder.encode()"
}
+ else -> throw CodegenException("unexpected shape: $target")
}
}
@@ -263,12 +271,13 @@ class RequestBindingGenerator(
}
val combinedArgs = listOf(formatString, *args.toTypedArray())
writer.addImport(RuntimeType.stdfmt.member("Write").toSymbol(), null)
- writer.rustBlock("fn uri_base(&self, output: &mut String) -> Result<(), #T>", runtimeConfig.operationBuildError()) {
+ writer.rustBlock(
+ "fn uri_base(&self, output: &mut String) -> Result<(), #T>",
+ runtimeConfig.operationBuildError()
+ ) {
httpTrait.uri.labels.map { label ->
val member = inputShape.expectMember(label.content)
- assignment(local(member)) {
- serializeLabel(member, label)
- }
+ serializeLabel(member, label, local(member))
}
rust("""write!(output, ${combinedArgs.joinToString(", ")}).expect("formatting should succeed");""")
rust("Ok(())")
@@ -374,13 +383,12 @@ class RequestBindingGenerator(
throw IllegalArgumentException("lists should be handled at a higher level")
}
else -> {
- val func = writer.format(RuntimeType.QueryFormat(runtimeConfig, "fmt_default"))
- "$func(&$targetName)"
+ "${writer.format(Encoder)}::from(${autoDeref(targetName)}).encode()"
}
}
}
- private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment) {
+ private fun RustWriter.serializeLabel(member: MemberShape, label: SmithyPattern.Segment, outputVar: String) {
val target = model.expectShape(member.target)
val symbol = symbolProvider.toSymbol(member)
val buildError = {
@@ -390,37 +398,37 @@ class RequestBindingGenerator(
"cannot be empty or unset"
)
}
- rustBlock("") {
- rust("let input = &self.${symbolProvider.toMemberName(member)};")
- if (symbol.isOptional()) {
- rust("let input = input.as_ref().ok_or(${buildError()})?;")
+ val input = safeName("input")
+ rust("let $input = &self.${symbolProvider.toMemberName(member)};")
+ if (symbol.isOptional()) {
+ rust("let $input = $input.as_ref().ok_or(${buildError()})?;")
+ }
+ when {
+ target.isStringShape -> {
+ val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string"))
+ rust("let $outputVar = $func($input, ${label.isGreedyLabel});")
}
- when {
- target.isStringShape -> {
- val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_string"))
- rust("let formatted = $func(input, ${label.isGreedyLabel});")
- }
- target.isTimestampShape -> {
- val timestampFormat =
- index.determineTimestampFormat(member, HttpBinding.Location.LABEL, defaultTimestampFormat)
- val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
- val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
- rust("let formatted = $func(&input, ${format(timestampFormatType)});")
- }
- else -> {
- val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_default"))
- rust("let formatted = $func(input);")
- }
+ target.isTimestampShape -> {
+ val timestampFormat =
+ index.determineTimestampFormat(member, HttpBinding.Location.LABEL, defaultTimestampFormat)
+ val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
+ val func = format(RuntimeType.LabelFormat(runtimeConfig, "fmt_timestamp"))
+ rust("let $outputVar = $func(&$input, ${format(timestampFormatType)});")
}
- rust(
- """
- if formatted.is_empty() {
+ else -> {
+ rust(
+ "let mut ${outputVar}_encoder = #T::from(${autoDeref(input)}); let $outputVar = ${outputVar}_encoder.encode();",
+ Encoder
+ )
+ }
+ }
+ rust(
+ """
+ if $outputVar.is_empty() {
return Err(${buildError()})
}
- formatted
"""
- )
- }
+ )
}
/** End URI generation **/
}
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt
index 14c0b465b0..2d70ea3264 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/generators/http/ResponseBindingGenerator.kt
@@ -38,6 +38,7 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.hasTrait
+import software.amazon.smithy.rust.codegen.util.isPrimitive
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.toSnakeCase
@@ -238,17 +239,22 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
headerUtil,
timestampFormatType
)
+ } else if (coreShape.isPrimitive()) {
+ rust(
+ "let $parsedValue = #T::read_many_primitive::<${coreType.render(fullyQualified = true)}>(headers)?;",
+ headerUtil
+ )
} else {
rust(
- "let $parsedValue: Vec<${coreType.render(true)}> = #T::read_many(headers)?;",
+ "let $parsedValue: Vec<${coreType.render(fullyQualified = true)}> = #T::read_many_from_str(headers)?;",
headerUtil
)
if (coreShape.hasTrait()) {
rustTemplate(
"""let $parsedValue: std::result::Result, _> = $parsedValue
.iter().map(|s|
- #{base_64_decode}(s).map_err(|_|#{header}::ParseError)
- .and_then(|bytes|String::from_utf8(bytes).map_err(|_|#{header}::ParseError))
+ #{base_64_decode}(s).map_err(|_|#{header}::ParseError::new_with_message("failed to decode base64"))
+ .and_then(|bytes|String::from_utf8(bytes).map_err(|_|#{header}::ParseError::new_with_message("base64 encoded data was not valid utf-8")))
).collect();""",
"base_64_decode" to RuntimeType.Base64Decode(runtimeConfig),
"header" to headerUtil
@@ -281,7 +287,7 @@ class ResponseBindingGenerator(protocolConfig: ProtocolConfig, private val opera
else -> rustTemplate(
"""
if $parsedValue.len() > 1 {
- Err(#{header_util}::ParseError)
+ Err(#{header_util}::ParseError::new_with_message(format!("expected one item but found {}", $parsedValue.len())))
} else {
let mut $parsedValue = $parsedValue;
Ok($parsedValue.pop())
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt
index 1a2cb050f0..0ab8d68ca7 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/parse/XmlBindingTraitParserGenerator.kt
@@ -5,6 +5,7 @@
package software.amazon.smithy.rust.codegen.smithy.protocols.parse
+import software.amazon.smithy.aws.traits.customizations.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
@@ -34,6 +35,7 @@ import software.amazon.smithy.rust.codegen.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.rustlang.withBlock
+import software.amazon.smithy.rust.codegen.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.smithy.generators.ProtocolConfig
import software.amazon.smithy.rust.codegen.smithy.generators.StructureGenerator
@@ -44,7 +46,6 @@ import software.amazon.smithy.rust.codegen.smithy.isOptional
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlMemberIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.XmlNameIndex
import software.amazon.smithy.rust.codegen.smithy.protocols.deserializeFunctionName
-import software.amazon.smithy.rust.codegen.smithy.traits.S3UnwrappedXmlOutputTrait
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectMember
import software.amazon.smithy.rust.codegen.util.hasTrait
@@ -101,7 +102,8 @@ class XmlBindingTraitParserGenerator(
"XmlError" to xmlError,
"next_start_element" to smithyXml.member("decode::next_start_element"),
"try_data" to smithyXml.member("decode::try_data"),
- "ScopedDecoder" to scopedDecoder
+ "ScopedDecoder" to scopedDecoder,
+ "smithy_types" to CargoDependency.SmithyTypes(runtimeConfig).asType()
)
private val model = protocolConfig.model
private val index = HttpBindingIndex.of(model)
@@ -192,7 +194,7 @@ class XmlBindingTraitParserGenerator(
*codegenScope
)
val context = OperationWrapperContext(operationShape, shapeName, xmlError)
- if (outputShape.hasTrait()) {
+ if (operationShape.hasTrait()) {
unwrappedResponseParser("builder", "decoder", "start_el", outputShape.members())
} else {
writeOperationWrapper(context) { tagName ->
@@ -561,8 +563,12 @@ class XmlBindingTraitParserGenerator(
is StringShape -> parseStringInner(shape, provider)
is NumberShape, is BooleanShape -> {
rustBlock("") {
- rust("use std::str::FromStr;")
- withBlock("#T::from_str(", ")", symbolProvider.toSymbol(shape)) {
+ withBlockTemplate(
+ "<#{shape} as #{smithy_types}::primitive::Parse>::parse_smithy_primitive(",
+ ")",
+ *codegenScope,
+ "shape" to symbolProvider.toSymbol(shape)
+ ) {
provider()
}
rustTemplate(
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt
index f2b4ce53ee..469c8e824f 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/protocols/serialize/XmlBindingTraitSerializerGenerator.kt
@@ -27,6 +27,7 @@ import software.amazon.smithy.rust.codegen.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.rustlang.RustType
import software.amazon.smithy.rust.codegen.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.rustlang.asType
+import software.amazon.smithy.rust.codegen.rustlang.autoDeref
import software.amazon.smithy.rust.codegen.rustlang.render
import software.amazon.smithy.rust.codegen.rustlang.rust
import software.amazon.smithy.rust.codegen.rustlang.rustBlock
@@ -195,17 +196,6 @@ class XmlBindingTraitSerializerGenerator(
rust("scope.finish();")
}
- /**
- * Dereference [input]
- *
- * Clippy is upset about `*&`, so if [input] is already referenced, simply strip the leading '&'
- */
- private fun autoDeref(input: String) = if (input.startsWith("&")) {
- input.removePrefix("&")
- } else {
- "*$input"
- }
-
private fun RustWriter.serializeRawMember(member: MemberShape, input: String) {
when (val shape = model.expectShape(member.target)) {
is StringShape -> if (shape.hasTrait()) {
@@ -213,8 +203,9 @@ class XmlBindingTraitSerializerGenerator(
} else {
rust("$input.as_ref()")
}
- is NumberShape -> rust("$input.to_string().as_ref()")
- is BooleanShape -> rust("""if ${autoDeref(input)} { "true" } else { "false" }""")
+ is BooleanShape, is NumberShape -> {
+ rust("#T::from(${autoDeref(input)}).encode()", CargoDependency.SmithyTypes(runtimeConfig).asType().member("primitive::Encoder"))
+ }
is BlobShape -> rust("#T($input.as_ref()).as_ref()", RuntimeType.Base64Encode(runtimeConfig))
is TimestampShape -> {
val timestampFormat =
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/S3UnwrappedXmlOutputTrait.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/S3UnwrappedXmlOutputTrait.kt
deleted file mode 100644
index 7d7eb41b25..0000000000
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/smithy/traits/S3UnwrappedXmlOutputTrait.kt
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
- * SPDX-License-Identifier: Apache-2.0.
- */
-
-package software.amazon.smithy.rust.codegen.smithy.traits
-
-import software.amazon.smithy.model.node.Node
-import software.amazon.smithy.model.shapes.ShapeId
-import software.amazon.smithy.model.traits.AnnotationTrait
-
-/**
- * S3's GetBucketLocation response shape can't be represented with Smithy's restXml protocol
- * without customization. We add this trait to the S3 model at codegen time so that a different
- * code path is taken in the XML deserialization codegen to generate code that parses the S3
- * response shape correctly.
- *
- * From what the S3 model states, the generated parser would expect:
- * ```
- *
- * us-west-2
- *
- * ```
- *
- * But S3 actually responds with:
- * ```
- * us-west-2
- * ```
- */
-class S3UnwrappedXmlOutputTrait : AnnotationTrait(ID, Node.objectNode()) {
- companion object {
- val ID = ShapeId.from("smithy.api.internal#s3UnwrappedXmlOutputTrait")
- }
-}
diff --git a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt
index ea5c9af3fa..4391e6e7aa 100644
--- a/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt
+++ b/codegen/src/main/kotlin/software/amazon/smithy/rust/codegen/util/Smithy.kt
@@ -7,7 +7,9 @@ package software.amazon.smithy.rust.codegen.util
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.model.Model
+import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.MemberShape
+import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
@@ -65,3 +67,10 @@ inline fun Shape.expectTrait(): T = expectTrait(T::class.jav
/** Kotlin sugar for getTrait() check. e.g. shape.getTrait() instead of shape.getTrait(EnumTrait::class.java) */
inline fun Shape.getTrait(): T? = getTrait(T::class.java).orNull()
+
+fun Shape.isPrimitive(): Boolean {
+ return when (this) {
+ is NumberShape, is BooleanShape -> true
+ else -> false
+ }
+}
diff --git a/gradle.properties b/gradle.properties
index abc15cd091..061d0c3249 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -6,7 +6,7 @@
kotlin.code.style=official
# codegen
-smithyVersion=1.8.0
+smithyVersion=1.10.0
# kotlin
kotlinVersion=1.4.21
diff --git a/rust-runtime/inlineable/Cargo.toml b/rust-runtime/inlineable/Cargo.toml
index 07aac64313..5464f6769f 100644
--- a/rust-runtime/inlineable/Cargo.toml
+++ b/rust-runtime/inlineable/Cargo.toml
@@ -20,4 +20,4 @@ are to allow this crate to be compilable and testable in isolation, no client co
[dev-dependencies]
proptest = "1"
-regex = "1"
\ No newline at end of file
+regex = "1"
diff --git a/rust-runtime/protocol-test-helpers/src/lib.rs b/rust-runtime/protocol-test-helpers/src/lib.rs
index 1852c01f7f..a690ebe6ab 100644
--- a/rust-runtime/protocol-test-helpers/src/lib.rs
+++ b/rust-runtime/protocol-test-helpers/src/lib.rs
@@ -15,6 +15,39 @@ use std::fmt::{self, Debug};
use thiserror::Error;
use urlencoded::try_url_encoded_form_equivalent;
+/// Helper trait for tests for float comparisons
+///
+/// This trait differs in float's default `PartialEq` implementation by considering all `NaN` values to
+/// be equal.
+pub trait FloatEquals {
+ fn float_equals(&self, other: &Self) -> bool;
+}
+
+impl FloatEquals for f64 {
+ fn float_equals(&self, other: &Self) -> bool {
+ (self.is_nan() && other.is_nan()) || self.eq(other)
+ }
+}
+
+impl FloatEquals for f32 {
+ fn float_equals(&self, other: &Self) -> bool {
+ (self.is_nan() && other.is_nan()) || self.eq(other)
+ }
+}
+
+impl FloatEquals for Option
+where
+ T: FloatEquals,
+{
+ fn float_equals(&self, other: &Self) -> bool {
+ match (self, other) {
+ (Some(this), Some(other)) => this.float_equals(other),
+ (None, None) => true,
+ _else => false,
+ }
+ }
+}
+
#[derive(Debug, PartialEq, Eq, Error)]
pub enum ProtocolTestFailure {
#[error("missing query param: expected `{expected}`, found {found:?}")]
@@ -326,7 +359,7 @@ fn try_json_eq(actual: &str, expected: &str) -> Result<(), ProtocolTestFailure>
mod tests {
use crate::{
forbid_headers, forbid_query_params, require_headers, require_query_params, validate_body,
- validate_headers, validate_query_string, MediaType, ProtocolTestFailure,
+ validate_headers, validate_query_string, FloatEquals, MediaType, ProtocolTestFailure,
};
use http::Request;
@@ -472,4 +505,20 @@ mod tests {
validate_body(&expected, expected, MediaType::from("something/else"))
.expect("inputs matched exactly")
}
+
+ #[test]
+ fn test_float_equals() {
+ let a = f64::NAN;
+ let b = f64::NAN;
+ assert_ne!(a, b);
+ assert!(a.float_equals(&b));
+ assert!(!a.float_equals(&5_f64));
+
+ assert!(5.0.float_equals(&5.0));
+ assert!(!5.0.float_equals(&5.1));
+
+ assert!(f64::INFINITY.float_equals(&f64::INFINITY));
+ assert!(!f64::INFINITY.float_equals(&f64::NEG_INFINITY));
+ assert!(f64::NEG_INFINITY.float_equals(&f64::NEG_INFINITY));
+ }
}
diff --git a/rust-runtime/smithy-http/src/header.rs b/rust-runtime/smithy-http/src/header.rs
index eb44c926da..64c56ac495 100644
--- a/rust-runtime/smithy-http/src/header.rs
+++ b/rust-runtime/smithy-http/src/header.rs
@@ -8,18 +8,40 @@
use http::header::{HeaderName, ValueIter};
use http::HeaderValue;
use smithy_types::instant::Format;
+use smithy_types::primitive::Parse;
use smithy_types::Instant;
+use std::borrow::Cow;
use std::error::Error;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
-#[derive(Debug)]
-pub struct ParseError;
+#[derive(Debug, Eq, PartialEq)]
+#[non_exhaustive]
+pub struct ParseError {
+ message: Option>,
+}
+
+impl ParseError {
+ #[allow(clippy::new_without_default)]
+ pub fn new() -> Self {
+ Self { message: None }
+ }
+
+ pub fn new_with_message(message: impl Into>) -> Self {
+ Self {
+ message: Some(message.into()),
+ }
+ }
+}
impl Display for ParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
- write!(f, "Output failed to parse in headers")
+ write!(f, "Output failed to parse in headers")?;
+ if let Some(message) = &self.message {
+ write!(f, ". {}", message)?;
+ }
+ Ok(())
}
}
@@ -35,9 +57,13 @@ pub fn many_dates(
) -> Result, ParseError> {
let mut out = vec![];
for header in values {
- let mut header = header.to_str().map_err(|_| ParseError)?;
+ let mut header = header
+ .to_str()
+ .map_err(|_| ParseError::new_with_message("header was not valid utf-8 string"))?;
while !header.is_empty() {
- let (v, next) = Instant::read(header, format, ',').map_err(|_| ParseError)?;
+ let (v, next) = Instant::read(header, format, ',').map_err(|err| {
+ ParseError::new_with_message(format!("header could not be parsed as date: {}", err))
+ })?;
out.push(v);
header = next;
}
@@ -56,16 +82,36 @@ pub fn headers_for_prefix<'a>(
.map(move |h| (&h.as_str()[key.len()..], h))
}
+pub fn read_many_from_str(
+ values: ValueIter,
+) -> Result, ParseError> {
+ read_many(values, |v: &str| {
+ v.parse()
+ .map_err(|_err| ParseError::new_with_message("failed during FromString conversion"))
+ })
+}
+
+pub fn read_many_primitive(values: ValueIter) -> Result, ParseError> {
+ read_many(values, |v: &str| {
+ T::parse_smithy_primitive(v).map_err(|primitive| {
+ ParseError::new_with_message(format!(
+ "failed reading a list of primitives: {}",
+ primitive
+ ))
+ })
+ })
+}
+
/// Read many comma / header delimited values from HTTP headers for `FromStr` types
-pub fn read_many(values: ValueIter) -> Result, ParseError>
-where
- T: FromStr,
-{
+fn read_many(
+ values: ValueIter,
+ f: impl Fn(&str) -> Result,
+) -> Result, ParseError> {
let mut out = vec![];
for header in values {
let mut header = header.as_bytes();
while !header.is_empty() {
- let (v, next) = read_one::(&header)?;
+ let (v, next) = read_one(&header, &f)?;
out.push(v);
header = next;
}
@@ -83,10 +129,15 @@ pub fn one_or_none(
Some(v) => v,
None => return Ok(None),
};
- let value = std::str::from_utf8(first.as_bytes()).map_err(|_| ParseError)?;
+ let value = std::str::from_utf8(first.as_bytes())
+ .map_err(|_| ParseError::new_with_message("invalid utf-8"))?;
match values.next() {
- None => T::from_str(value.trim()).map_err(|_| ParseError).map(Some),
- Some(_) => Err(ParseError),
+ None => T::from_str(value.trim())
+ .map_err(|_| ParseError::new())
+ .map(Some),
+ Some(_) => Err(ParseError::new_with_message(
+ "expected a single value but found multiple",
+ )),
}
}
@@ -107,13 +158,14 @@ pub fn set_header_if_absent(
}
/// Read one comma delimited value for `FromStr` types
-fn read_one(s: &[u8]) -> Result<(T, &[u8]), ParseError>
-where
- T: FromStr,
-{
+fn read_one<'a, T>(
+ s: &'a [u8],
+ f: &impl Fn(&str) -> Result,
+) -> Result<(T, &'a [u8]), ParseError> {
let (head, rest) = split_at_delim(s);
- let head = std::str::from_utf8(head).map_err(|_| ParseError)?;
- Ok((T::from_str(head.trim()).map_err(|_| ParseError)?, rest))
+ let head = std::str::from_utf8(head)
+ .map_err(|_| ParseError::new_with_message("header was not valid utf8"))?;
+ Ok((f(head.trim())?, rest))
}
fn split_at_delim(s: &[u8]) -> (&[u8], &[u8]) {
@@ -128,13 +180,15 @@ fn then_delim(s: &[u8]) -> Result<&[u8], ParseError> {
} else if s.starts_with(b",") {
Ok(&s[1..])
} else {
- Err(ParseError)
+ Err(ParseError::new_with_message("expected delimiter `,`"))
}
}
#[cfg(test)]
mod test {
- use crate::header::{headers_for_prefix, read_many, set_header_if_absent, ParseError};
+ use crate::header::{
+ headers_for_prefix, read_many_primitive, set_header_if_absent, ParseError,
+ };
use std::collections::HashMap;
#[test]
@@ -153,6 +207,27 @@ mod test {
);
}
+ #[test]
+ fn parse_floats() {
+ let test_request = http::Request::builder()
+ .header("X-Float-Multi", "0.0,Infinity,-Infinity,5555.5")
+ .header("X-Float-Error", "notafloat")
+ .body(())
+ .unwrap();
+ assert_eq!(
+ read_many_primitive::(test_request.headers().get_all("X-Float-Multi").iter())
+ .expect("valid"),
+ vec![0.0, f32::INFINITY, f32::NEG_INFINITY, 5555.5]
+ );
+ assert_eq!(
+ read_many_primitive::(test_request.headers().get_all("X-Float-Error").iter())
+ .expect_err("invalid"),
+ ParseError::new_with_message(
+ "failed reading a list of primitives: failed to parse input as f32"
+ )
+ )
+ }
+
#[test]
fn read_many_bools() {
let test_request = http::Request::builder()
@@ -164,47 +239,50 @@ mod test {
.body(())
.unwrap();
assert_eq!(
- read_many::(test_request.headers().get_all("X-Bool-Multi").iter())
+ read_many_primitive::(test_request.headers().get_all("X-Bool-Multi").iter())
.expect("valid"),
vec![true, false, true]
);
assert_eq!(
- read_many::(test_request.headers().get_all("X-Bool").iter()).unwrap(),
+ read_many_primitive::(test_request.headers().get_all("X-Bool").iter()).unwrap(),
vec![true]
);
assert_eq!(
- read_many::(test_request.headers().get_all("X-Bool-Single").iter()).unwrap(),
+ read_many_primitive::(test_request.headers().get_all("X-Bool-Single").iter())
+ .unwrap(),
vec![true, false, true, true]
);
- read_many::(test_request.headers().get_all("X-Bool-Invalid").iter())
+ read_many_primitive::(test_request.headers().get_all("X-Bool-Invalid").iter())
.expect_err("invalid");
}
#[test]
- fn read_many_u16() {
+ fn check_read_many_i16() {
let test_request = http::Request::builder()
.header("X-Multi", "123,456")
.header("X-Multi", "789")
.header("X-Num", "777")
.header("X-Num-Invalid", "12ef3")
- .header("X-Num-Single", "1,2,3,4,5")
+ .header("X-Num-Single", "1,2,3,-4,5")
.body(())
.unwrap();
assert_eq!(
- read_many::(test_request.headers().get_all("X-Multi").iter()).expect("valid"),
+ read_many_primitive::(test_request.headers().get_all("X-Multi").iter())
+ .expect("valid"),
vec![123, 456, 789]
);
assert_eq!(
- read_many::(test_request.headers().get_all("X-Num").iter()).unwrap(),
+ read_many_primitive::(test_request.headers().get_all("X-Num").iter()).unwrap(),
vec![777]
);
assert_eq!(
- read_many::(test_request.headers().get_all("X-Num-Single").iter()).unwrap(),
- vec![1, 2, 3, 4, 5]
+ read_many_primitive::(test_request.headers().get_all("X-Num-Single").iter())
+ .unwrap(),
+ vec![1, 2, 3, -4, 5]
);
- read_many::(test_request.headers().get_all("X-Num-Invalid").iter())
+ read_many_primitive::(test_request.headers().get_all("X-Num-Invalid").iter())
.expect_err("invalid");
}
@@ -217,15 +295,14 @@ mod test {
.header("X-Prefix-C", "777")
.body(())
.unwrap();
- let resp: Result>, ParseError> =
+ let resp: Result>, ParseError> =
headers_for_prefix(test_request.headers(), "X-Prefix-")
.map(|(key, header_name)| {
let values = test_request.headers().get_all(header_name);
- read_many(values.iter()).map(|v| (key.to_string(), v))
+ read_many_primitive(values.iter()).map(|v| (key.to_string(), v))
})
.collect();
let resp = resp.expect("valid");
- println!("{:?}", resp);
- assert_eq!(resp.get("a"), Some(&vec![123_u16, 456_u16]));
+ assert_eq!(resp.get("a"), Some(&vec![123_i16, 456_i16]));
}
}
diff --git a/rust-runtime/smithy-http/src/label.rs b/rust-runtime/smithy-http/src/label.rs
index 48b61755c4..4a8a5f0584 100644
--- a/rust-runtime/smithy-http/src/label.rs
+++ b/rust-runtime/smithy-http/src/label.rs
@@ -9,14 +9,9 @@
use crate::urlencode::BASE_SET;
use percent_encoding::AsciiSet;
use smithy_types::Instant;
-use std::fmt::Debug;
const GREEDY: &AsciiSet = &BASE_SET.remove(b'/');
-pub fn fmt_default(t: T) -> String {
- format!("{:?}", t)
-}
-
pub fn fmt_string>(t: T, greedy: bool) -> String {
let uri_set = if greedy { GREEDY } else { BASE_SET };
percent_encoding::utf8_percent_encode(t.as_ref(), &uri_set).to_string()
diff --git a/rust-runtime/smithy-http/src/query.rs b/rust-runtime/smithy-http/src/query.rs
index 821376177e..59dc9cc73e 100644
--- a/rust-runtime/smithy-http/src/query.rs
+++ b/rust-runtime/smithy-http/src/query.rs
@@ -8,11 +8,6 @@ use percent_encoding::utf8_percent_encode;
/// Formatting values into the query string as specified in
/// [httpQuery](https://awslabs.github.io/smithy/1.0/spec/core/http-traits.html#httpquery-trait)
use smithy_types::Instant;
-use std::fmt::Debug;
-
-pub fn fmt_default(t: T) -> String {
- format!("{:?}", t)
-}
pub fn fmt_string>(t: T) -> String {
utf8_percent_encode(t.as_ref(), BASE_SET).to_string()
diff --git a/rust-runtime/smithy-json/Cargo.toml b/rust-runtime/smithy-json/Cargo.toml
index 5f0ca2d3a8..243bc1dc2e 100644
--- a/rust-runtime/smithy-json/Cargo.toml
+++ b/rust-runtime/smithy-json/Cargo.toml
@@ -5,8 +5,6 @@ authors = ["AWS Rust SDK Team ", "John DiSanti JsonTokenIterator<'a> {
offset,
value: if floating {
Number::Float(
- f64::from_str(&number_str).map_err(|_| self.error_at(start, InvalidNumber))?,
+ f64::from_str(&number_str)
+ .map_err(|_| self.error_at(start, InvalidNumber))
+ .and_then(|f| {
+ must_be_finite(f).map_err(|_| self.error_at(start, InvalidNumber))
+ })?,
)
} else if negative {
// If the negative value overflows, then stuff it into an f64
@@ -484,6 +488,22 @@ impl<'a> Iterator for JsonTokenIterator<'a> {
}
}
+fn must_be_finite(f: f64) -> Result {
+ if f.is_finite() {
+ Ok(f)
+ } else {
+ Err(())
+ }
+}
+
+fn must_not_be_finite(f: f64) -> Result {
+ if !f.is_finite() {
+ Ok(f)
+ } else {
+ Err(())
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::deserialize::token::test::{
diff --git a/rust-runtime/smithy-json/src/deserialize/token.rs b/rust-runtime/smithy-json/src/deserialize/token.rs
index 626d4b3456..c50bebb00f 100644
--- a/rust-runtime/smithy-json/src/deserialize/token.rs
+++ b/rust-runtime/smithy-json/src/deserialize/token.rs
@@ -9,7 +9,9 @@ use smithy_types::instant::Format;
use smithy_types::{base64, Blob, Document, Instant, Number};
use std::borrow::Cow;
+use crate::deserialize::must_not_be_finite;
pub use crate::escape::Error as EscapeError;
+use smithy_types::primitive::Parse;
use std::collections::HashMap;
use std::iter::Peekable;
@@ -151,9 +153,45 @@ macro_rules! expect_value_or_null_fn {
}
expect_value_or_null_fn!(expect_bool_or_null, ValueBool, bool, "Expects a [Token::ValueBool] or [Token::ValueNull], and returns the bool value if it's not null.");
-expect_value_or_null_fn!(expect_number_or_null, ValueNumber, Number, "Expects a [Token::ValueNumber] or [Token::ValueNull], and returns the [Number] value if it's not null.");
expect_value_or_null_fn!(expect_string_or_null, ValueString, EscapedStr, "Expects a [Token::ValueString] or [Token::ValueNull], and returns the [EscapedStr] value if it's not null.");
+/// Expects a [Token::ValueString], [Token::ValueNumber] or [Token::ValueNull].
+///
+/// If the value is a string, it MUST be `Infinity`, `-Infinity` or `Nan`.
+/// If the value is a number, it is returned directly
+pub fn expect_number_or_null(
+ token: Option, Error>>,
+) -> Result, Error> {
+ match token.transpose()? {
+ Some(Token::ValueNull { .. }) => Ok(None),
+ Some(Token::ValueNumber { value, .. }) => Ok(Some(value)),
+ Some(Token::ValueString { value, offset }) => match value.to_unescaped() {
+ Err(err) => Err(Error::new(
+ ErrorReason::Custom(format!("expected a valid string, escape was invalid: {}", err).into()), Some(offset.0))
+ ),
+ Ok(v) => f64::parse_smithy_primitive(v.as_ref())
+ // disregard the exact error
+ .map_err(|_|())
+ // only infinite / NaN can be used as strings
+ .and_then(must_not_be_finite)
+ .map(|float| Some(smithy_types::Number::Float(float)))
+ // convert to a helpful error
+ .map_err(|_| {
+ Error::new(
+ ErrorReason::Custom(Cow::Owned(format!(
+ "only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `{}`",
+ v
+ ))),
+ Some(offset.0),
+ )
+ }),
+ },
+ _ => Err(Error::custom(
+ "expected ValueString, ValueNumber, or ValueNull",
+ )),
+ }
+}
+
/// Expects a [Token::ValueString] or [Token::ValueNull]. If the value is a string, it interprets it as a base64 encoded [Blob] value.
pub fn expect_blob_or_null(token: Option, Error>>) -> Result, Error> {
Ok(match expect_string_or_null(token)? {
@@ -386,6 +424,15 @@ pub mod test {
))
}
+ #[test]
+ fn test_non_finite_floats() {
+ let mut tokens = json_token_iter(b"inf");
+ tokens
+ .next()
+ .expect("there is a token")
+ .expect_err("but it is invalid, ensure that Rust float boundary cases don't parse");
+ }
+
#[test]
fn mismatched_braces() {
// The skip_value function doesn't need to explicitly handle these cases since
@@ -466,9 +513,27 @@ pub mod test {
expect_number_or_null(value_number(0, Number::PosInt(5)))
);
assert_eq!(
- Err(Error::custom("expected ValueNumber or ValueNull")),
+ Err(Error::custom(
+ "expected ValueString, ValueNumber, or ValueNull"
+ )),
expect_number_or_null(value_bool(0, true))
);
+ assert_eq!(
+ Ok(Some(Number::Float(f64::INFINITY))),
+ expect_number_or_null(value_string(0, "Infinity"))
+ );
+ assert_eq!(
+ Err(Error::new(ErrorReason::Custom("only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `123`".into()), Some(0))),
+ expect_number_or_null(value_string(0, "123"))
+ );
+ match expect_number_or_null(value_string(0, "NaN")) {
+ Ok(Some(Number::Float(v))) if v.is_nan() => {
+ // ok
+ }
+ not_ok => {
+ panic!("expected nan, found: {:?}", not_ok)
+ }
+ }
}
#[test]
@@ -505,8 +570,14 @@ pub mod test {
Ok(Some(Instant::from_f64(1445412480.0))),
expect_timestamp_or_null(value_string(0, "2015-10-21T07:28:00Z"), Format::DateTime)
);
+ let err = Error::new(
+ ErrorReason::Custom(
+ "only `Infinity`, `-Infinity`, `NaN` can represent a float as a string but found `wrong`".into(),
+ ),
+ Some(0),
+ );
assert_eq!(
- Err(Error::custom("expected ValueNumber or ValueNull")),
+ Err(err),
expect_timestamp_or_null(value_string(0, "wrong"), Format::EpochSeconds)
);
assert_eq!(
diff --git a/rust-runtime/smithy-json/src/serialize.rs b/rust-runtime/smithy-json/src/serialize.rs
index 4eefce52e6..bf356e31f6 100644
--- a/rust-runtime/smithy-json/src/serialize.rs
+++ b/rust-runtime/smithy-json/src/serialize.rs
@@ -5,6 +5,7 @@
use crate::escape::escape_string;
use smithy_types::instant::Format;
+use smithy_types::primitive::Encoder;
use smithy_types::{Document, Instant, Number};
use std::borrow::Cow;
@@ -76,19 +77,18 @@ impl<'a> JsonValueWriter<'a> {
match value {
Number::PosInt(value) => {
// itoa::Buffer is a fixed-size stack allocation, so this is cheap
- self.output.push_str(itoa::Buffer::new().format(value));
+ self.output.push_str(Encoder::from(value).encode());
}
Number::NegInt(value) => {
- self.output.push_str(itoa::Buffer::new().format(value));
+ self.output.push_str(Encoder::from(value).encode());
}
Number::Float(value) => {
- // If the value is NaN, Infinity, or -Infinity
- if value.is_nan() || value.is_infinite() {
- self.output.push_str("null");
+ let mut encoder: Encoder = value.into();
+ // Nan / infinite values actually get written in quotes as a string value
+ if value.is_infinite() || value.is_nan() {
+ self.string_unchecked(encoder.encode())
} else {
- // ryu::Buffer is a fixed-size stack allocation, so this is cheap
- self.output
- .push_str(ryu::Buffer::new().format_finite(value));
+ self.output.push_str(encoder.encode())
}
}
}
@@ -394,18 +394,15 @@ mod tests {
assert_eq!("10000000000.0", format_test_number(Number::Float(1e10)));
assert_eq!("-1.2", format_test_number(Number::Float(-1.2)));
- // JSON doesn't support NaN, Infinity, or -Infinity, so we're matching
+ // Smithy has specific behavior for infinity & NaN
// the behavior of the serde_json crate in these cases.
+ assert_eq!("\"NaN\"", format_test_number(Number::Float(f64::NAN)));
assert_eq!(
- serde_json::to_string(&f64::NAN).unwrap(),
- format_test_number(Number::Float(f64::NAN))
- );
- assert_eq!(
- serde_json::to_string(&f64::INFINITY).unwrap(),
+ "\"Infinity\"",
format_test_number(Number::Float(f64::INFINITY))
);
assert_eq!(
- serde_json::to_string(&f64::NEG_INFINITY).unwrap(),
+ "\"-Infinity\"",
format_test_number(Number::Float(f64::NEG_INFINITY))
);
}
diff --git a/rust-runtime/smithy-query/Cargo.toml b/rust-runtime/smithy-query/Cargo.toml
index 35810b3f1a..256f49428b 100644
--- a/rust-runtime/smithy-query/Cargo.toml
+++ b/rust-runtime/smithy-query/Cargo.toml
@@ -5,7 +5,5 @@ authors = ["AWS Rust SDK Team ", "John DiSanti QueryValueWriter<'a> {
match value {
Number::PosInt(value) => {
// itoa::Buffer is a fixed-size stack allocation, so this is cheap
- self.string(itoa::Buffer::new().format(value));
+ self.string(Encoder::from(value).encode());
}
Number::NegInt(value) => {
- self.string(itoa::Buffer::new().format(value));
- }
- Number::Float(value) => {
- // If the value is NaN, Infinity, or -Infinity
- if value.is_nan() || value.is_infinite() {
- self.string("");
- } else {
- // ryu::Buffer is a fixed-size stack allocation, so this is cheap
- self.string(ryu::Buffer::new().format_finite(value));
- }
+ self.string(Encoder::from(value).encode());
}
+ Number::Float(value) => self.string(Encoder::from(value).encode()),
}
}
@@ -378,9 +371,9 @@ mod tests {
&Version=1.0\
&PosInt=5\
&NegInt=-5\
- &Infinity=\
- &NegInfinity=\
- &NaN=\
+ &Infinity=Infinity\
+ &NegInfinity=-Infinity\
+ &NaN=NaN\
&Floating=5.2\
",
out
diff --git a/rust-runtime/smithy-types/Cargo.toml b/rust-runtime/smithy-types/Cargo.toml
index 636bb28d18..0c08c47137 100644
--- a/rust-runtime/smithy-types/Cargo.toml
+++ b/rust-runtime/smithy-types/Cargo.toml
@@ -11,6 +11,8 @@ default = ["chrono-conversions"]
[dependencies]
chrono = { version = "0.4", default-features = false, features = [] }
+ryu = "1.0.5"
+itoa = "0.4.0"
[dev-dependencies]
base64 = "0.13.0"
diff --git a/rust-runtime/smithy-types/src/lib.rs b/rust-runtime/smithy-types/src/lib.rs
index 733d1fc2fe..1bcfd42c1d 100644
--- a/rust-runtime/smithy-types/src/lib.rs
+++ b/rust-runtime/smithy-types/src/lib.rs
@@ -5,6 +5,7 @@
pub mod base64;
pub mod instant;
+pub mod primitive;
pub mod retry;
use std::collections::HashMap;
diff --git a/rust-runtime/smithy-types/src/primitive.rs b/rust-runtime/smithy-types/src/primitive.rs
new file mode 100644
index 0000000000..65aef4d51e
--- /dev/null
+++ b/rust-runtime/smithy-types/src/primitive.rs
@@ -0,0 +1,276 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ * SPDX-License-Identifier: Apache-2.0.
+ */
+
+//! Utilities for formatting and parsing primitives
+//!
+//! Smithy protocols have specific behavior for serializing
+//! & deserializing floats, specifically:
+//! - NaN should be serialized as `NaN`
+//! - Positive infinity should be serialized as `Infinity`
+//! - Negative infinity should be serialized as `-Infinity`
+//!
+//! This module defines the [`Parse`](Parse) trait which
+//! enables parsing primitive values (numbers & booleans) that follow
+//! these rules and [`Encoder`](Encoder), a struct that enables
+//! allocation-free serialization.
+//!
+//! # Examples
+//! ## Parsing
+//! ```rust
+//! use smithy_types::primitive::Parse;
+//! let parsed = f64::parse_smithy_primitive("123.4").expect("valid float");
+//! ```
+//!
+//! ## Encoding
+//! ```
+//! use smithy_types::primitive::Encoder;
+//! assert_eq!("123.4", Encoder::from(123.4).encode());
+//! assert_eq!("Infinity", Encoder::from(f64::INFINITY).encode());
+//! assert_eq!("true", Encoder::from(true).encode());
+//! ```
+use crate::primitive::private::Sealed;
+use std::error::Error;
+use std::fmt::{Display, Formatter};
+use std::str::FromStr;
+
+/// An error during primitive parsing
+#[non_exhaustive]
+#[derive(Debug, Eq, PartialEq, Clone)]
+pub struct PrimitiveParseError(&'static str);
+impl Display for PrimitiveParseError {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "failed to parse input as {}", self.0)
+ }
+}
+impl Error for PrimitiveParseError {}
+
+/// Sealed trait for custom parsing of primitive types
+pub trait Parse: Sealed {
+ fn parse_smithy_primitive(input: &str) -> Result
+ where
+ Self: Sized;
+}
+
+mod private {
+ pub trait Sealed {}
+ impl Sealed for i8 {}
+ impl Sealed for i16 {}
+ impl Sealed for i32 {}
+ impl Sealed for i64 {}
+ impl Sealed for f32 {}
+ impl Sealed for f64 {}
+ impl Sealed for u64 {}
+ impl Sealed for bool {}
+}
+
+macro_rules! parse_from_str {
+ ($t: ty) => {
+ impl Parse for $t {
+ fn parse_smithy_primitive(input: &str) -> Result {
+ FromStr::from_str(input).map_err(|_| PrimitiveParseError(stringify!($t)))
+ }
+ }
+ };
+}
+
+parse_from_str!(bool);
+parse_from_str!(i8);
+parse_from_str!(i16);
+parse_from_str!(i32);
+parse_from_str!(i64);
+
+impl Parse for f32 {
+ fn parse_smithy_primitive(input: &str) -> Result {
+ float::parse_f32(input).map_err(|_| PrimitiveParseError("f32"))
+ }
+}
+
+impl Parse for f64 {
+ fn parse_smithy_primitive(input: &str) -> Result {
+ float::parse_f64(input).map_err(|_| PrimitiveParseError("f64"))
+ }
+}
+
+/// Primitive Type Encoder
+///
+/// This type implements `From` for all Smithy primitive types.
+#[non_exhaustive]
+pub enum Encoder {
+ #[non_exhaustive]
+ Bool(bool),
+ #[non_exhaustive]
+ I8(i8, itoa::Buffer),
+ #[non_exhaustive]
+ I16(i16, itoa::Buffer),
+ #[non_exhaustive]
+ I32(i32, itoa::Buffer),
+ #[non_exhaustive]
+ I64(i64, itoa::Buffer),
+ #[non_exhaustive]
+ U64(u64, itoa::Buffer),
+ #[non_exhaustive]
+ F32(f32, ryu::Buffer),
+ #[non_exhaustive]
+ F64(f64, ryu::Buffer),
+}
+
+impl Encoder {
+ pub fn encode(&mut self) -> &str {
+ match self {
+ Encoder::Bool(true) => "true",
+ Encoder::Bool(false) => "false",
+ Encoder::I8(v, buf) => buf.format(*v),
+ Encoder::I16(v, buf) => buf.format(*v),
+ Encoder::I32(v, buf) => buf.format(*v),
+ Encoder::I64(v, buf) => buf.format(*v),
+ Encoder::U64(v, buf) => buf.format(*v),
+ Encoder::F32(v, buf) => {
+ if v.is_nan() {
+ float::NAN
+ } else if *v == f32::INFINITY {
+ float::INFINITY
+ } else if *v == f32::NEG_INFINITY {
+ float::NEG_INFINITY
+ } else {
+ buf.format_finite(*v)
+ }
+ }
+ Encoder::F64(v, buf) => {
+ if v.is_nan() {
+ float::NAN
+ } else if *v == f64::INFINITY {
+ float::INFINITY
+ } else if *v == f64::NEG_INFINITY {
+ float::NEG_INFINITY
+ } else {
+ buf.format_finite(*v)
+ }
+ }
+ }
+ }
+}
+
+impl From for Encoder {
+ fn from(input: bool) -> Self {
+ Self::Bool(input)
+ }
+}
+
+impl From for Encoder {
+ fn from(input: i8) -> Self {
+ Self::I8(input, itoa::Buffer::new())
+ }
+}
+
+impl From for Encoder {
+ fn from(input: i16) -> Self {
+ Self::I16(input, itoa::Buffer::new())
+ }
+}
+
+impl From for Encoder {
+ fn from(input: i32) -> Self {
+ Self::I32(input, itoa::Buffer::new())
+ }
+}
+
+impl From for Encoder {
+ fn from(input: i64) -> Self {
+ Self::I64(input, itoa::Buffer::new())
+ }
+}
+
+impl From for Encoder {
+ fn from(input: u64) -> Self {
+ Self::U64(input, itoa::Buffer::new())
+ }
+}
+
+impl From for Encoder {
+ fn from(input: f32) -> Self {
+ Self::F32(input, ryu::Buffer::new())
+ }
+}
+
+impl From for Encoder {
+ fn from(input: f64) -> Self {
+ Self::F64(input, ryu::Buffer::new())
+ }
+}
+
+mod float {
+ use std::num::ParseFloatError;
+ pub const INFINITY: &str = "Infinity";
+ pub const NEG_INFINITY: &str = "-Infinity";
+ pub const NAN: &str = "NaN";
+
+ pub fn parse_f32(data: &str) -> Result {
+ match data {
+ INFINITY => Ok(f32::INFINITY),
+ NEG_INFINITY => Ok(f32::NEG_INFINITY),
+ NAN => Ok(f32::NAN),
+ other => other.parse::(),
+ }
+ }
+
+ pub fn parse_f64(data: &str) -> Result {
+ match data {
+ INFINITY => Ok(f64::INFINITY),
+ NEG_INFINITY => Ok(f64::NEG_INFINITY),
+ NAN => Ok(f64::NAN),
+ other => other.parse::(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::primitive::{Encoder, Parse};
+
+ #[test]
+ fn bool_format() {
+ assert_eq!(Encoder::from(true).encode(), "true");
+ assert_eq!(Encoder::from(false).encode(), "false");
+ let err = bool::parse_smithy_primitive("not a boolean").expect_err("should fail");
+ assert_eq!(err.0, "bool");
+ assert_eq!(bool::parse_smithy_primitive("true"), Ok(true));
+ assert_eq!(bool::parse_smithy_primitive("false"), Ok(false));
+ }
+
+ #[test]
+ fn float_format() {
+ assert_eq!(Encoder::from(55_f64).encode(), "55.0");
+ assert_eq!(Encoder::from(f64::INFINITY).encode(), "Infinity");
+ assert_eq!(Encoder::from(f32::INFINITY).encode(), "Infinity");
+ assert_eq!(Encoder::from(f32::NEG_INFINITY).encode(), "-Infinity");
+ assert_eq!(Encoder::from(f64::NEG_INFINITY).encode(), "-Infinity");
+ assert_eq!(Encoder::from(f32::NAN).encode(), "NaN");
+ assert_eq!(Encoder::from(f64::NAN).encode(), "NaN");
+ }
+
+ #[test]
+ fn float_parse() {
+ assert_eq!(f64::parse_smithy_primitive("1234.5"), Ok(1234.5));
+ assert!(f64::parse_smithy_primitive("NaN").unwrap().is_nan());
+ assert_eq!(
+ f64::parse_smithy_primitive("Infinity").unwrap(),
+ f64::INFINITY
+ );
+ assert_eq!(
+ f64::parse_smithy_primitive("-Infinity").unwrap(),
+ f64::NEG_INFINITY
+ );
+ assert_eq!(f32::parse_smithy_primitive("1234.5"), Ok(1234.5));
+ assert!(f32::parse_smithy_primitive("NaN").unwrap().is_nan());
+ assert_eq!(
+ f32::parse_smithy_primitive("Infinity").unwrap(),
+ f32::INFINITY
+ );
+ assert_eq!(
+ f32::parse_smithy_primitive("-Infinity").unwrap(),
+ f32::NEG_INFINITY
+ );
+ }
+}