Skip to content

Commit

Permalink
datastore: tidy up some byte slice conversions (#935)
Browse files Browse the repository at this point in the history
By adding a couple of `TryFrom<&[u8]>` implementations, we can make use
of `RowExt::get_bytea_and_convert` to get things like `TaskId`,
`ReportId` or other types that are simple wrappers around fixed-size
byte arrays from database rows.
  • Loading branch information
tgeoghegan authored Jan 24, 2023
1 parent aec3618 commit 0613aa0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 46 deletions.
52 changes: 13 additions & 39 deletions aggregator/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -937,13 +937,7 @@ impl<C: Clock> Transaction<'_, C> {

rows.into_iter()
.map(|row| {
let report_id_bytes: [u8; ReportId::LEN] = row
.get::<_, Vec<u8>>("report_id")
.try_into()
.map_err(|err| {
Error::DbState(format!("couldn't convert report_id value: {err:?}"))
})?;
let report_id = ReportId::from(report_id_bytes);
let report_id = row.get_bytea_and_convert::<ReportId>("report_id")?;
let time = Time::from_naive_date_time(&row.get("client_timestamp"));
Ok((report_id, time))
})
Expand Down Expand Up @@ -999,13 +993,7 @@ impl<C: Clock> Transaction<'_, C> {

rows.into_iter()
.map(|row| {
let report_id_bytes: [u8; ReportId::LEN] = row
.get::<_, Vec<u8>>("report_id")
.try_into()
.map_err(|err| {
Error::DbState(format!("couldn't convert report_id value: {0:?}", err))
})?;
let report_id = ReportId::from(report_id_bytes);
let report_id = row.get_bytea_and_convert::<ReportId>("report_id")?;
let time = Time::from_naive_date_time(&row.get("client_timestamp"));
let agg_param = A::AggregationParam::get_decoded(row.get("aggregation_param"))?;
Ok((report_id, time, agg_param))
Expand Down Expand Up @@ -1342,11 +1330,7 @@ impl<C: Clock> Transaction<'_, C> {
AggregationJobId::get_decoded(row.get("aggregation_job_id"))?;
let query_type = row.try_get::<_, Json<task::QueryType>>("query_type")?.0;
let vdaf = row.try_get::<_, Json<VdafInstance>>("vdaf")?.0;
let lease_token_bytes: [u8; LeaseToken::LEN] = row
.get::<_, Vec<u8>>("lease_token")
.try_into()
.map_err(|err| Error::DbState(format!("lease_token invalid: {:?}", err)))?;
let lease_token = LeaseToken::from(lease_token_bytes);
let lease_token = row.get_bytea_and_convert::<LeaseToken>("lease_token")?;
let lease_attempts = row.get_bigint_and_convert("lease_attempts")?;
Ok(Lease::new(
AcquiredAggregationJob::new(task_id, aggregation_job_id, query_type, vdaf),
Expand Down Expand Up @@ -1585,13 +1569,7 @@ impl<C: Clock> Transaction<'_, C> {
.await?
.into_iter()
.map(|row| {
let report_id_bytes: [u8; ReportId::LEN] = row
.get::<_, Vec<u8>>("report_id")
.try_into()
.map_err(|err| {
Error::DbState(format!("couldn't convert report_id value: {err:?}"))
})?;
let report_id = ReportId::from(report_id_bytes);
let report_id = row.get_bytea_and_convert::<ReportId>("report_id")?;
Self::report_aggregation_from_row(
vdaf,
role,
Expand Down Expand Up @@ -2190,11 +2168,7 @@ ORDER BY id DESC
let collect_job_id = row.get("collect_job_id");
let query_type = row.try_get::<_, Json<task::QueryType>>("query_type")?.0;
let vdaf = row.try_get::<_, Json<VdafInstance>>("vdaf")?.0;
let lease_token_bytes: [u8; LeaseToken::LEN] = row
.get::<_, Vec<u8>>("lease_token")
.try_into()
.map_err(|err| Error::DbState(format!("lease_token invalid: {:?}", err)))?;
let lease_token = LeaseToken::from(lease_token_bytes);
let lease_token = row.get_bytea_and_convert::<LeaseToken>("lease_token")?;
let lease_attempts = row.get_bigint_and_convert("lease_attempts")?;
Ok(Lease::new(
AcquiredCollectJob::new(task_id, collect_job_id, query_type, vdaf),
Expand Down Expand Up @@ -2274,11 +2248,7 @@ ORDER BY id DESC
let collect_job_id = row.get("collect_job_id");
let query_type = row.try_get::<_, Json<task::QueryType>>("query_type")?.0;
let vdaf = row.try_get::<_, Json<VdafInstance>>("vdaf")?.0;
let lease_token_bytes: [u8; LeaseToken::LEN] = row
.get::<_, Vec<u8>>("lease_token")
.try_into()
.map_err(|err| Error::DbState(format!("lease_token invalid: {:?}", err)))?;
let lease_token = LeaseToken::from(lease_token_bytes);
let lease_token = row.get_bytea_and_convert::<LeaseToken>("lease_token")?;
let lease_attempts = row.get_bigint_and_convert("lease_attempts")?;
Ok(Lease::new(
AcquiredCollectJob::new(task_id, collect_job_id, query_type, vdaf),
Expand Down Expand Up @@ -3599,9 +3569,13 @@ pub mod models {
}
}

impl From<[u8; Self::LEN]> for LeaseToken {
fn from(lease_token: [u8; Self::LEN]) -> Self {
Self(lease_token)
impl<'a> TryFrom<&'a [u8]> for LeaseToken {
type Error = &'static str;

fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
"byte slice has incorrect length for LeaseToken"
})?))
}
}

Expand Down
52 changes: 45 additions & 7 deletions messages/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,16 @@ impl From<[u8; Self::LEN]> for BatchId {
}
}

impl<'a> TryFrom<&'a [u8]> for BatchId {
type Error = &'static str;

fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
"byte slice has incorrect length for BatchId"
})?))
}
}

impl AsRef<[u8; Self::LEN]> for BatchId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
Expand Down Expand Up @@ -243,6 +253,16 @@ impl From<[u8; Self::LEN]> for ReportId {
}
}

impl<'a> TryFrom<&'a [u8]> for ReportId {
type Error = &'static str;

fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
"byte slice has incorrect length for ReportId"
})?))
}
}

impl AsRef<[u8; Self::LEN]> for ReportId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
Expand Down Expand Up @@ -300,6 +320,16 @@ impl From<[u8; Self::LEN]> for ReportIdChecksum {
}
}

impl<'a> TryFrom<&'a [u8]> for ReportIdChecksum {
type Error = &'static str;

fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
"byte slice has incorrect length for ReportIdChecksum"
})?))
}
}

impl AsRef<[u8]> for ReportIdChecksum {
fn as_ref(&self) -> &[u8] {
&self.0
Expand Down Expand Up @@ -488,11 +518,21 @@ impl Decode for TaskId {
}

impl From<[u8; Self::LEN]> for TaskId {
fn from(task_id: [u8; TaskId::LEN]) -> Self {
fn from(task_id: [u8; Self::LEN]) -> Self {
Self(task_id)
}
}

impl<'a> TryFrom<&'a [u8]> for TaskId {
type Error = &'static str;

fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
"byte slice has incorrect length for TaskId"
})?))
}
}

impl AsRef<[u8; Self::LEN]> for TaskId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
Expand Down Expand Up @@ -534,10 +574,8 @@ impl<'de> Visitor<'de> for TaskIdVisitor {
let decoded = URL_SAFE_NO_PAD
.decode(value)
.map_err(|_| E::custom("invalid base64url value"))?;
let byte_array: [u8; TaskId::LEN] = decoded
.try_into()
.map_err(|_| E::custom("incorrect TaskId length"))?;
Ok(TaskId::from(byte_array))

TaskId::try_from(decoded.as_slice()).map_err(|e| E::custom(e))
}
}

Expand Down Expand Up @@ -4062,11 +4100,11 @@ mod tests {
);
assert_de_tokens_error::<TaskId>(
&[Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")],
"incorrect TaskId length",
"byte slice has incorrect length for TaskId",
);
assert_de_tokens_error::<TaskId>(
&[Token::Str("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")],
"incorrect TaskId length",
"byte slice has incorrect length for TaskId",
);
}

Expand Down

0 comments on commit 0613aa0

Please sign in to comment.