diff --git a/mbedtls/src/rng/mod.rs b/mbedtls/src/rng/mod.rs index 31270bf16..edb620e93 100644 --- a/mbedtls/src/rng/mod.rs +++ b/mbedtls/src/rng/mod.rs @@ -45,7 +45,7 @@ pub unsafe extern "C" fn mbedtls_psa_external_get_random( _user_data: *mut mbedtls_sys::types::raw_types::c_void, data: *mut c_uchar, len: size_t, - olen: * mut size_t) -> mbedtls_sys::types::int32_t { - olen = &mut len as * mut size_t; - rng_call(data, len) + _olen: * mut size_t) -> mbedtls_sys::types::int32_t { + *_olen = len; + self::rdrand::rng_call(data, len) } \ No newline at end of file diff --git a/mbedtls/tests/hyper.rs b/mbedtls/tests/hyper.rs index dec61a10f..4040ddf30 100644 --- a/mbedtls/tests/hyper.rs +++ b/mbedtls/tests/hyper.rs @@ -172,6 +172,7 @@ mod tests { use mbedtls::ssl::CipherSuite::*; use std::io::Write; use mbedtls::ssl::TicketContext; + use mbedtls::Error as MbedtlsError; use ntest::test_case; @@ -181,6 +182,8 @@ mod tests { #[cfg(target_env = "sgx")] use mbedtls::rng::{Rdrand}; + static OUTBOUND_REQUEST_RETRY_TIMES: usize = 3; + #[cfg(not(target_env = "sgx"))] pub fn rng_new() -> Arc { let entropy = Arc::new(OsEntropy::new()); @@ -193,6 +196,27 @@ mod tests { Arc::new(Rdrand) } + fn request_with_retry(client: Arc, url: &str, body: Option<&str>, expected_status: StatusCode, retry_times: usize) + { + let mut try_times = 0; + while try_times < retry_times { + let mut rq = client.head(url); + if let Some(body_str) = body { + rq = rq.body(body_str); + } + match rq.send() { + Ok(response) => { + assert_eq!(response.status, expected_status); + return; + }, + Err(hyper::Error::Ssl(err)) if err.downcast_ref::() == Some(&MbedtlsError::SslHwAccelFailed) => try_times += 1, + Err(err) => assert!(false, "{:?}", err), + } + std::thread::sleep(core::time::Duration::from_millis(100)); + } + assert!(false, "Try request to {} with expecting {:?} exceed max try times {}", url, expected_status, retry_times); + } + #[test_case(0, test_name="test_simple_request_tls1_2")] #[test_case(1, test_name="test_simple_request_tls1_3")] fn test_simple_request(ver_int : u32) { @@ -216,9 +240,7 @@ mod tests { let connector = HttpsConnector::new(ssl); let client = hyper::Client::with_connector(Pool::with_connector(Default::default(), connector)); - let response = client.head("https://www.google.com/").send().unwrap(); - - assert_eq!(response.status, hyper::status::StatusCode::Ok); + request_with_retry(client.into(), "https://www.google.com/", None, hyper::status::StatusCode::Ok, OUTBOUND_REQUEST_RETRY_TIMES); } @@ -246,16 +268,13 @@ mod tests { let ssl = MbedSSLClient::new(Arc::new(config), true); let client1 = hyper::Client::with_connector(Pool::with_connector(Default::default(), HttpsConnector::new(ssl.clone()))); - let response = client1.head("https://www.google.com/").send().unwrap(); - assert_eq!(response.status, hyper::status::StatusCode::Ok); + request_with_retry(client1.into(), "https://cloud.google.com/", None, hyper::status::StatusCode::Ok, OUTBOUND_REQUEST_RETRY_TIMES); let client2 = hyper::Client::with_connector(Pool::with_connector(Default::default(), HttpsConnector::new(ssl.clone()))); - let response = client2.head("https://www.google.com/").send().unwrap(); - assert_eq!(response.status, hyper::status::StatusCode::Ok); + request_with_retry(client2.into(), "https://www.youtube.com/", None, hyper::status::StatusCode::Ok, OUTBOUND_REQUEST_RETRY_TIMES); let client3 = hyper::Client::with_connector(Pool::with_connector(Default::default(), HttpsConnector::new(ssl.clone()))); - let response = client3.head("https://www.google.com/").send().unwrap(); - assert_eq!(response.status, hyper::status::StatusCode::Ok); + request_with_retry(client3.into(), "https://www.android.com/", None, hyper::status::StatusCode::Ok, OUTBOUND_REQUEST_RETRY_TIMES); } #[test_case(0, test_name="test_hyper_multithread_tls1_2")] @@ -284,12 +303,11 @@ mod tests { let clone1 = client.clone(); let clone2 = client.clone(); let t1 = std::thread::spawn(move || { - let response = clone1.head("https://google.com").send().unwrap(); - assert_eq!(response.status, hyper::status::StatusCode::Ok); + request_with_retry(clone1, "https://www.android.com/", None, hyper::status::StatusCode::Ok, OUTBOUND_REQUEST_RETRY_TIMES); }); let t2 = std::thread::spawn(move || { - let response = clone2.post("https://google.com").body("foo=bar").send().unwrap(); + let response = clone2.post("https://www.google.com").body("foo=bar").send().unwrap(); assert_eq!(response.status, hyper::status::StatusCode::MethodNotAllowed); });