Skip to content

Commit

Permalink
implement IpAddrs (#64)
Browse files Browse the repository at this point in the history
* implement IpAddrs

* remove some unnecessary leftovers from ip addr implementation

* more docs on parameters and reflect

* cargo fmt

* fixup

* implement IpAddr eq via Rust's implementation

* fix returning ZST from runtime function

* macros register functions and (static) methods
  • Loading branch information
tertsdiepraam authored Sep 30, 2024
1 parent e57d465 commit c45f7e1
Show file tree
Hide file tree
Showing 20 changed files with 802 additions and 207 deletions.
13 changes: 6 additions & 7 deletions examples/simple.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
use roto::{read_files, Runtime, Verdict};
use roto_macros::roto_function;
use roto_macros::roto_method;

struct Bla {
_x: u16,
y: u32,
_z: u32,
}

#[roto_function]
fn get_y(bla: *const Bla) -> u32 {
unsafe { &*bla }.y
}

fn main() -> Result<(), roto::RotoReport> {
env_logger::init();

let mut runtime = Runtime::basic().unwrap();

runtime.register_type::<Bla>().unwrap();
runtime.register_method::<Bla, _, _>("y", get_y).unwrap();

#[roto_method(runtime, Bla, y)]
fn get_y(bla: *const Bla) -> u32 {
unsafe { &*bla }.y
}

let mut compiled = read_files(["examples/simple.roto"])?
.compile(runtime, usize::BITS / 8)
Expand Down
186 changes: 157 additions & 29 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,153 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn};

///
///
/// ```rust,no_run
/// fn foo(a1: A1, a2: A2) -> Ret {
/// /* ... */
/// }
/// ```
///
///
use syn::{parse_macro_input, Token};

struct Intermediate {
function: proc_macro2::TokenStream,
name: syn::Ident,
identifier: proc_macro2::TokenStream,
}

struct FunctionArgs {
runtime_ident: syn::Ident,
name: Option<syn::Ident>,
}

impl syn::parse::Parse for FunctionArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let runtime_ident = input.parse()?;

let mut name = None;
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.peek(syn::Ident) {
name = input.parse()?;
}
}

Ok(Self {
runtime_ident,
name,
})
}
}

#[proc_macro_attribute]
pub fn roto_function(attr: TokenStream, item: TokenStream) -> TokenStream {
let item = parse_macro_input!(item as syn::ItemFn);
let Intermediate {
function,
identifier,
name: function_ident,
} = generate_function(item);

let FunctionArgs {
runtime_ident,
name,
} = syn::parse(attr).unwrap();

let name = name.unwrap_or(function_ident);

let expanded = quote! {
#function

#runtime_ident.register_function(stringify!(#name), #identifier).unwrap();
};

TokenStream::from(expanded)
}

struct MethodArgs {
runtime_ident: syn::Ident,
ty: syn::Type,
name: Option<syn::Ident>,
}

impl syn::parse::Parse for MethodArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let runtime_ident = input.parse()?;
input.parse::<Token![,]>()?;
let ty = input.parse()?;

let mut name = None;
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.peek(syn::Ident) {
name = input.parse()?;
}
}
Ok(Self {
runtime_ident,
ty,
name,
})
}
}

#[proc_macro_attribute]
pub fn roto_method(attr: TokenStream, item: TokenStream) -> TokenStream {
let item = parse_macro_input!(item as syn::ItemFn);
let Intermediate {
function,
identifier,
name: function_name,
} = generate_function(item);

let MethodArgs {
runtime_ident,
ty,
name,
} = parse_macro_input!(attr as MethodArgs);

let name = name.unwrap_or(function_name);

let expanded = quote! {
#function

#runtime_ident.register_method::<#ty, _, _>(stringify!(#name), #identifier).unwrap();
};

TokenStream::from(expanded)
}

#[proc_macro_attribute]
pub fn roto_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
pub fn roto_static_method(
attr: TokenStream,
item: TokenStream,
) -> TokenStream {
let item = parse_macro_input!(item as syn::ItemFn);
let Intermediate {
function,
identifier,
name: function_name,
} = generate_function(item);

let MethodArgs {
runtime_ident,
ty,
name,
} = parse_macro_input!(attr as MethodArgs);

let ItemFn {
let name = name.unwrap_or(function_name);

let expanded = quote! {
#function

#runtime_ident.register_static_method::<#ty, _, _>(stringify!(#name), #identifier).unwrap();
};

TokenStream::from(expanded)
}

fn generate_function(item: syn::ItemFn) -> Intermediate {
let syn::ItemFn {
attrs,
vis,
sig,
block: _,
} = input.clone();
} = item.clone();

assert!(sig.unsafety.is_none());
assert!(sig.generics.params.is_empty());
assert!(sig.generics.where_clause.is_none());
assert!(sig.variadic.is_none());

let ident = sig.ident;
Expand All @@ -35,6 +158,7 @@ pub fn roto_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
pat
});

let generics = sig.generics;
let inputs = sig.inputs.clone().into_iter();
let ret = match sig.output {
syn::ReturnType::Default => quote!(()),
Expand All @@ -53,21 +177,25 @@ pub fn roto_function(_attr: TokenStream, item: TokenStream) -> TokenStream {
})
.collect();

let arg_types = quote!(*mut #ret, #(#input_types,)*);
let underscored_types = input_types.iter().map(|_| quote!(_));
let arg_types = quote!(_, #(#underscored_types,)*);

let expanded = quote! {
#[allow(non_upper_case_globals)]
#vis const #ident: extern "C" fn(#arg_types) = {
#(#attrs)*
extern "C" fn #ident ( out: *mut #ret, #(#inputs,)* ) {
#input
let function = quote! {
#(#attrs)*
#vis extern "C" fn #ident #generics ( out: *mut #ret, #(#inputs,)* ) {
#item

unsafe { *out = #ident(#(#args),*) };
}
unsafe { *out = #ident(#(#args),*) };
}
};

#ident as extern "C" fn(#arg_types)
};
let identifier = quote! {
#ident as extern "C" fn(#arg_types)
};

TokenStream::from(expanded)
Intermediate {
function,
name: ident,
identifier,
}
}
8 changes: 1 addition & 7 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ pub enum Literal {
#[allow(dead_code)]
String(String),
Asn(Asn),
IpAddress(IpAddress),
IpAddress(std::net::IpAddr),
Integer(i64),
Bool(bool),
}
Expand Down Expand Up @@ -313,9 +313,3 @@ impl std::fmt::Display for BinOp {
)
}
}

#[derive(Clone, Debug)]
pub enum IpAddress {
Ipv4(std::net::Ipv4Addr),
Ipv6(std::net::Ipv6Addr),
}
48 changes: 42 additions & 6 deletions src/codegen/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
types::{Primitive, Type},
},
};
use std::{any::TypeId, fmt::Display, mem::MaybeUninit};
use std::{any::TypeId, fmt::Display, mem::MaybeUninit, net::IpAddr};

#[derive(Debug)]
pub enum FunctionRetrievalError {
Expand Down Expand Up @@ -84,6 +84,7 @@ fn check_roto_type(
let I64: TypeId = TypeId::of::<i64>();
let UNIT: TypeId = TypeId::of::<()>();
let ASN: TypeId = TypeId::of::<Asn>();
let IPADDR: TypeId = TypeId::of::<IpAddr>();

let Some(rust_ty) = registry.get(rust_ty) else {
return Err(TypeMismatch {
Expand Down Expand Up @@ -117,6 +118,7 @@ fn check_roto_type(
x if x == I64 => Type::Primitive(Primitive::I64),
x if x == UNIT => Type::Primitive(Primitive::Unit),
x if x == ASN => Type::Primitive(Primitive::Asn),
x if x == IPADDR => Type::Primitive(Primitive::IpAddr),
_ => panic!(),
};
if expected_roto == roto_ty {
Expand Down Expand Up @@ -151,29 +153,55 @@ pub fn return_type_by_ref(registry: &TypeRegistry, rust_ty: TypeId) -> bool {
#[allow(clippy::match_like_matches_macro)]
match rust_ty.description {
TypeDescription::Verdict(_, _) => true,
_ => false,
_ => todo!(),
}
}

/// Parameters of a Roto function
///
/// This trait allows for checking the types against Roto types and converting
/// the values into values appropriate for Roto.
///
/// The `invoke` method can (unsafely) invoke a pointer as if it were a function
/// with these parameters.
///
/// This trait is implemented on tuples of various sizes.
pub trait RotoParams {
/// This type but with [`Reflect::AsParam`] applied to each element.
type AsParams;

/// Convert to `Self::AsParams`.
fn as_params(&mut self) -> Self::AsParams;

/// Check whether these parameters match a parameter list from Roto.
fn check(
type_info: &mut TypeInfo,
ty: &[Type],
) -> Result<(), FunctionRetrievalError>;

/// Call a function pointer as if it were a function with these parameters.
///
/// This is _extremely_ unsafe, do not pass this arbitrary pointers and
/// always call `RotoParams::check` before calling this function. Don't
/// forget to also check the return type.
///
/// A [`TypedFunc`](super::TypedFunc) is a safe abstraction around this
/// function.
unsafe fn invoke<R: Reflect>(
self,
func_ptr: *const u8,
params: Self,
return_by_ref: bool,
) -> R;
}

/// Little helper macro to create a unit
macro_rules! unit {
($t:tt) => {
()
};
}

/// Implement the [`RotoParams`] trait for a tuple with some type parameters.
macro_rules! params {
($($t:ident),*) => {
#[allow(non_snake_case)]
Expand All @@ -183,6 +211,13 @@ macro_rules! params {
where
$($t: Reflect,)*
{
type AsParams = ($($t::AsParam,)*);

fn as_params(&mut self) -> Self::AsParams {
let ($($t,)*) = self;
return ($($t.as_param(),)*);
}

fn check(
type_info: &mut TypeInfo,
ty: &[Type]
Expand All @@ -205,17 +240,18 @@ macro_rules! params {
Ok(())
}

unsafe fn invoke<R: Reflect>(func_ptr: *const u8, ($($t,)*): Self, return_by_ref: bool) -> R {
unsafe fn invoke<R: Reflect>(mut self, func_ptr: *const u8, return_by_ref: bool) -> R {
let ($($t,)*) = self.as_params();
if return_by_ref {
let func_ptr = unsafe {
std::mem::transmute::<*const u8, fn(*mut R, $($t),*) -> ()>(func_ptr)
std::mem::transmute::<*const u8, fn(*mut R, $($t::AsParam),*) -> ()>(func_ptr)
};
let mut ret = MaybeUninit::<R>::uninit();
func_ptr(ret.as_mut_ptr(), $($t),*);
unsafe { ret.assume_init() }
} else {
let func_ptr = unsafe {
std::mem::transmute::<*const u8, fn($($t),*) -> R>(func_ptr)
std::mem::transmute::<*const u8, fn($($t::AsParam),*) -> R>(func_ptr)
};
func_ptr($($t),*)
}
Expand Down
Loading

0 comments on commit c45f7e1

Please sign in to comment.