diff --git a/Cargo.toml b/Cargo.toml index c553cb5..b83cf12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,11 +12,13 @@ ark-bls12-381 = { version = "0.4.0", default-features = false } ark-ec = { version = "0.4.0", default-features = false } ark-serialize = { version = "0.4.0", default-features = false } ciborium = { version = "0.2.2", default-features = false } -indexmap = { version = "2.1", default-features = false } +indexmap = { version = "2.1", default-features = false, features = ["serde"] } proof-of-sql = { version = "0.28.6", default-features = false } proof-of-sql-parser = { version = "0.28.6", default-features = false } rand = { version = "0.8.0", optional = true } serde = { version = "1.0", default-features = false } +serde_with = { version = "3.11.0", default-features = false, features = ["macros", "alloc", "indexmap_2"] } +snafu = { version = "0.8.0", default-features = false } [dev-dependencies] ark-std = { version = "0.4.0" } @@ -29,6 +31,8 @@ std = [ "serde/std", "ciborium/std", "proof-of-sql/std", + "serde_with/std", + "snafu/std", ] test = [ "std", diff --git a/src/errors.rs b/src/errors.rs index 688448a..23b3e9e 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -14,7 +14,9 @@ // limitations under the License. /// This module defines errors used across the verification library. -#[derive(Debug, PartialEq)] +use snafu::Snafu; + +#[derive(Debug, Snafu)] pub enum VerifyError { /// Provided data has invalid public inputs. InvalidInput, diff --git a/src/serde.rs b/src/serde.rs index 4e7dc9c..0f8e42e 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::errors::VerifyError; use alloc::{string::String, vec::Vec}; use proof_of_sql::{ base::{ @@ -26,9 +27,10 @@ use proof_of_sql_parser::{ posql_time::{PoSQLTimeUnit, PoSQLTimeZone}, Identifier, }; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_with::{serde_as, DeserializeAs, MapPreventDuplicates, SerializeAs}; -pub type IndexMap = indexmap::IndexMap< +type IndexMap = indexmap::IndexMap< Identifier, OwnedColumn, core::hash::BuildHasherDefault, @@ -42,102 +44,109 @@ pub(crate) struct QueryDataDef { verification_hash: [u8; 32], } +#[serde_as] #[derive(Serialize, Deserialize)] -#[serde(remote = "OwnedTable")] -pub struct OwnedTableDef { - #[serde(getter = "OwnedTable::inner_table", with = "index_map_serde")] +struct RaggedTable { + #[serde_as(as = "MapPreventDuplicates<_, OwnedColumnDef>")] table: IndexMap, } -impl From for OwnedTable { - fn from(value: OwnedTableDef) -> Self { - Self::try_new(value.table).unwrap() - } +#[serde_as] +#[derive(Serialize, Deserialize)] +#[serde(remote = "OwnedTable", try_from = "RaggedTable")] +struct OwnedTableDef { + #[serde_as(as = "MapPreventDuplicates<_, OwnedColumnDef>")] + #[serde(getter = "OwnedTable::inner_table")] + table: IndexMap, } -mod index_map_serde { - use super::*; - use core::{fmt, marker::PhantomData}; - use serde::{ - de::{MapAccess, Visitor}, - ser::SerializeMap, - Deserializer, Serializer, - }; +impl TryFrom for OwnedTable { + type Error = VerifyError; - #[derive(Serialize, Deserialize)] - struct OwnedColumnWrap(#[serde(with = "OwnedColumnDef")] OwnedColumn); - - #[derive(Serialize, Deserialize)] - #[serde(remote = "OwnedColumn")] - #[non_exhaustive] - pub enum OwnedColumnDef { - Boolean(Vec), - SmallInt(Vec), - Int(Vec), - BigInt(Vec), - VarChar(Vec), - Int128(Vec), - Decimal75(Precision, i8, Vec), - Scalar(Vec), - TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, Vec), + fn try_from(value: RaggedTable) -> Result { + Self::try_new(value.table).map_err(|_| VerifyError::InvalidInput) } +} + +#[derive(Serialize, Deserialize)] +#[serde(remote = "OwnedColumn")] +#[non_exhaustive] +enum OwnedColumnDef { + Boolean(Vec), + SmallInt(Vec), + Int(Vec), + BigInt(Vec), + VarChar(Vec), + Int128(Vec), + Decimal75(Precision, i8, Vec), + Scalar(Vec), + TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, Vec), +} - pub fn serialize(index_map: &IndexMap, serializer: S) -> Result +impl SerializeAs> for OwnedColumnDef { + fn serialize_as(source: &OwnedColumn, serializer: S) -> Result where S: Serializer, { - let mut map = serializer.serialize_map(Some(index_map.len()))?; - for (k, v) in index_map { - map.serialize_entry(k, &OwnedColumnWrap(v.clone()))?; - } - map.end() + OwnedColumnDef::serialize(source, serializer) } +} - pub fn deserialize<'de, D>(deserializer: D) -> Result +impl<'de> DeserializeAs<'de, OwnedColumn> for OwnedColumnDef { + fn deserialize_as(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserializer.deserialize_map(IndexMapVisitor::new()) + OwnedColumnDef::deserialize(deserializer) } +} - struct IndexMapVisitor { - marker: PhantomData IndexMap>, - } +#[cfg(test)] +mod owned_table { + use super::*; - impl IndexMapVisitor { - fn new() -> Self { - IndexMapVisitor { - marker: PhantomData, - } - } - } + use core::str::FromStr; - impl<'de> Visitor<'de> for IndexMapVisitor { - // The type that our Visitor is going to produce. - type Value = IndexMap; - - // Format a message stating what data this Visitor expects to receive. - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a very special map") - } - - // Deserialize MyMap from an abstract "map" provided by the - // Deserializer. The MapAccess input is a callback provided by - // the Deserializer to let us see each entry in the map. - fn visit_map(self, mut access: M) -> Result - where - M: MapAccess<'de>, - { - let mut map = - IndexMap::with_capacity_and_hasher(access.size_hint().unwrap_or(0), <_>::default()); - - // While there are entries remaining in the input, add them - // into our map. - while let Some((key, OwnedColumnWrap(value))) = access.next_entry()? { - map.insert(key, value); - } - - Ok(map) - } + use indexmap::IndexMap; + use proof_of_sql::base::scalar::Scalar; + + #[derive(Serialize, Deserialize)] + #[serde(transparent)] + struct Wrapper(#[serde(with = "OwnedTableDef")] OwnedTable); + + #[test] + fn serialization_should_preserve_order() { + let mut table = IndexMap::default(); + table.insert( + Identifier::from_str("b").unwrap(), + OwnedColumn::try_from_scalars( + &vec![DoryScalar::ONE, DoryScalar::ZERO], + proof_of_sql::base::database::ColumnType::Boolean, + ) + .unwrap(), + ); + table.insert( + Identifier::from_str("a").unwrap(), + OwnedColumn::try_from_scalars( + &vec![DoryScalar::ZERO, DoryScalar::ONE], + proof_of_sql::base::database::ColumnType::Boolean, + ) + .unwrap(), + ); + let owned_table = OwnedTable::try_new(table).unwrap(); + + let mut buffer = Vec::new(); + ciborium::into_writer(&Wrapper(owned_table.clone()), &mut buffer).unwrap(); + let Wrapper(deserialized_owned_table) = ciborium::from_reader(&buffer[..]).unwrap(); + + assert_eq!( + owned_table.inner_table().len(), + deserialized_owned_table.inner_table().len() + ); + assert!(owned_table + .inner_table() + .iter() + .zip(deserialized_owned_table.inner_table().iter()) + .all(|((k_0, _), (k_1, _))| k_0 == k_1)) } }