diff --git a/aggregator/src/datastore.rs b/aggregator/src/datastore.rs index 01b7cadb1..41988322d 100644 --- a/aggregator/src/datastore.rs +++ b/aggregator/src/datastore.rs @@ -937,13 +937,7 @@ impl Transaction<'_, C> { rows.into_iter() .map(|row| { - let report_id_bytes: [u8; ReportId::LEN] = row - .get::<_, Vec>("report_id") - .try_into() - .map_err(|err| { - Error::DbState(format!("couldn't convert report_id value: {err:?}")) - })?; - let report_id = ReportId::from(report_id_bytes); + let report_id = row.get_bytea_and_convert::("report_id")?; let time = Time::from_naive_date_time(&row.get("client_timestamp")); Ok((report_id, time)) }) @@ -999,13 +993,7 @@ impl Transaction<'_, C> { rows.into_iter() .map(|row| { - let report_id_bytes: [u8; ReportId::LEN] = row - .get::<_, Vec>("report_id") - .try_into() - .map_err(|err| { - Error::DbState(format!("couldn't convert report_id value: {0:?}", err)) - })?; - let report_id = ReportId::from(report_id_bytes); + let report_id = row.get_bytea_and_convert::("report_id")?; let time = Time::from_naive_date_time(&row.get("client_timestamp")); let agg_param = A::AggregationParam::get_decoded(row.get("aggregation_param"))?; Ok((report_id, time, agg_param)) @@ -1342,11 +1330,7 @@ impl Transaction<'_, C> { AggregationJobId::get_decoded(row.get("aggregation_job_id"))?; let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; - let lease_token_bytes: [u8; LeaseToken::LEN] = row - .get::<_, Vec>("lease_token") - .try_into() - .map_err(|err| Error::DbState(format!("lease_token invalid: {:?}", err)))?; - let lease_token = LeaseToken::from(lease_token_bytes); + let lease_token = row.get_bytea_and_convert::("lease_token")?; let lease_attempts = row.get_bigint_and_convert("lease_attempts")?; Ok(Lease::new( AcquiredAggregationJob::new(task_id, aggregation_job_id, query_type, vdaf), @@ -1585,13 +1569,7 @@ impl Transaction<'_, C> { .await? .into_iter() .map(|row| { - let report_id_bytes: [u8; ReportId::LEN] = row - .get::<_, Vec>("report_id") - .try_into() - .map_err(|err| { - Error::DbState(format!("couldn't convert report_id value: {err:?}")) - })?; - let report_id = ReportId::from(report_id_bytes); + let report_id = row.get_bytea_and_convert::("report_id")?; Self::report_aggregation_from_row( vdaf, role, @@ -2190,11 +2168,7 @@ ORDER BY id DESC let collect_job_id = row.get("collect_job_id"); let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; - let lease_token_bytes: [u8; LeaseToken::LEN] = row - .get::<_, Vec>("lease_token") - .try_into() - .map_err(|err| Error::DbState(format!("lease_token invalid: {:?}", err)))?; - let lease_token = LeaseToken::from(lease_token_bytes); + let lease_token = row.get_bytea_and_convert::("lease_token")?; let lease_attempts = row.get_bigint_and_convert("lease_attempts")?; Ok(Lease::new( AcquiredCollectJob::new(task_id, collect_job_id, query_type, vdaf), @@ -2274,11 +2248,7 @@ ORDER BY id DESC let collect_job_id = row.get("collect_job_id"); let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; - let lease_token_bytes: [u8; LeaseToken::LEN] = row - .get::<_, Vec>("lease_token") - .try_into() - .map_err(|err| Error::DbState(format!("lease_token invalid: {:?}", err)))?; - let lease_token = LeaseToken::from(lease_token_bytes); + let lease_token = row.get_bytea_and_convert::("lease_token")?; let lease_attempts = row.get_bigint_and_convert("lease_attempts")?; Ok(Lease::new( AcquiredCollectJob::new(task_id, collect_job_id, query_type, vdaf), @@ -3599,9 +3569,13 @@ pub mod models { } } - impl From<[u8; Self::LEN]> for LeaseToken { - fn from(lease_token: [u8; Self::LEN]) -> Self { - Self(lease_token) + impl<'a> TryFrom<&'a [u8]> for LeaseToken { + type Error = &'static str; + + fn try_from(value: &[u8]) -> Result { + Ok(Self(value.try_into().map_err(|_| { + "byte slice has incorrect length for LeaseToken" + })?)) } } diff --git a/messages/src/lib.rs b/messages/src/lib.rs index 68aeca761..de8395eab 100644 --- a/messages/src/lib.rs +++ b/messages/src/lib.rs @@ -186,6 +186,16 @@ impl From<[u8; Self::LEN]> for BatchId { } } +impl<'a> TryFrom<&'a [u8]> for BatchId { + type Error = &'static str; + + fn try_from(value: &[u8]) -> Result { + Ok(Self(value.try_into().map_err(|_| { + "byte slice has incorrect length for BatchId" + })?)) + } +} + impl AsRef<[u8; Self::LEN]> for BatchId { fn as_ref(&self) -> &[u8; Self::LEN] { &self.0 @@ -243,6 +253,16 @@ impl From<[u8; Self::LEN]> for ReportId { } } +impl<'a> TryFrom<&'a [u8]> for ReportId { + type Error = &'static str; + + fn try_from(value: &[u8]) -> Result { + Ok(Self(value.try_into().map_err(|_| { + "byte slice has incorrect length for ReportId" + })?)) + } +} + impl AsRef<[u8; Self::LEN]> for ReportId { fn as_ref(&self) -> &[u8; Self::LEN] { &self.0 @@ -300,6 +320,16 @@ impl From<[u8; Self::LEN]> for ReportIdChecksum { } } +impl<'a> TryFrom<&'a [u8]> for ReportIdChecksum { + type Error = &'static str; + + fn try_from(value: &[u8]) -> Result { + Ok(Self(value.try_into().map_err(|_| { + "byte slice has incorrect length for ReportIdChecksum" + })?)) + } +} + impl AsRef<[u8]> for ReportIdChecksum { fn as_ref(&self) -> &[u8] { &self.0 @@ -488,11 +518,21 @@ impl Decode for TaskId { } impl From<[u8; Self::LEN]> for TaskId { - fn from(task_id: [u8; TaskId::LEN]) -> Self { + fn from(task_id: [u8; Self::LEN]) -> Self { Self(task_id) } } +impl<'a> TryFrom<&'a [u8]> for TaskId { + type Error = &'static str; + + fn try_from(value: &[u8]) -> Result { + Ok(Self(value.try_into().map_err(|_| { + "byte slice has incorrect length for TaskId" + })?)) + } +} + impl AsRef<[u8; Self::LEN]> for TaskId { fn as_ref(&self) -> &[u8; Self::LEN] { &self.0 @@ -534,10 +574,8 @@ impl<'de> Visitor<'de> for TaskIdVisitor { let decoded = URL_SAFE_NO_PAD .decode(value) .map_err(|_| E::custom("invalid base64url value"))?; - let byte_array: [u8; TaskId::LEN] = decoded - .try_into() - .map_err(|_| E::custom("incorrect TaskId length"))?; - Ok(TaskId::from(byte_array)) + + TaskId::try_from(decoded.as_slice()).map_err(|e| E::custom(e)) } } @@ -4062,11 +4100,11 @@ mod tests { ); assert_de_tokens_error::( &[Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")], - "incorrect TaskId length", + "byte slice has incorrect length for TaskId", ); assert_de_tokens_error::( &[Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")], - "incorrect TaskId length", + "byte slice has incorrect length for TaskId", ); }