diff --git a/.gitignore b/.gitignore index 4308d82..3ad270c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# Rust target/ **/*.rs.bk Cargo.lock + +# Editors +/.idea diff --git a/Cargo.toml b/Cargo.toml index 3025798..fa43455 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,11 +21,12 @@ hyper = "^0.13" hyper-rustls = { version = "^0.19", optional = true } hyper-openssl = { version = "^0.8", optional = true } bytes = "^0.5" +cache_control = "0.1.0" +jsonwebtoken = "^7" +futures = "0.3.4" serde = "^1.0" serde_derive = "^1.0" serde_json = "^1.0" -jsonwebtoken = "^7" -cache_control = "0.1.0" [dev-dependencies] tokio = { version = "0.2", features = ["full"] } diff --git a/src/client.rs b/src/client.rs index 70c13eb..a7146ec 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,124 +1,84 @@ -use hyper::{client::{Client as HyperClient, HttpConnector}}; -#[cfg(feature = "with-rustls")] -use hyper_rustls::HttpsConnector; +use bytes::buf::ext::BufExt; +use futures::future::{FutureExt, Shared}; +use hyper::client::{Client as HyperClient, HttpConnector}; #[cfg(feature = "with-openssl")] use hyper_openssl::HttpsConnector; -use serde; -use serde_json; -use bytes::buf::ext::BufExt; +#[cfg(feature = "with-rustls")] +use hyper_rustls::HttpsConnector; +use std::collections::btree_map::Range; use std::collections::BTreeMap; +use std::ops::{ + Bound, + Bound::{Included, Unbounded}, +}; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; -use std::ops::{Bound, Bound::{Included, Unbounded}}; -use std::collections::btree_map::Range; use crate::error::Error; use crate::token::IdInfo; +type HttpClient = HyperClient>; + pub struct Client { - client: HyperClient>, + client: HttpClient, + cache: Cache, pub audiences: Vec, pub hosted_domains: Vec, } -#[derive(Debug, Clone, Deserialize)] -struct CertsObject { - keys: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -struct Cert { - kid: String, - e: String, - kty: String, - alg: String, - n: String, - r#use: String, -} - -type Key = String; - -#[derive(Clone)] -pub struct CachedCerts { - keys: BTreeMap, - pub expiry: Option, -} - -impl CachedCerts { - pub fn new() -> Self { - Self { - keys: BTreeMap::new(), - expiry: None, - } - } - - fn certs_url() -> &'static str { - "https://www.googleapis.com/oauth2/v2/certs" - } - - fn get_range<'a>(&'a self, kid: &Option) -> Result, Error> { - match kid { - None => Ok(self.keys.range::, Bound<&String>)>((Unbounded, Unbounded))), - Some(kid) => { - if !self.keys.contains_key(kid) { - return Err(Error::InvalidKey); - } - Ok(self.keys.range::, Bound<&String>)>((Included(kid), Included(kid)))) - } - } - } - - /// Downloads the public Google certificates if it didn't do so already, or based on expiry of - /// their Cache-Control. Returns `true` if the certificates were updated. - pub async fn refresh_if_needed(&mut self) -> Result { - let check = match self.expiry { - None => true, - Some(expiry) => expiry <= Instant::now(), - }; - - if !check { - return Ok(false); - } - - let client = Client::new(); - let certs : CertsObject = client.get_any(Self::certs_url(), &mut self.expiry).await?; - self.keys = BTreeMap::new(); - - for cert in certs.keys { - self.keys.insert(cert.kid.clone(), cert); - } - - Ok(true) - } -} - impl Client { pub fn new() -> Client { #[cfg(feature = "with-rustls")] let ssl = HttpsConnector::new(); #[cfg(feature = "with-openssl")] let ssl = HttpsConnector::new().expect("unable to build HttpsConnector"); - let client = HyperClient::builder().http1_max_buf_size(0x2000).keep_alive(false).build(ssl); - Client { client, audiences: vec![], hosted_domains: vec![] } + let client = HyperClient::builder() + .http1_max_buf_size(0x2000) + .keep_alive(false) + .build(ssl); + Client { + client, + cache: Cache::new(), + audiences: vec![], + hosted_domains: vec![], + } } /// Verifies that the token is signed by Google's OAuth cerificate, /// and check that it has a valid issuer, audience, and hosted domain. /// /// Returns an error if the client has no configured audiences. - pub async fn verify(&self, id_token: &str, cached_certs: &CachedCerts) -> Result { - let unverified_header = jsonwebtoken::decode_header(&id_token)?; + pub async fn verify(&self, id_token: &str) -> Result { + let certs = self.cache.get_cached_or_refresh(&self.client).await?; + self.verify_with(id_token, &certs).await + } - use jsonwebtoken::{Algorithm, Validation, DecodingKey}; + /// Verifies the token using the same method as `Client::verify`, but allows you to manually + /// manage the lifetime of the certificates. + /// + /// This allows you to control when your application performs a network request (for example, + /// to avoid network requests after dropping OS capabilities or outside of initialization). + /// + /// It is recommended to use `Client::verify` directly instead. + pub async fn verify_with( + &self, + id_token: &str, + cached_certs: &Certificates, + ) -> Result { + use jsonwebtoken::{Algorithm, DecodingKey, Validation}; - for (_, cert) in cached_certs.get_range(&unverified_header.kid)? { - // Check each certificate + let unverified_header = jsonwebtoken::decode_header(&id_token)?; + // Check each certificate + for (_, cert) in cached_certs.get_range(&unverified_header.kid)? { let mut validation = Validation::new(Algorithm::RS256); validation.set_audience(&self.audiences); - let token_data = jsonwebtoken::decode::(&id_token, + let token_data = jsonwebtoken::decode::( + &id_token, &DecodingKey::from_rsa_components(&cert.n, &cert.e), - &validation)?; + &validation, + )?; token_data.claims.verify(self)?; @@ -136,14 +96,16 @@ impl Client { /// This is NOT the recommended way to use the library, but can be used in combination with /// [IdInfo.verify](https://docs.rs/google-signin/latest/google_signin/struct.IdInfo.html#impl) /// for applications with more complex error-handling requirements. - pub async fn get_slow_unverified(&self, id_token: &str) -> Result, Error> { - self.get_any(&format!("https://www.googleapis.com/oauth2/v3/tokeninfo?id_token={}", id_token), &mut None).await - } - - async fn get_any(&self, url: &str, cache: &mut Option) -> Result { + pub async fn get_slow_unverified( + &self, + id_token: &str, + ) -> Result, Error> { + let url = format!( + "https://www.googleapis.com/oauth2/v3/tokeninfo?id_token={}", + id_token + ); let url = url.parse().unwrap(); - let response = self.client.get(url).await.unwrap(); - + let response = self.client.get(url).await?; let status = response.status().as_u16(); match status { 200..=299 => {} @@ -151,21 +113,178 @@ impl Client { return Err(Error::InvalidToken); } } + let body = hyper::body::aggregate(response).await?; + let data = serde_json::from_reader(body.reader())?; + Ok(data) + } +} + +#[derive(Clone)] +struct Cache { + state: Arc>, +} + +impl Cache { + fn new() -> Cache { + Cache { + state: Arc::new(Mutex::new(RefreshState::Uninitialized)), + } + } - if let Some(value) = response.headers().get("Cache-Control") { - if let Ok(value) = value.to_str() { - if let Some(cc) = cache_control::CacheControl::from_value(value) { - if let Some(max_age) = cc.max_age { - let seconds = max_age.num_seconds(); - if seconds >= 0 { - *cache = Some(Instant::now() + Duration::from_secs(seconds as u64)); - } + async fn get_cached_or_refresh(&self, client: &HttpClient) -> Result, Error> { + // Acquire a lock in order to clone the Arc to the currently cached certificates, + // or initialize a new future but don't block on it until after releasing the lock. + let fut = { + let mut guard = self.state.lock().unwrap(); + let state: &mut RefreshState = &mut guard; + match state { + RefreshState::Expired(fut) => fut.clone(), + RefreshState::Uninitialized => { + let fut = Cache::refresh_with(self.state.clone(), client.clone()) + .boxed_local() + .shared(); + *state = RefreshState::Expired(fut.clone()); + fut + } + RefreshState::Ready(certs) => { + if certs.is_expired() { + let fut = Cache::refresh_with(self.state.clone(), client.clone()) + .boxed_local() + .shared(); + *state = RefreshState::Expired(fut.clone()); + fut + } else { + let certs = Arc::clone(certs); + (async move { Ok(certs) }).boxed_local().shared() } } } + }; + + fut.await + } + + async fn refresh_with( + state: Arc>, + client: HttpClient, + ) -> Result, Error> { + let certs = Certificates::get_with_http_client(&client).await?; + let certs = Arc::new(certs); + let mut state = state.lock().unwrap(); + *state = RefreshState::Ready(Arc::clone(&certs)); + Ok(certs) + } +} + +type Promise = + std::pin::Pin, Error>>>>; + +enum RefreshState { + Ready(Arc), + Expired(Shared), + Uninitialized, +} + +#[derive(Clone, Debug, Deserialize)] +struct CertsObject { + keys: Vec, +} + +#[derive(Clone, Debug, Deserialize)] +struct Cert { + kid: String, + e: String, + kty: String, + alg: String, + n: String, + r#use: String, +} + +type Key = String; + +#[derive(Clone)] +pub struct Certificates { + keys: BTreeMap, + pub expiry: Option, +} + +impl Certificates { + pub fn new() -> Self { + Self { + keys: BTreeMap::new(), + expiry: None, } + } + + /// Downloads the public Google certificates even if the current certificates have not expired. + pub async fn get(client: &Client) -> Result { + Certificates::get_with_http_client(&client.client).await + } + async fn get_with_http_client(client: &HttpClient) -> Result { + const URL: &str = "https://www.googleapis.com/oauth2/v2/certs"; + + let url = URL.parse().unwrap(); + let response = client.get(url).await?; + let expiry = response + .headers() + .get("Cache-Control") + .and_then(|val| val.to_str().ok()) + .and_then(cache_control::CacheControl::from_value) + .and_then(|cc| cc.max_age) + .and_then(|max_age| { + let seconds = max_age.num_seconds(); + if seconds >= 0 { + Some(Instant::now() + Duration::from_secs(seconds as u64)) + } else { + None + } + }); let body = hyper::body::aggregate(response).await?; - Ok(serde_json::from_reader(body.reader())?) + let certs: CertsObject = serde_json::from_reader(body.reader())?; + let mut keys = BTreeMap::new(); + for cert in certs.keys { + keys.insert(cert.kid.clone(), cert); + } + Ok(Certificates { keys, expiry }) + } + + /// Downloads the public Google certificates if it didn't do so already, or based on expiry of + /// their Cache-Control. Returns `true` if the certificates were updated. + pub async fn refresh(&mut self) -> Result { + if !self.is_expired() { + return Ok(false); + } + + let client = Client::new(); + *self = Certificates::get(&client).await?; + Ok(true) + } + + /// Returns true if all cached certificates are expired (or if there are no cached certificates). + pub fn is_expired(&self) -> bool { + match self.expiry { + Some(expiry) => expiry <= Instant::now(), + None => true, + } + } + + fn get_range<'a>(&'a self, kid: &Option) -> Result, Error> { + match kid { + None => Ok(self + .keys + .range::, Bound<&String>)>((Unbounded, Unbounded))), + Some(kid) => { + if !self.keys.contains_key(kid) { + return Err(Error::InvalidKey); + } + Ok(self + .keys + .range::, Bound<&String>)>(( + Included(kid), + Included(kid), + ))) + } + } } } diff --git a/src/error.rs b/src/error.rs index 9114462..d782de5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,13 +1,13 @@ -use std::{self, fmt, io}; use hyper; use serde_json; +use std::{self, fmt, io, sync::Arc}; /// A network or validation error -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum Error { - DecodeJson(serde_json::Error), - JSONWebToken(jsonwebtoken::errors::Error), - ConnectionError(Box), + DecodeJson(Arc), + JSONWebToken(Arc), + ConnectionError(Arc), InvalidKey, InvalidToken, InvalidIssuer, @@ -31,7 +31,7 @@ impl std::error::Error for Error { fn cause(&self) -> Option<&dyn std::error::Error> { match *self { - Error::DecodeJson(ref err) => Some(err), + Error::DecodeJson(ref err) => Some(&**err), Error::ConnectionError(ref err) => Some(&**err), _ => None, } @@ -48,31 +48,33 @@ impl fmt::Display for Error { Error::InvalidToken => f.write_str("Token was not recognized by google"), Error::InvalidIssuer => f.write_str("Token was not issued by google"), Error::InvalidAudience => f.write_str("Token is for a different google application"), - Error::InvalidHostedDomain => f.write_str("User is not a member of the hosted domain(s)"), + Error::InvalidHostedDomain => { + f.write_str("User is not a member of the hosted domain(s)") + } } } } impl From for Error { fn from(err: io::Error) -> Error { - Error::ConnectionError(Box::new(err)) + Error::ConnectionError(Arc::new(err)) } } impl From for Error { fn from(err: hyper::Error) -> Error { - Error::ConnectionError(Box::new(err)) + Error::ConnectionError(Arc::new(err)) } } impl From for Error { fn from(err: serde_json::Error) -> Error { - Error::DecodeJson(err) + Error::DecodeJson(Arc::new(err)) } } impl From for Error { fn from(err: jsonwebtoken::errors::Error) -> Error { - Error::JSONWebToken(err) + Error::JSONWebToken(Arc::new(err)) } } diff --git a/src/lib.rs b/src/lib.rs index a51663c..13e600f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,9 +28,8 @@ //! } //! //! # async fn handler(client: &google_signin::Client, request: GoogleLogin) { -//! let mut certs_cache = google_signin::CachedCerts::new(); //! // Recommended: Let the crate handle everything for you -//! let id_info = client.verify(&request.token, &mut certs_cache).await.expect("Expected token to be valid"); +//! let id_info = client.verify(&request.token).await.expect("Expected token to be valid"); //! println!("Success! Signed-in as {}", id_info.sub); //! //! // Alternative: Inspect the ID before verifying it @@ -41,10 +40,10 @@ //! ``` extern crate hyper; -#[cfg(feature = "with-rustls")] -extern crate hyper_rustls; #[cfg(feature = "with-openssl")] extern crate hyper_openssl; +#[cfg(feature = "with-rustls")] +extern crate hyper_rustls; extern crate serde; #[macro_use] extern crate serde_derive; @@ -54,8 +53,7 @@ mod client; mod error; mod token; -pub use client::Client; -pub use client::CachedCerts; +pub use client::{Certificates, Client}; pub use error::Error; pub use token::IdInfo; diff --git a/src/token.rs b/src/token.rs index b70f910..5779c95 100644 --- a/src/token.rs +++ b/src/token.rs @@ -2,7 +2,7 @@ use crate::client::Client; use crate::error::Error; #[derive(Debug, Deserialize)] -pub struct IdInfo { +pub struct IdInfo { /// These six fields are included in all Google ID Tokens. pub iss: String, pub sub: String, @@ -17,7 +17,7 @@ pub struct IdInfo { /// These seven fields are only included when the user has granted the "profile" and /// "email" OAuth scopes to the application. pub email: Option, - pub email_verified: Option, // eg. "true" (but unusually as a string) + pub email_verified: Option, // eg. "true" (but unusually as a string) pub name: Option, pub picture: Option, pub given_name: Option, @@ -33,7 +33,9 @@ impl IdInfo { // Check the id was authorized by google match self.iss.as_str() { "accounts.google.com" | "https://accounts.google.com" => {} - _ => { return Err(Error::InvalidIssuer); } + _ => { + return Err(Error::InvalidIssuer); + } } // Check the token belongs to the application(s) @@ -45,7 +47,9 @@ impl IdInfo { if client.hosted_domains.len() > 0 { match self.hd { Some(ref domain) if client.hosted_domains.contains(domain) => {} - _ => { return Err(Error::InvalidHostedDomain); } + _ => { + return Err(Error::InvalidHostedDomain); + } } }