Skip to content

Commit

Permalink
Improve Hintable struct to support generics (#742)
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu authored Nov 1, 2024
1 parent f5bac96 commit 09dad3b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 49 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 1 addition & 25 deletions lib/recursion/src/hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,15 +475,10 @@ mod test {
use axvm_native_compiler::{
asm::AsmBuilder,
ir::{Ext, Felt, Var},
prelude::*,
};
use axvm_native_compiler_derive::{DslVariable, Hintable};
use p3_field::AbstractField;

use crate::{
hints::{Hintable, InnerChallenge, InnerVal},
types::InnerConfig,
};
use crate::hints::{Hintable, InnerChallenge, InnerVal};

#[test]
fn test_var_array() {
Expand Down Expand Up @@ -559,23 +554,4 @@ mod test {
let program = builder.compile_isa();
execute_program(program, stream);
}

#[derive(Hintable)]
struct TestStruct {
a: usize,
b: usize,
c: usize,
}

#[test]
fn test_macro() {
let x = TestStruct { a: 1, b: 2, c: 3 };
let stream = Hintable::<InnerConfig>::write(&x);
assert_eq!(
stream,
[1, 2, 3]
.map(|x| vec![InnerVal::from_canonical_usize(x)])
.to_vec()
);
}
}
5 changes: 5 additions & 0 deletions toolchain/native-compiler/derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ proc-macro = true
syn = { version = "1.0", features = ["parsing"] }
quote = "1.0"
proc-macro2 = "1.0"

[dev-dependencies]
axvm-native-compiler = { workspace = true }
axvm-recursion = { workspace = true }
p3-field = { workspace = true }
29 changes: 14 additions & 15 deletions toolchain/native-compiler/derive/src/hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use syn::ItemStruct;

pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result<TokenStream, TokenStream> {
let name = &ast.ident;

let name_prefix = name.to_string();
let name_var = format!("{}Var", name_prefix);
let name_var_ident = Ident::new(&name_var, Span::call_site());
Expand All @@ -28,18 +27,18 @@ pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result<TokenStre
})
.collect();

let existing_generics = ast.generics.clone();
if !existing_generics.params.is_empty() {
return Err(quote! {
compile_error!("Hintable macro only supports structs with no generics for now");
});
}
let (_, ty_generics, where_clause) = ast.generics.split_for_impl();

let impl_generics = {
let params = &ast.generics.params;
quote! { < C: axvm_native_compiler::prelude::Config, #params >}
};
let input_struct_tokens: Vec<_> = field_names
.iter()
.zip(field_types.iter())
.map(|(name, field_type)| {
quote! {
pub #name: <#field_type as Hintable<C> >::HintVariable,
pub #name: <#field_type as axvm_recursion::hints::Hintable<C> >::HintVariable,
}
})
.collect();
Expand All @@ -49,7 +48,7 @@ pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result<TokenStre
.zip(field_types.iter())
.map(|(name, field_type)| {
quote! {
let #name = <#field_type as Hintable<C>>::read(builder);
let #name = <#field_type as axvm_recursion::hints::Hintable<C>>::read(builder);
}
})
.collect();
Expand All @@ -58,29 +57,29 @@ pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result<TokenStre
.iter()
.map(|name| {
quote! {
stream.extend(Hintable::<C>::write(&self.#name));
stream.extend(axvm_recursion::hints::Hintable::<C>::write(&self.#name));
}
})
.collect();

Ok(quote! {
#[derive(DslVariable, Debug, Clone)]
pub struct #name_var_ident <C: Config> {
#[derive(axvm_native_compiler_derive::DslVariable, Debug, Clone)]
pub struct #name_var_ident #impl_generics {
#(#input_struct_tokens)*
}

impl<C: Config> Hintable<C> for #name {
impl #impl_generics axvm_recursion::hints::Hintable<C> for #name #ty_generics #where_clause {
type HintVariable = #name_var_ident<C>;

fn read(builder: &mut Builder<C>) -> Self::HintVariable {
fn read(builder: &mut axvm_native_compiler::prelude::Builder<C>) -> Self::HintVariable {
#(#read_tokens)*

#name_var_ident {
#(#field_names,)*
}
}

fn write(&self) -> Vec<Vec<<C as Config>::N>> {
fn write(&self) -> Vec<Vec<<C as axvm_native_compiler::prelude::Config>::N>> {
let mut stream = Vec::new();

#(#write_tokens)*
Expand Down
25 changes: 16 additions & 9 deletions toolchain/native-compiler/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ extern crate proc_macro;
use hints::create_new_struct_and_impl_hintable;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, ItemStruct, TypeParamBound};
use syn::{
parse_macro_input, Data, DeriveInput, Fields, GenericParam, Generics, ItemStruct,
TypeParamBound,
};

mod hints;

#[proc_macro_derive(DslVariable)]
pub fn derive_variable(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident; // Struct name
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let has_config_generic = input.generics.params.iter().any(|param| match param {
/// Returns true if the generic parameter C: Config exists.
pub(crate) fn has_config_generic(generics: &Generics) -> bool {
generics.params.iter().any(|param| match param {
GenericParam::Type(ty) => {
ty.ident == "C"
&& ty.bounds.iter().any(|b| match b {
Expand All @@ -23,9 +23,16 @@ pub fn derive_variable(input: TokenStream) -> TokenStream {
})
}
_ => false,
});
})
}

#[proc_macro_derive(DslVariable)]
pub fn derive_variable(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident; // Struct name
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
assert!(
has_config_generic,
has_config_generic(&input.generics),
"DslVariable requires a generic parameter C: Config"
);

Expand Down
23 changes: 23 additions & 0 deletions toolchain/native-compiler/derive/tests/hintable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use axvm_native_compiler::prelude::*;
use axvm_native_compiler_derive::Hintable;
use axvm_recursion::{hints::InnerVal, types::InnerConfig};
use p3_field::AbstractField;

#[derive(Hintable)]
struct TestStruct {
a: usize,
b: usize,
c: usize,
}

#[test]
fn test_macro() {
let x = TestStruct { a: 1, b: 2, c: 3 };
let stream = axvm_recursion::hints::Hintable::<InnerConfig>::write(&x);
assert_eq!(
stream,
[1, 2, 3]
.map(|x| vec![InnerVal::from_canonical_usize(x)])
.to_vec()
);
}

0 comments on commit 09dad3b

Please sign in to comment.