diff --git a/Cargo.lock b/Cargo.lock index 4accb73ffd..edfe30c69d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -712,6 +712,9 @@ dependencies = [ name = "axvm-native-compiler-derive" version = "0.1.0" dependencies = [ + "axvm-native-compiler", + "axvm-recursion", + "p3-field", "proc-macro2", "quote", "syn 1.0.109", diff --git a/lib/recursion/src/hints.rs b/lib/recursion/src/hints.rs index aee24ab11c..2656304385 100644 --- a/lib/recursion/src/hints.rs +++ b/lib/recursion/src/hints.rs @@ -475,15 +475,10 @@ mod test { use axvm_native_compiler::{ asm::AsmBuilder, ir::{Ext, Felt, Var}, - prelude::*, }; - use axvm_native_compiler_derive::{DslVariable, Hintable}; use p3_field::AbstractField; - use crate::{ - hints::{Hintable, InnerChallenge, InnerVal}, - types::InnerConfig, - }; + use crate::hints::{Hintable, InnerChallenge, InnerVal}; #[test] fn test_var_array() { @@ -559,23 +554,4 @@ mod test { let program = builder.compile_isa(); execute_program(program, stream); } - - #[derive(Hintable)] - struct TestStruct { - a: usize, - b: usize, - c: usize, - } - - #[test] - fn test_macro() { - let x = TestStruct { a: 1, b: 2, c: 3 }; - let stream = Hintable::::write(&x); - assert_eq!( - stream, - [1, 2, 3] - .map(|x| vec![InnerVal::from_canonical_usize(x)]) - .to_vec() - ); - } } diff --git a/toolchain/native-compiler/derive/Cargo.toml b/toolchain/native-compiler/derive/Cargo.toml index 2d7f703e0f..8c08a6f6fa 100644 --- a/toolchain/native-compiler/derive/Cargo.toml +++ b/toolchain/native-compiler/derive/Cargo.toml @@ -12,3 +12,8 @@ proc-macro = true syn = { version = "1.0", features = ["parsing"] } quote = "1.0" proc-macro2 = "1.0" + +[dev-dependencies] +axvm-native-compiler = { workspace = true } +axvm-recursion = { workspace = true } +p3-field = { workspace = true } \ No newline at end of file diff --git a/toolchain/native-compiler/derive/src/hints.rs b/toolchain/native-compiler/derive/src/hints.rs index d7ec716911..43d05d8b60 100644 --- a/toolchain/native-compiler/derive/src/hints.rs +++ b/toolchain/native-compiler/derive/src/hints.rs @@ -4,7 +4,6 @@ use syn::ItemStruct; pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result { let name = &ast.ident; - let name_prefix = name.to_string(); let name_var = format!("{}Var", name_prefix); let name_var_ident = Ident::new(&name_var, Span::call_site()); @@ -28,18 +27,18 @@ pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result} + }; let input_struct_tokens: Vec<_> = field_names .iter() .zip(field_types.iter()) .map(|(name, field_type)| { quote! { - pub #name: <#field_type as Hintable >::HintVariable, + pub #name: <#field_type as axvm_recursion::hints::Hintable >::HintVariable, } }) .collect(); @@ -49,7 +48,7 @@ pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result>::read(builder); + let #name = <#field_type as axvm_recursion::hints::Hintable>::read(builder); } }) .collect(); @@ -58,21 +57,21 @@ pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result::write(&self.#name)); + stream.extend(axvm_recursion::hints::Hintable::::write(&self.#name)); } }) .collect(); Ok(quote! { - #[derive(DslVariable, Debug, Clone)] - pub struct #name_var_ident { + #[derive(axvm_native_compiler_derive::DslVariable, Debug, Clone)] + pub struct #name_var_ident #impl_generics { #(#input_struct_tokens)* } - impl Hintable for #name { + impl #impl_generics axvm_recursion::hints::Hintable for #name #ty_generics #where_clause { type HintVariable = #name_var_ident; - fn read(builder: &mut Builder) -> Self::HintVariable { + fn read(builder: &mut axvm_native_compiler::prelude::Builder) -> Self::HintVariable { #(#read_tokens)* #name_var_ident { @@ -80,7 +79,7 @@ pub fn create_new_struct_and_impl_hintable(ast: &ItemStruct) -> Result Vec::N>> { + fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); #(#write_tokens)* diff --git a/toolchain/native-compiler/derive/src/lib.rs b/toolchain/native-compiler/derive/src/lib.rs index a696bcc50c..a7bfb9d7d4 100644 --- a/toolchain/native-compiler/derive/src/lib.rs +++ b/toolchain/native-compiler/derive/src/lib.rs @@ -5,16 +5,16 @@ extern crate proc_macro; use hints::create_new_struct_and_impl_hintable; use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, ItemStruct, TypeParamBound}; +use syn::{ + parse_macro_input, Data, DeriveInput, Fields, GenericParam, Generics, ItemStruct, + TypeParamBound, +}; mod hints; -#[proc_macro_derive(DslVariable)] -pub fn derive_variable(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let name = input.ident; // Struct name - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let has_config_generic = input.generics.params.iter().any(|param| match param { +/// Returns true if the generic parameter C: Config exists. +pub(crate) fn has_config_generic(generics: &Generics) -> bool { + generics.params.iter().any(|param| match param { GenericParam::Type(ty) => { ty.ident == "C" && ty.bounds.iter().any(|b| match b { @@ -23,9 +23,16 @@ pub fn derive_variable(input: TokenStream) -> TokenStream { }) } _ => false, - }); + }) +} + +#[proc_macro_derive(DslVariable)] +pub fn derive_variable(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = input.ident; // Struct name + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); assert!( - has_config_generic, + has_config_generic(&input.generics), "DslVariable requires a generic parameter C: Config" ); diff --git a/toolchain/native-compiler/derive/tests/hintable.rs b/toolchain/native-compiler/derive/tests/hintable.rs new file mode 100644 index 0000000000..dd1ad3ed32 --- /dev/null +++ b/toolchain/native-compiler/derive/tests/hintable.rs @@ -0,0 +1,23 @@ +use axvm_native_compiler::prelude::*; +use axvm_native_compiler_derive::Hintable; +use axvm_recursion::{hints::InnerVal, types::InnerConfig}; +use p3_field::AbstractField; + +#[derive(Hintable)] +struct TestStruct { + a: usize, + b: usize, + c: usize, +} + +#[test] +fn test_macro() { + let x = TestStruct { a: 1, b: 2, c: 3 }; + let stream = axvm_recursion::hints::Hintable::::write(&x); + assert_eq!( + stream, + [1, 2, 3] + .map(|x| vec![InnerVal::from_canonical_usize(x)]) + .to_vec() + ); +}