Skip to content

Commit

Permalink
Merge pull request #61 from NLnetLabs/clean-up-typedfunc-call
Browse files Browse the repository at this point in the history
Implement `TypedFunc::call` as variadic-ish (as much as possible)
  • Loading branch information
tertsdiepraam authored Sep 11, 2024
2 parents 5dd4e61 + bdcb3e2 commit 897d4a0
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 30 deletions.
2 changes: 1 addition & 1 deletion examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() -> Result<(), roto::RotoReport> {

for y in 0..20 {
let mut bla = Bla { _x: 1, y, _z: 1 };
let res = func.call((&mut bla as *mut _,));
let res = func.call(&mut bla as *mut _);

let expected = if y > 10 {
Verdict::Accept(y * 2)
Expand Down
30 changes: 29 additions & 1 deletion src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,41 @@ pub struct TypedFunc<'module, Params, Return> {
impl<'module, Params: RotoParams, Return: Reflect>
TypedFunc<'module, Params, Return>
{
pub fn call(&self, params: Params) -> Return {
pub fn call_tuple(&self, params: Params) -> Return {
unsafe {
Params::invoke::<Return>(self.func, params, self.return_by_ref)
}
}
}

macro_rules! call_impl {
($($ty:ident),*) => {
impl<'module, $($ty,)* Return: Reflect> TypedFunc<'module, ($($ty,)*), Return>
where
($($ty,)*): RotoParams,
{
#[allow(non_snake_case)]
pub fn call(&self, $($ty: $ty,)*) -> Return {
self.call_tuple(($($ty,)*))
}

#[allow(non_snake_case)]
pub fn as_func(self) -> impl Fn($($ty,)*) -> Return + 'module {
move |$($ty,)*| self.call($($ty,)*)
}
}
}
}

call_impl!();
call_impl!(A);
call_impl!(A, B);
call_impl!(A, B, C);
call_impl!(A, B, C, D);
call_impl!(A, B, C, D, E);
call_impl!(A, B, C, D, E, F);
call_impl!(A, B, C, D, E, F, G);

pub struct FunctionInfo {
id: FuncId,
signature: types::Signature,
Expand Down
57 changes: 29 additions & 28 deletions src/codegen/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn accept() {
.get_function::<(), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call(());
let res = f.call();
dbg!(std::mem::size_of::<Verdict<(), ()>>());
assert_eq!(res, Verdict::Accept(()));
}
Expand All @@ -70,7 +70,7 @@ fn reject() {
.get_function::<(), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call(());
let res = f.call();
assert_eq!(res, Verdict::Reject(()));
}

Expand All @@ -93,10 +93,10 @@ fn equal_to_10() {
.get_function::<(u32,), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call((5,));
let res = f.call(5);
assert_eq!(res, Verdict::Reject(()));

let res = f.call((10,));
let res = f.call(10);
assert_eq!(res, Verdict::Accept(()));
}

Expand All @@ -123,10 +123,10 @@ fn equal_to_10_with_function() {
.get_function::<(i32,), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call((5,));
let res = f.call(5);
assert_eq!(res, Verdict::Reject(()));

let res = f.call((10,));
let res = f.call(10);
assert_eq!(res, Verdict::Accept(()));
}

Expand Down Expand Up @@ -159,10 +159,10 @@ fn equal_to_10_with_two_functions() {
.get_function::<(u32,), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");

assert_eq!(f.call((5,)), Verdict::Reject(()));
assert_eq!(f.call((10,)), Verdict::Accept(()));
assert_eq!(f.call((15,)), Verdict::Reject(()));
assert_eq!(f.call((20,)), Verdict::Accept(()));
assert_eq!(f.call(5), Verdict::Reject(()));
assert_eq!(f.call(10), Verdict::Accept(()));
assert_eq!(f.call(15), Verdict::Reject(()));
assert_eq!(f.call(20), Verdict::Accept(()));
}

#[test]
Expand All @@ -185,7 +185,7 @@ fn negation() {
.expect("No function found (or mismatched types)");

for x in 0..20 {
let res = f.call((x,));
let res = f.call(x);
let exp = if x != 10 {
Verdict::Accept(())
} else {
Expand Down Expand Up @@ -227,7 +227,7 @@ fn a_bunch_of_comparisons() {
Verdict::Reject(())
};

let res = f.call((x,));
let res = f.call(x);
assert_eq!(res, expected);
}
}
Expand Down Expand Up @@ -262,7 +262,7 @@ fn record() {
} else {
Verdict::Reject(())
};
let res = f.call((x,));
let res = f.call(x);
assert_eq!(res, expected);
}
}
Expand Down Expand Up @@ -291,15 +291,16 @@ fn record_with_fields_flipped() {
let mut p = compile(s);
let f = p
.get_function::<(i32,), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");
.expect("No function found (or mismatched types)")
.as_func();

for x in 0..100 {
let expected = if x == 20 {
Verdict::Accept(())
} else {
Verdict::Reject(())
};
let res = f.call((x,));
let res = f(x);
assert_eq!(res, expected);
}
}
Expand Down Expand Up @@ -336,7 +337,7 @@ fn nested_record() {
} else {
Verdict::Reject(())
};
let res = f.call((x,));
let res = f.call(x);
assert_eq!(res, expected, "for {x}");
}
}
Expand Down Expand Up @@ -372,7 +373,7 @@ fn misaligned_fields() {
} else {
Verdict::Reject(())
};
let res = f.call((x,));
let res = f.call(x);
assert_eq!(res, expected, "for {x}");
}
}
Expand Down Expand Up @@ -403,8 +404,8 @@ fn enum_match() {
.get_function::<(bool,), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");

assert_eq!(f.call((true,)), Verdict::Accept(()));
assert_eq!(f.call((false,)), Verdict::Reject(()));
assert_eq!(f.call(true), Verdict::Accept(()));
assert_eq!(f.call(false), Verdict::Reject(()));
}

#[test]
Expand All @@ -426,13 +427,13 @@ fn arithmetic() {
.get_function::<(i32,), Verdict<(), ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call((5,));
let res = f.call(5);
assert_eq!(res, Verdict::Accept(()));

let res = f.call((20,));
let res = f.call(20);
assert_eq!(res, Verdict::Accept(()));

let res = f.call((100,));
let res = f.call(100);
assert_eq!(res, Verdict::Reject(()));
}

Expand All @@ -458,7 +459,7 @@ fn call_runtime_function() {
for (value, expected) in
[(5, Verdict::Reject(())), (11, Verdict::Accept(()))]
{
let res = f.call((value,));
let res = f.call(value);
assert_eq!(res, expected);
}
}
Expand All @@ -485,7 +486,7 @@ fn call_runtime_method() {
for (value, expected) in
[(5, Verdict::Reject(())), (10, Verdict::Accept(()))]
{
let res = f.call((value,));
let res = f.call(value);
assert_eq!(res, expected);
}
}
Expand All @@ -505,7 +506,7 @@ fn int_var() {
.get_function::<(), Verdict<i32, ()>>("main")
.expect("No function found (or mismatched types)");

assert_eq!(f.call(()), Verdict::Accept(32));
assert_eq!(f.call(), Verdict::Accept(32));
}

#[test]
Expand Down Expand Up @@ -573,11 +574,11 @@ fn asn() {
.expect("No function found (or mismatched types)");

assert_eq!(
f.call((Asn::from_u32(1000),)),
f.call(Asn::from_u32(1000)),
Verdict::Accept(Asn::from_u32(1000))
);
assert_eq!(
f.call((Asn::from_u32(2000),)),
f.call(Asn::from_u32(2000)),
Verdict::Reject(Asn::from_u32(2000))
);
}
Expand Down Expand Up @@ -621,7 +622,7 @@ fn multiply() {
.get_function::<(u8,), Verdict<u8, ()>>("main")
.expect("No function found (or mismatched types)");

let res = f.call((20,));
let res = f.call(20);
assert_eq!(res, Verdict::Accept(40));
}

Expand Down

0 comments on commit 897d4a0

Please sign in to comment.