Skip to content

Commit

Permalink
Avoid custom implementation of (de)serialization for IndexMap
Browse files Browse the repository at this point in the history
  • Loading branch information
lgiussan committed Oct 9, 2024
1 parent 35debf9 commit ee3e708
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 80 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ ark-bls12-381 = { version = "0.4.0", default-features = false }
ark-ec = { version = "0.4.0", default-features = false }
ark-serialize = { version = "0.4.0", default-features = false }
ciborium = { version = "0.2.2", default-features = false }
indexmap = { version = "2.1", default-features = false }
indexmap = { version = "2.1", default-features = false, features = ["serde"] }
proof-of-sql = { version = "0.28.6", default-features = false }
proof-of-sql-parser = { version = "0.28.6", default-features = false }
rand = { version = "0.8.0", optional = true }
serde = { version = "1.0", default-features = false }
serde_with = { version = "3.11.0", default-features = false, features = ["macros", "alloc", "indexmap_2"] }
snafu = { version = "0.8.0", default-features = false }

[dev-dependencies]
ark-std = { version = "0.4.0" }
Expand All @@ -29,6 +31,8 @@ std = [
"serde/std",
"ciborium/std",
"proof-of-sql/std",
"serde_with/std",
"snafu/std",
]
test = [
"std",
Expand Down
4 changes: 3 additions & 1 deletion src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
// limitations under the License.

/// This module defines errors used across the verification library.
#[derive(Debug, PartialEq)]
use snafu::Snafu;

#[derive(Debug, Snafu)]
pub enum VerifyError {
/// Provided data has invalid public inputs.
InvalidInput,
Expand Down
165 changes: 87 additions & 78 deletions src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::errors::VerifyError;
use alloc::{string::String, vec::Vec};
use proof_of_sql::{
base::{
Expand All @@ -26,9 +27,10 @@ use proof_of_sql_parser::{
posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
Identifier,
};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_with::{serde_as, DeserializeAs, MapPreventDuplicates, SerializeAs};

pub type IndexMap = indexmap::IndexMap<
type IndexMap = indexmap::IndexMap<
Identifier,
OwnedColumn<DoryScalar>,
core::hash::BuildHasherDefault<ahash::AHasher>,
Expand All @@ -42,102 +44,109 @@ pub(crate) struct QueryDataDef {
verification_hash: [u8; 32],
}

#[serde_as]
#[derive(Serialize, Deserialize)]
#[serde(remote = "OwnedTable<DoryScalar>")]
pub struct OwnedTableDef {
#[serde(getter = "OwnedTable::inner_table", with = "index_map_serde")]
struct RaggedTable {
#[serde_as(as = "MapPreventDuplicates<_, OwnedColumnDef>")]
table: IndexMap,
}

impl From<OwnedTableDef> for OwnedTable<DoryScalar> {
fn from(value: OwnedTableDef) -> Self {
Self::try_new(value.table).unwrap()
}
#[serde_as]
#[derive(Serialize, Deserialize)]
#[serde(remote = "OwnedTable<DoryScalar>", try_from = "RaggedTable")]
struct OwnedTableDef {
#[serde_as(as = "MapPreventDuplicates<_, OwnedColumnDef>")]
#[serde(getter = "OwnedTable::inner_table")]
table: IndexMap,
}

mod index_map_serde {
use super::*;
use core::{fmt, marker::PhantomData};
use serde::{
de::{MapAccess, Visitor},
ser::SerializeMap,
Deserializer, Serializer,
};
impl TryFrom<RaggedTable> for OwnedTable<DoryScalar> {
type Error = VerifyError;

#[derive(Serialize, Deserialize)]
struct OwnedColumnWrap(#[serde(with = "OwnedColumnDef")] OwnedColumn<DoryScalar>);

#[derive(Serialize, Deserialize)]
#[serde(remote = "OwnedColumn<DoryScalar>")]
#[non_exhaustive]
pub enum OwnedColumnDef {
Boolean(Vec<bool>),
SmallInt(Vec<i16>),
Int(Vec<i32>),
BigInt(Vec<i64>),
VarChar(Vec<String>),
Int128(Vec<i128>),
Decimal75(Precision, i8, Vec<DoryScalar>),
Scalar(Vec<DoryScalar>),
TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, Vec<i64>),
fn try_from(value: RaggedTable) -> Result<Self, Self::Error> {
Self::try_new(value.table).map_err(|_| VerifyError::InvalidInput)
}
}

#[derive(Serialize, Deserialize)]
#[serde(remote = "OwnedColumn<DoryScalar>")]
#[non_exhaustive]
enum OwnedColumnDef {
Boolean(Vec<bool>),
SmallInt(Vec<i16>),
Int(Vec<i32>),
BigInt(Vec<i64>),
VarChar(Vec<String>),
Int128(Vec<i128>),
Decimal75(Precision, i8, Vec<DoryScalar>),
Scalar(Vec<DoryScalar>),
TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, Vec<i64>),
}

pub fn serialize<S>(index_map: &IndexMap, serializer: S) -> Result<S::Ok, S::Error>
impl SerializeAs<OwnedColumn<DoryScalar>> for OwnedColumnDef {
fn serialize_as<S>(source: &OwnedColumn<DoryScalar>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(index_map.len()))?;
for (k, v) in index_map {
map.serialize_entry(k, &OwnedColumnWrap(v.clone()))?;
}
map.end()
OwnedColumnDef::serialize(source, serializer)
}
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<IndexMap, D::Error>
impl<'de> DeserializeAs<'de, OwnedColumn<DoryScalar>> for OwnedColumnDef {
fn deserialize_as<D>(deserializer: D) -> Result<OwnedColumn<DoryScalar>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(IndexMapVisitor::new())
OwnedColumnDef::deserialize(deserializer)
}
}

struct IndexMapVisitor {
marker: PhantomData<fn() -> IndexMap>,
}
#[cfg(test)]
mod owned_table {
use super::*;

impl IndexMapVisitor {
fn new() -> Self {
IndexMapVisitor {
marker: PhantomData,
}
}
}
use core::str::FromStr;

impl<'de> Visitor<'de> for IndexMapVisitor {
// The type that our Visitor is going to produce.
type Value = IndexMap;

// Format a message stating what data this Visitor expects to receive.
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a very special map")
}

// Deserialize MyMap from an abstract "map" provided by the
// Deserializer. The MapAccess input is a callback provided by
// the Deserializer to let us see each entry in the map.
fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut map =
IndexMap::with_capacity_and_hasher(access.size_hint().unwrap_or(0), <_>::default());

// While there are entries remaining in the input, add them
// into our map.
while let Some((key, OwnedColumnWrap(value))) = access.next_entry()? {
map.insert(key, value);
}

Ok(map)
}
use indexmap::IndexMap;
use proof_of_sql::base::scalar::Scalar;

Check warning on line 111 in src/serde.rs

View workflow job for this annotation

GitHub Actions / build-test-check

unused import: `proof_of_sql::base::scalar::Scalar`

#[derive(Serialize, Deserialize)]
#[serde(transparent)]
struct Wrapper(#[serde(with = "OwnedTableDef")] OwnedTable<DoryScalar>);

#[test]
fn serialization_should_preserve_order() {
let mut table = IndexMap::default();
table.insert(
Identifier::from_str("b").unwrap(),
OwnedColumn::try_from_scalars(
&vec![DoryScalar::ONE, DoryScalar::ZERO],

Check failure on line 123 in src/serde.rs

View workflow job for this annotation

GitHub Actions / build-test-check

cannot find macro `vec` in this scope
proof_of_sql::base::database::ColumnType::Boolean,
)
.unwrap(),
);
table.insert(
Identifier::from_str("a").unwrap(),
OwnedColumn::try_from_scalars(
&vec![DoryScalar::ZERO, DoryScalar::ONE],

Check failure on line 131 in src/serde.rs

View workflow job for this annotation

GitHub Actions / build-test-check

cannot find macro `vec` in this scope
proof_of_sql::base::database::ColumnType::Boolean,
)
.unwrap(),
);
let owned_table = OwnedTable::try_new(table).unwrap();

let mut buffer = Vec::new();
ciborium::into_writer(&Wrapper(owned_table.clone()), &mut buffer).unwrap();
let Wrapper(deserialized_owned_table) = ciborium::from_reader(&buffer[..]).unwrap();

assert_eq!(
owned_table.inner_table().len(),
deserialized_owned_table.inner_table().len()
);
assert!(owned_table
.inner_table()
.iter()
.zip(deserialized_owned_table.inner_table().iter())
.all(|((k_0, _), (k_1, _))| k_0 == k_1))
}
}

0 comments on commit ee3e708

Please sign in to comment.