Skip to content

Commit

Permalink
feat: Derive Ord and Hash in the stdlib; add `std::meta::make_imp…
Browse files Browse the repository at this point in the history
…l` helper (#5683)

# Description

## Problem\*

## Summary\*

`Ord` and `Hash` are the last two traits in the stdlib that can be
derived - and now we can.
I've also added `std::meta::make_impl` so that there's not so much
repeated code for each of this function. This also makes writing these
derive functions somewhat easier for users.

## Additional Context

## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Michael J Klein <[email protected]>
  • Loading branch information
jfecher and michaeljklein authored Aug 6, 2024
1 parent 07ea107 commit 38397d3
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 47 deletions.
42 changes: 20 additions & 22 deletions noir_stdlib/src/cmp.nr
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,10 @@ trait Eq {
// docs:end:eq-trait

comptime fn derive_eq(s: StructDefinition) -> Quoted {
let typ = s.as_type();

let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,});

let where_clause = s.generics().map(|name| quote { $name: Eq }).join(quote {,});

// `(self.a == other.a) & (self.b == other.b) & ...`
let equalities = s.fields().map(
|f: (Quoted, Type)| {
let name = f.0;
quote { (self.$name == other.$name) }
}
);
let body = equalities.join(quote { & });

quote {
impl<$impl_generics> Eq for $typ where $where_clause {
fn eq(self, other: Self) -> bool {
$body
}
}
}
let signature = quote { fn eq(_self: Self, _other: Self) -> bool };
let for_each_field = |name| quote { (_self.$name == _other.$name) };
let body = |fields| fields;
crate::meta::make_trait_impl(s, quote { Eq }, signature, for_each_field, quote { & }, body)
}

impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } }
Expand Down Expand Up @@ -127,12 +109,28 @@ impl Ordering {
}
}

#[derive_via(derive_ord)]
// docs:start:ord-trait
trait Ord {
fn cmp(self, other: Self) -> Ordering;
}
// docs:end:ord-trait

comptime fn derive_ord(s: StructDefinition) -> Quoted {
let signature = quote { fn cmp(_self: Self, _other: Self) -> std::cmp::Ordering };
let for_each_field = |name| quote {
if result == std::cmp::Ordering::equal() {
result = _self.$name.cmp(_other.$name);
}
};
let body = |fields| quote {
let mut result = std::cmp::Ordering::equal();
$fields
result
};
crate::meta::make_trait_impl(s, quote { Ord }, signature, for_each_field, quote {}, body)
}

// Note: Field deliberately does not implement Ord

impl Ord for u64 {
Expand Down
27 changes: 5 additions & 22 deletions noir_stdlib/src/default.nr
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,11 @@ trait Default {
// docs:end:default-trait

comptime fn derive_default(s: StructDefinition) -> Quoted {
let typ = s.as_type();

let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,});

let where_clause = s.generics().map(|name| quote { $name: Default }).join(quote {,});

// `foo: Default::default(), bar: Default::default(), ...`
let fields = s.fields().map(
|f: (Quoted, Type)| {
let name = f.0;
quote { $name: Default::default() }
}
);
let fields = fields.join(quote {,});

quote {
impl<$impl_generics> Default for $typ where $where_clause {
fn default() -> Self {
Self { $fields }
}
}
}
let name = quote { Default };
let signature = quote { fn default() -> Self };
let for_each_field = |name| quote { $name: Default::default() };
let body = |fields| quote { Self { $fields } };
crate::meta::make_trait_impl(s, name, signature, for_each_field, quote { , }, body)
}

impl Default for Field { fn default() -> Field { 0 } }
Expand Down
11 changes: 10 additions & 1 deletion noir_stdlib/src/hash/mod.nr
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::uint128::U128;
use crate::sha256::{digest, sha256_var};
use crate::collections::vec::Vec;
use crate::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul, multi_scalar_mul_slice};
use crate::meta::derive_via;

#[foreign(sha256)]
// docs:start:sha256
Expand Down Expand Up @@ -141,10 +142,18 @@ pub fn sha256_compression(_input: [u32; 16], _state: [u32; 8]) -> [u32; 8] {}
// Partially ported and impacted by rust.

// Hash trait shall be implemented per type.
trait Hash{
#[derive_via(derive_hash)]
trait Hash {
fn hash<H>(self, state: &mut H) where H: Hasher;
}

comptime fn derive_hash(s: StructDefinition) -> Quoted {
let name = quote { Hash };
let signature = quote { fn hash<H>(_self: Self, _state: &mut H) where H: std::hash::Hasher };
let for_each_field = |name| quote { _self.$name.hash(_state); };
crate::meta::make_trait_impl(s, name, signature, for_each_field, quote {}, |fields| fields)
}

// Hasher trait shall be implemented by algorithms to provide hash-agnostic means.
// TODO: consider making the types generic here ([u8], [Field], etc.)
trait Hasher{
Expand Down
42 changes: 42 additions & 0 deletions noir_stdlib/src/meta/mod.nr
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,45 @@ pub comptime fn derive(s: StructDefinition, traits: [TraitDefinition]) -> Quoted
unconstrained pub comptime fn derive_via(t: TraitDefinition, f: DeriveFunction) {
HANDLERS.insert(t, f);
}

/// `make_impl` is a helper function to make a simple impl, usually while deriving a trait.
/// This impl has a couple assumptions:
/// 1. The impl only has one function, with the signature `function_signature`
/// 2. The trait itself does not have any generics.
///
/// While these assumptions are met, `make_impl` will create an impl from a StructDefinition,
/// automatically filling in the required generics from the struct, along with the where clause.
/// The function body is created by mapping each field with `for_each_field` and joining the
/// results with `join_fields_with`. The result of this is passed to the `body` function for
/// any final processing - e.g. wrapping each field in a `StructConstructor { .. }` expression.
///
/// See `derive_eq` and `derive_default` for example usage.
pub comptime fn make_trait_impl<Env1, Env2>(
s: StructDefinition,
trait_name: Quoted,
function_signature: Quoted,
for_each_field: fn[Env1](Quoted) -> Quoted,
join_fields_with: Quoted,
body: fn[Env2](Quoted) -> Quoted
) -> Quoted {
let typ = s.as_type();
let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,});
let where_clause = s.generics().map(|name| quote { $name: $trait_name }).join(quote {,});

// `for_each_field(field1) $join_fields_with for_each_field(field2) $join_fields_with ...`
let fields = s.fields().map(
|f: (Quoted, Type)| {
let name = f.0;
for_each_field(name)
}
);
let body = body(fields.join(join_fields_with));

quote {
impl<$impl_generics> $trait_name for $typ where $where_clause {
$function_signature {
$body
}
}
}
}
30 changes: 28 additions & 2 deletions test_programs/execution_success/derive/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::hash::Hash;

#[derive_via(derive_do_nothing)]
trait DoNothing {
fn do_nothing(self);
Expand All @@ -20,14 +22,15 @@ comptime fn derive_do_nothing(s: StructDefinition) -> Quoted {
}

// Test stdlib derive fns & multiple traits
#[derive(Eq, Default)]
// - We can derive Ord and Hash even though std::cmp::Ordering and std::hash::Hasher aren't imported
#[derive(Eq, Default, Hash, Ord)]
struct MyOtherStruct<A, B> {
field1: A,
field2: B,
field3: MyOtherOtherStruct<B>,
}

#[derive(Eq, Default)]
#[derive(Eq, Default, Hash, Ord)]
struct MyOtherOtherStruct<T> {
x: T,
}
Expand All @@ -41,4 +44,27 @@ fn main() {

let o: MyOtherStruct<u8, [str<2>]> = MyOtherStruct::default();
assert_eq(o, o);

// Field & str<2> above don't implement Ord
let o1 = MyOtherStruct { field1: 12 as u32, field2: 24 as i8, field3: MyOtherOtherStruct { x: 54 as i8 } };
let o2 = MyOtherStruct { field1: 12 as u32, field2: 24 as i8, field3: MyOtherOtherStruct { x: 55 as i8 } };
assert(o1 < o2);

let mut hasher = TestHasher { result: 0 };
o1.hash(&mut hasher);
assert_eq(hasher.finish(), 12 + 24 + 54);
}

struct TestHasher {
result: Field,
}

impl std::hash::Hasher for TestHasher {
fn finish(self) -> Field {
self.result
}

fn write(&mut self, input: Field) {
self.result += input;
}
}

0 comments on commit 38397d3

Please sign in to comment.