From 459811d3aceecb2bc677311c80677a9f014d395f Mon Sep 17 00:00:00 2001 From: "David A. Ramos" Date: Sat, 17 Aug 2024 15:16:53 -0700 Subject: [PATCH] Accept null device code interval Section 3.2 of RFC 8628 states that the Device Authorization Response's `interval` field is optional. Previously, this crate accepted responses without an `interval` field, but rejected responses in which the `interval` field was `null`. This change treats `null` values as if they were omitted. Fixes #278. --- src/devicecode.rs | 92 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/src/devicecode.rs b/src/devicecode.rs index 53e9140..543dcc2 100644 --- a/src/devicecode.rs +++ b/src/devicecode.rs @@ -475,6 +475,37 @@ fn default_devicecode_interval() -> u64 { 5 } +fn deserialize_devicecode_interval<'de, D>(deserializer: D) -> Result +where + D: serde::de::Deserializer<'de>, +{ + struct NumOrNull; + + impl<'de> serde::de::Visitor<'de> for NumOrNull { + type Value = u64; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("non-negative integer or null") + } + + fn visit_u64(self, v: u64) -> Result + where + E: Error, + { + Ok(v) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(default_devicecode_interval()) + } + } + + deserializer.deserialize_any(NumOrNull) +} + /// Trait for adding extra fields to the `DeviceAuthorizationResponse`. pub trait ExtraDeviceAuthorizationFields: DeserializeOwned + Debug + Serialize {} @@ -516,7 +547,10 @@ where /// The minimum amount of time in seconds that the client SHOULD wait /// between polling requests to the token endpoint. If no value is /// provided, clients MUST use 5 as the default. - #[serde(default = "default_devicecode_interval")] + #[serde( + default = "default_devicecode_interval", + deserialize_with = "deserialize_devicecode_interval" + )] interval: u64, #[serde(bound = "EF: ExtraDeviceAuthorizationFields", flatten)] @@ -674,10 +708,12 @@ where #[cfg(test)] mod tests { use crate::basic::BasicTokenType; + use crate::devicecode::default_devicecode_interval; use crate::tests::{mock_http_client, mock_http_client_success_fail, new_client}; use crate::{ - DeviceAuthorizationUrl, DeviceCodeErrorResponse, DeviceCodeErrorResponseType, - RequestTokenError, Scope, StandardDeviceAuthorizationResponse, TokenResponse, + DeviceAuthorizationResponse, DeviceAuthorizationUrl, DeviceCodeErrorResponse, + DeviceCodeErrorResponseType, EmptyExtraDeviceAuthorizationFields, RequestTokenError, Scope, + StandardDeviceAuthorizationResponse, TokenResponse, }; use chrono::{DateTime, Utc}; @@ -1120,4 +1156,54 @@ mod tests { _ => unreachable!("Error should be ExpiredToken"), } } + + #[test] + fn test_device_auth_response_default_interval() { + let response: DeviceAuthorizationResponse = + serde_json::from_str( + r#"{ + "device_code": "12345", + "verification_uri": "https://verify/here", + "user_code": "abcde", + "expires_in": 300 + }"#, + ) + .unwrap(); + + assert_eq!(response.interval, default_devicecode_interval()); + } + + #[test] + fn test_device_auth_response_non_default_interval() { + let response: DeviceAuthorizationResponse = + serde_json::from_str( + r#"{ + "device_code": "12345", + "verification_uri": "https://verify/here", + "user_code": "abcde", + "expires_in": 300, + "interval": 10 + }"#, + ) + .unwrap(); + + assert_eq!(response.interval, 10); + } + + #[test] + fn test_device_auth_response_null_interval() { + let response: DeviceAuthorizationResponse = + serde_json::from_str( + r#"{ + "device_code": "12345", + "verification_uri": "https://verify/here", + "user_code": "abcde", + "expires_in": 300, + "interval": null + }"#, + ) + .unwrap(); + + assert_eq!(response.interval, default_devicecode_interval()); + } }