Skip to content

Commit

Permalink
scroll_derive: Custom ctx override for derive macro (#106)
Browse files Browse the repository at this point in the history
* feat: attribute to override ctx for particular fields on derive(Pread, Pwrite)

Enables using the syntax:
```
#[scroll(ctx = context_expr)]
field: T
```
To use the `context_expr` expression as context when for `Pread` for `field`, regardless of the context for the rest of the struct.
  • Loading branch information
Easyoakland authored Oct 22, 2024
1 parent a7db2f2 commit 01d8722
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 24 deletions.
81 changes: 81 additions & 0 deletions scroll_derive/examples/derive_custom_ctx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use scroll_derive::{Pread, Pwrite, SizeWith};

/// An example of using a method as the value for a ctx in a derive.
struct EndianDependent(Endian);
impl EndianDependent {
fn len(&self) -> usize {
match self.0 {
scroll::Endian::Little => 5,
scroll::Endian::Big => 6,
}
}
}

#[derive(Debug, PartialEq)]
struct VariableLengthData {
buf: Vec<u8>,
}

impl<'a> TryFromCtx<'a, usize> for VariableLengthData {
type Error = scroll::Error;

fn try_from_ctx(from: &'a [u8], ctx: usize) -> Result<(Self, usize), Self::Error> {
let offset = &mut 0;
let buf = from.gread_with::<&[u8]>(offset, ctx)?.to_owned();
Ok((Self { buf }, *offset))
}
}
impl<'a> TryIntoCtx<usize> for &'a VariableLengthData {
type Error = scroll::Error;
fn try_into_ctx(self, dst: &mut [u8], ctx: usize) -> Result<usize, Self::Error> {
let offset = &mut 0;
for i in 0..(ctx.min(self.buf.len())) {
dst.gwrite(self.buf[i], offset)?;
}
Ok(*offset)
}
}
impl SizeWith<usize> for VariableLengthData {
fn size_with(ctx: &usize) -> usize {
*ctx
}
}

#[derive(Debug, PartialEq, Pread, Pwrite, SizeWith)]
#[repr(C)]
struct Data {
id: u32,
timestamp: f64,
// You can fix the ctx regardless of what is passed in.
#[scroll(ctx = BE)]
arr: [u16; 2],
// You can use arbitrary expressions for the ctx.
// You have access to the `ctx` parameter of the `{pread/gread}_with` inside the expression.
// TODO(implement) you have access to previous fields.
// TODO(check) will this break structs with fields named `ctx`?.
#[scroll(ctx = EndianDependent(ctx.clone()).len())]
custom_ctx: VariableLengthData,
}

use scroll::{
ctx::{SizeWith, TryFromCtx, TryIntoCtx},
Endian, Pread, Pwrite, BE, LE,
};

fn main() {
let bytes = [
0xefu8, 0xbe, 0xad, 0xde, 0, 0, 0, 0, 0, 0, 224, 63, 0xad, 0xde, 0xef, 0xbe, 0xaa, 0xbb,
0xcc, 0xdd, 0xee,
];
let data: Data = bytes.pread_with(0, LE).unwrap();
println!("data: {data:?}");
assert_eq!(data.id, 0xdeadbeefu32);
assert_eq!(data.arr, [0xadde, 0xefbe]);
let mut bytes2 = vec![0; ::std::mem::size_of::<Data>()];
bytes2.pwrite_with(data, 0, LE).unwrap();
let data: Data = bytes.pread_with(0, LE).unwrap();
let data2: Data = bytes2.pread_with(0, LE).unwrap();
assert_eq!(data, data2);
// Not enough bytes because of ctx dependent length being too long.
assert!(bytes.pread_with::<Data>(0, BE).is_err())
}
116 changes: 93 additions & 23 deletions scroll_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

extern crate proc_macro;
use proc_macro2;
use quote::quote;
use quote::{quote, ToTokens};

use proc_macro::TokenStream;

fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream {
fn impl_field(
ident: &proc_macro2::TokenStream,
ty: &syn::Type,
custom_ctx: Option<&proc_macro2::TokenStream>,
) -> proc_macro2::TokenStream {
let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(&default_ctx);
match *ty {
syn::Type::Array(ref array) => match array.len {
syn::Expr::Lit(syn::ExprLit {
Expand All @@ -15,20 +21,63 @@ fn impl_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::
}) => {
let size = int.base10_parse::<usize>().unwrap();
quote! {
#ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, ctx)?; __tmp }
#ident: { let mut __tmp: #ty = [0u8.into(); #size]; src.gread_inout_with(offset, &mut __tmp, #ctx)?; __tmp }
}
}
_ => panic!("Pread derive with bad array constexpr"),
},
syn::Type::Group(ref group) => impl_field(ident, &group.elem),
syn::Type::Group(ref group) => impl_field(ident, &group.elem, custom_ctx),
_ => {
quote! {
#ident: src.gread_with::<#ty>(offset, ctx)?
#ident: src.gread_with::<#ty>(offset, #ctx)?
}
}
}
}

/// Retrieve the field attribute with given ident e.g:
/// ```ignore
/// #[attr_ident(..)]
/// field: T,
/// ```
fn get_attr<'a>(attr_ident: &str, field: &'a syn::Field) -> Option<&'a syn::Attribute> {
field
.attrs
.iter()
.filter(|attr| attr.path().is_ident(attr_ident))
.next()
}

/// Gets the `TokenStream` for the custom ctx set in the `ctx` attribute. e.g. `expr` in the following
/// ```ignore
/// #[scroll(ctx = expr)]
/// field: T,
/// ```
fn custom_ctx(field: &syn::Field) -> Option<proc_macro2::TokenStream> {
get_attr("scroll", field).and_then(|x| {
// parsed #[scroll..]
// `expr` is `None` if the `ctx` key is not used.
let mut expr = None;
let res = x.parse_nested_meta(|meta| {
// parsed #[scroll(..)]
if meta.path.is_ident("ctx") {
// parsed #[scroll(ctx..)]
let value = meta.value()?; // parsed #[scroll(ctx = ..)]
expr = Some(value.parse::<syn::Expr>()?.into_token_stream()); // parsed #[scroll(ctx = expr)]
return Ok(());
}
Err(meta.error(match meta.path.get_ident() {
Some(ident) => format!("unrecognized attribute: {ident}"),
None => "unrecognized and invalid attribute".to_owned(),
}))
});
match res {
Ok(()) => expr,
Err(e) => Some(e.into_compile_error()),
}
})
}

fn impl_struct(
name: &syn::Ident,
fields: &syn::punctuated::Punctuated<syn::Field, syn::Token![,]>,
Expand All @@ -43,7 +92,9 @@ fn impl_struct(
quote! {#t}
});
let ty = &f.ty;
impl_field(ident, ty)
// parse the `expr` out of #[scroll(ctx = expr)]
let custom_ctx = custom_ctx(f);
impl_field(ident, ty, custom_ctx.as_ref())
})
.collect();

Expand Down Expand Up @@ -104,14 +155,20 @@ fn impl_try_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(Pread)]
#[proc_macro_derive(Pread, attributes(scroll))]
pub fn derive_pread(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_try_from_ctx(&ast);
gen.into()
}

fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_macro2::TokenStream {
fn impl_pwrite_field(
ident: &proc_macro2::TokenStream,
ty: &syn::Type,
custom_ctx: Option<&proc_macro2::TokenStream>,
) -> proc_macro2::TokenStream {
let default_ctx = syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(&default_ctx);
match ty {
syn::Type::Array(ref array) => match array.len {
syn::Expr::Lit(syn::ExprLit {
Expand All @@ -121,24 +178,24 @@ fn impl_pwrite_field(ident: &proc_macro2::TokenStream, ty: &syn::Type) -> proc_m
let size = int.base10_parse::<usize>().unwrap();
quote! {
for i in 0..#size {
dst.gwrite_with(&self.#ident[i], offset, ctx)?;
dst.gwrite_with(&self.#ident[i], offset, #ctx)?;
}
}
}
_ => panic!("Pwrite derive with bad array constexpr"),
},
syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem),
syn::Type::Group(group) => impl_pwrite_field(ident, &group.elem, custom_ctx),
syn::Type::Reference(reference) => match *reference.elem {
syn::Type::Slice(_) => quote! {
dst.gwrite_with(self.#ident, offset, ())?
},
_ => quote! {
dst.gwrite_with(self.#ident, offset, ctx)?
dst.gwrite_with(self.#ident, offset, #ctx)?
},
},
_ => {
quote! {
dst.gwrite_with(&self.#ident, offset, ctx)?
dst.gwrite_with(&self.#ident, offset, #ctx)?
}
}
}
Expand All @@ -158,7 +215,8 @@ fn impl_try_into_ctx(
quote! {#t}
});
let ty = &f.ty;
impl_pwrite_field(ident, ty)
let custom_ctx = custom_ctx(f);
impl_pwrite_field(ident, ty, custom_ctx.as_ref())
})
.collect();

Expand Down Expand Up @@ -249,7 +307,7 @@ fn impl_pwrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(Pwrite)]
#[proc_macro_derive(Pwrite, attributes(scroll))]
pub fn derive_pwrite(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_pwrite(&ast);
Expand All @@ -265,6 +323,10 @@ fn size_with(
.iter()
.map(|f| {
let ty = &f.ty;
let custom_ctx = custom_ctx(f).map(|x| quote! {&#x});
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let elem = &array.elem;
Expand All @@ -275,15 +337,15 @@ fn size_with(
}) => {
let size = int.base10_parse::<usize>().unwrap();
quote! {
(#size * <#elem>::size_with(ctx))
(#size * <#elem>::size_with(#ctx))
}
}
_ => panic!("Pread derive with bad array constexpr"),
}
}
_ => {
quote! {
<#ty>::size_with(ctx)
<#ty>::size_with(#ctx)
}
}
}
Expand Down Expand Up @@ -341,7 +403,7 @@ fn impl_size_with(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(SizeWith)]
#[proc_macro_derive(SizeWith, attributes(scroll))]
pub fn derive_sizewith(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_size_with(&ast);
Expand All @@ -356,6 +418,10 @@ fn impl_cread_struct(
let items: Vec<_> = fields.iter().enumerate().map(|(i, f)| {
let ident = &f.ident.as_ref().map(|i|quote!{#i}).unwrap_or({let t = proc_macro2::Literal::usize_unsuffixed(i); quote!{#t}});
let ty = &f.ty;
let custom_ctx = custom_ctx(f);
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let arrty = &array.elem;
Expand All @@ -367,7 +433,7 @@ fn impl_cread_struct(
#ident: {
let mut __tmp: #ty = [0u8.into(); #size];
for i in 0..__tmp.len() {
__tmp[i] = src.cread_with(*offset, ctx);
__tmp[i] = src.cread_with(*offset, #ctx);
*offset += #incr;
}
__tmp
Expand All @@ -380,7 +446,7 @@ fn impl_cread_struct(
_ => {
let size = quote! { ::scroll::export::mem::size_of::<#ty>() };
quote! {
#ident: { let res = src.cread_with::<#ty>(*offset, ctx); *offset += #size; res }
#ident: { let res = src.cread_with::<#ty>(*offset, #ctx); *offset += #size; res }
}
}
}
Expand Down Expand Up @@ -440,7 +506,7 @@ fn impl_from_ctx(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(IOread)]
#[proc_macro_derive(IOread, attributes(scroll))]
pub fn derive_ioread(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_from_ctx(&ast);
Expand All @@ -462,20 +528,24 @@ fn impl_into_ctx(
});
let ty = &f.ty;
let size = quote! { ::scroll::export::mem::size_of::<#ty>() };
let custom_ctx = custom_ctx(f);
let default_ctx =
syn::Ident::new("ctx", proc_macro2::Span::call_site()).into_token_stream();
let ctx = custom_ctx.unwrap_or(default_ctx);
match *ty {
syn::Type::Array(ref array) => {
let arrty = &array.elem;
quote! {
let size = ::scroll::export::mem::size_of::<#arrty>();
for i in 0..self.#ident.len() {
dst.cwrite_with(self.#ident[i], *offset, ctx);
dst.cwrite_with(self.#ident[i], *offset, #ctx);
*offset += size;
}
}
}
_ => {
quote! {
dst.cwrite_with(self.#ident, *offset, ctx);
dst.cwrite_with(self.#ident, *offset, #ctx);
*offset += #size;
}
}
Expand Down Expand Up @@ -544,7 +614,7 @@ fn impl_iowrite(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(IOwrite)]
#[proc_macro_derive(IOwrite, attributes(scroll))]
pub fn derive_iowrite(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let gen = impl_iowrite(&ast);
Expand Down
Loading

0 comments on commit 01d8722

Please sign in to comment.