Skip to content

Commit

Permalink
Implement support of derive(Arbitrary) for generic newtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
greyblake committed Jun 29, 2024
1 parent 7c423f7 commit abd542b
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 25 deletions.
16 changes: 11 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,36 @@ jobs:
with:
toolchain: stable

- name: cargo test --features nutype_test
- name: cargo test
uses: actions-rs/cargo@v1
with:
command: test

- name: cargo test --features nutype_test,serde
- name: cargo test --features serde
uses: actions-rs/cargo@v1
with:
command: test
args: --features serde

- name: cargo test --features nutype_test,regex
- name: cargo test --features regex
uses: actions-rs/cargo@v1
with:
command: test
args: --features regex

- name: cargo test --features nutype_test,new_unchecked
- name: cargo test --features new_unchecked
uses: actions-rs/cargo@v1
with:
command: test
args: --features new_unchecked

- name: cargo test --features nutype_test,schemars08
- name: cargo test --features arbitrary
uses: actions-rs/cargo@v1
with:
command: test
args: --features arbitrary

- name: cargo test --features schemars08
uses: actions-rs/cargo@v1
with:
command: test
Expand Down
1 change: 1 addition & 0 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ test-all:
cargo test --features regex
cargo test --features new_unchecked
cargo test --features schemars08
cargo test --features arbitrary
cargo test --all-features

test:
Expand Down
18 changes: 12 additions & 6 deletions dummy/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use arbitrary::Arbitrary;
use nutype::nutype;

#[nutype(
validate(predicate = |n| n.is_even()),
derive(Debug, FromStr),
)]
struct Even<T: ::num::Integer>(T);
#[nutype(derive(Debug, Arbitrary))]
struct Wrapper<T>(Vec<T>);

fn main() {}
fn main() {
fn gen(bytes: &[u8]) -> Wrapper<bool> {
let mut u = arbitrary::Unstructured::new(bytes);
Wrapper::<bool>::arbitrary(&mut u).unwrap()
}
assert_eq!(gen(&[]).into_inner(), vec![]);
assert_eq!(gen(&[1]).into_inner(), vec![false]);
assert_eq!(gen(&[1, 3, 5]).into_inner(), vec![true, false]);
}
23 changes: 16 additions & 7 deletions nutype_macros/src/any/gen/traits/arbitrary.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::Generics;

use crate::{
any::models::{AnyGuard, AnyInnerType},
common::gen::{add_bound_to_all_type_params, add_param, strip_trait_bounds_on_generics},
common::models::TypeName,
};

pub fn gen_impl_trait_arbitrary(
type_name: &TypeName,
generics: &Generics,
inner_type: &AnyInnerType,
guard: &AnyGuard,
) -> Result<TokenStream, syn::Error> {
Expand All @@ -22,18 +25,24 @@ pub fn gen_impl_trait_arbitrary(

// Generate implementation of `Arbitrary` trait, assuming that inner type implements Arbitrary
// too.
let generics_without_bounds = strip_trait_bounds_on_generics(generics);
let generics_with_lifetime = add_param(&generics_without_bounds, quote!('nu_arb));
let generics_with_bounds = add_bound_to_all_type_params(
&generics_with_lifetime,
quote!(::arbitrary::Arbitrary<'nu_arb>),
);
Ok(quote!(
impl ::arbitrary::Arbitrary<'_> for #type_name {
fn arbitrary(u: &mut ::arbitrary::Unstructured<'_>) -> ::arbitrary::Result<Self> {
impl #generics_with_bounds ::arbitrary::Arbitrary<'nu_arb> for #type_name #generics_without_bounds {
fn arbitrary(u: &mut ::arbitrary::Unstructured<'nu_arb>) -> ::arbitrary::Result<Self> {
let inner_value: #inner_type = u.arbitrary()?;
Ok(#type_name::new(inner_value))
}
}

#[inline]
fn size_hint(_depth: usize) -> (usize, Option<usize>) {
let n = ::core::mem::size_of::<#inner_type>();
(n, Some(n))
#[inline]
fn size_hint(_depth: usize) -> (usize, Option<usize>) {
let n = ::core::mem::size_of::<#inner_type>();
(n, Some(n))
}
}
))
}
2 changes: 1 addition & 1 deletion nutype_macros/src/any/gen/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ fn gen_implemented_traits(
AnyIrregularTrait::SerdeDeserialize => Ok(
gen_impl_trait_serde_deserialize(type_name, generics, inner_type, maybe_error_type_name.as_ref())
),
AnyIrregularTrait::ArbitraryArbitrary => arbitrary::gen_impl_trait_arbitrary(type_name, inner_type, guard),
AnyIrregularTrait::ArbitraryArbitrary => arbitrary::gen_impl_trait_arbitrary(type_name, generics, inner_type, guard),
})
.collect()
}
20 changes: 18 additions & 2 deletions nutype_macros/src/common/gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ pub fn gen_impl_into_inner(
///
/// Output:
/// <T, U>
fn strip_trait_bounds_on_generics(original: &Generics) -> Generics {
pub fn strip_trait_bounds_on_generics(original: &Generics) -> Generics {
let mut generics = original.clone();
for param in &mut generics.params {
if let syn::GenericParam::Type(syn::TypeParam { bounds, .. }) = param {
Expand All @@ -174,7 +174,7 @@ fn strip_trait_bounds_on_generics(original: &Generics) -> Generics {
///
/// Output:
/// <T: Serialize, U: Serialize>
fn add_bound_to_all_type_params(generics: &Generics, bound: TokenStream) -> Generics {
pub fn add_bound_to_all_type_params(generics: &Generics, bound: TokenStream) -> Generics {
let mut generics = generics.clone();
let parsed_bound: syn::TypeParamBound =
syn::parse2(bound).expect("Failed to parse TypeParamBound");
Expand All @@ -186,6 +186,22 @@ fn add_bound_to_all_type_params(generics: &Generics, bound: TokenStream) -> Gene
generics
}

/// Add a parameter to generics.
///
/// Input:
/// <T, U>
/// 'a
///
/// Output:
/// <'a, T, U>
///
pub fn add_param(generics: &Generics, param: TokenStream) -> Generics {
let mut generics = generics.clone();
let parsed_param: syn::GenericParam = syn::parse2(param).expect("Failed to parse GenericParam");
generics.params.push(parsed_param);
generics
}

pub trait GenerateNewtype {
type Sanitizer;
type Validator;
Expand Down
2 changes: 1 addition & 1 deletion nutype_macros/src/common/gen/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub struct GeneratedTraits {
pub implement_traits: TokenStream,
}

/// Split traits into 2 groups for generatation:
/// Split traits into 2 groups for generation:
/// * Transparent traits can be simply derived, e.g. `derive(Debug)`.
/// * Irregular traits requires implementation to be generated.
pub enum GeneratableTrait<TransparentTrait, IrregularTrait> {
Expand Down
1 change: 1 addition & 0 deletions test_suite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ num = "0.4.3"
[features]
serde = ["nutype/serde", "dep:serde", "dep:serde_json"]
regex = ["nutype/regex", "dep:regex", "dep:lazy_static", "dep:once_cell"]
arbitrary = ["nutype/arbitrary"]
schemars08 = ["schemars"]
new_unchecked = []
ui = []
22 changes: 19 additions & 3 deletions test_suite/tests/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,25 @@ mod with_generics {
}
}

#[test]
fn test_generic_boundaries_arbitrary() {
// TODO
mod generics_and_arbitrary {
use super::*;
use arbitrary::Arbitrary;

#[nutype(derive(Debug, Arbitrary))]
struct Arbaro<T>(Vec<T>);

fn gen(bytes: &[u8]) -> Vec<bool> {
let mut u = arbitrary::Unstructured::new(&bytes);
let arbraro = Arbaro::<bool>::arbitrary(&mut u).unwrap();
arbraro.into_inner()
}

#[test]
fn test_generic_boundaries_arbitrary() {
assert_eq!(gen(&[]), Vec::<bool>::new());
assert_eq!(gen(&[1]), vec![false]);
assert_eq!(gen(&[1, 3, 5]), vec![true, false]);
}
}

#[test]
Expand Down

0 comments on commit abd542b

Please sign in to comment.