Skip to content

Commit

Permalink
Fix visibility of impl fns getting dropped by the cube macro (tracel-…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbelanich authored Oct 29, 2024
1 parent 0481ca3 commit 1226222
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 6 deletions.
23 changes: 23 additions & 0 deletions crates/cubecl-core/tests/frontend/cube_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,26 @@ impl<C: Numeric> ComplexType<C, f32> {
lhs * f32::cast_from(rhs)
}
}

mod foo {
use super::*;

#[derive(CubeType)]
pub struct TypeInModule {
pub a: u32,
}

#[cube]
impl TypeInModule {
#[allow(dead_code)]
pub fn simple_method(&self, lhs: u32) -> u32 {
self.a * lhs
}
}
}

#[cube]
fn call_from_outside_module() {
let bar = foo::TypeInModule { a: 0u32 };
let _ = bar.simple_method(5);
}
3 changes: 2 additions & 1 deletion crates/cubecl-macros/src/generate/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ use crate::{
impl KernelFn {
pub fn to_tokens_mut(&mut self) -> TokenStream {
let prelude_path = prelude_path();
let vis = &self.vis;
let sig = &self.sig;
let body = match &self.body {
KernelBody::Block(block) => &block.to_tokens(&mut self.context),
KernelBody::Verbatim(tokens) => tokens,
};

let out = quote! {
#sig {
#vis #sig {
use #prelude_path::IntoRuntime as _;

#body
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-macros/src/generate/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl ToTokens for Launch {
use super::*;

#[allow(unused, clippy::all)]
pub #func
#func

#kernel
#launch
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-macros/src/parse/cube_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl CubeImplItem {
pub fn from_impl_item(struct_ty_name: &Type, item: ImplItem) -> syn::Result<Vec<Self>> {
let res = match item {
ImplItem::Fn(func) => {
let mut func = KernelFn::from_sig_and_block(func.sig, func.block)?;
let mut func = KernelFn::from_sig_and_block(func.vis, func.sig, func.block)?;
let func_name_expand = format_ident!("__expand_{}", func.sig.name);

let is_method = func
Expand Down Expand Up @@ -123,6 +123,7 @@ impl CubeImplItem {
core::mem::swap(&mut func.body, &mut body);

KernelFn {
vis: func.vis.clone(),
sig: method_sig,
body,
context: Context::new(func.context.return_type.clone()),
Expand Down Expand Up @@ -172,6 +173,7 @@ impl CubeImplItem {
};

KernelFn {
vis: func.vis.clone(),
sig: func_sig,
body: KernelBody::Verbatim(body),
context: Context::new(func.context.return_type.clone()),
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-macros/src/parse/cube_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl CubeTraitImplItem {
pub fn from_impl_item(item: ImplItem) -> syn::Result<Self> {
let res = match item {
ImplItem::Fn(func) => {
let mut func = KernelFn::from_sig_and_block(func.sig, func.block)?;
let mut func = KernelFn::from_sig_and_block(func.vis, func.sig, func.block)?;
func.sig.name = format_ident!("__expand_{}", func.sig.name);
CubeTraitImplItem::Fn(func)
}
Expand Down
18 changes: 16 additions & 2 deletions crates/cubecl-macros/src/parse/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub struct Launch {

#[derive(Clone)]
pub struct KernelFn {
pub vis: Visibility,
pub sig: KernelSignature,
pub body: KernelBody,
pub context: Context,
Expand Down Expand Up @@ -207,7 +208,11 @@ impl KernelSignature {
}

impl KernelFn {
pub fn from_sig_and_block(sig: Signature, mut block: syn::Block) -> syn::Result<Self> {
pub fn from_sig_and_block(
vis: Visibility,
sig: Signature,
mut block: syn::Block,
) -> syn::Result<Self> {
let sig = KernelSignature::from_signature(sig)?;
Desugar.visit_block_mut(&mut block);

Expand All @@ -216,6 +221,7 @@ impl KernelFn {
let (block, _) = context.in_scope(|ctx| Block::from_block(block, ctx))?;

Ok(KernelFn {
vis,
sig,
body: KernelBody::Block(block),
context,
Expand All @@ -228,7 +234,15 @@ impl Launch {
let runtime = prelude_type("Runtime");

let vis = function.vis;
let func = KernelFn::from_sig_and_block(function.sig, *function.block)?;
let func = KernelFn::from_sig_and_block(
// When generating code, this function will be wrapped in
// a module. By setting the visibility to pub here, we
// ensure that the function is visibile outside that
// module.
Visibility::Public(parse_quote![pub]),
function.sig,
*function.block,
)?;
let mut kernel_generics = func.sig.generics.clone();
kernel_generics.params.push(parse_quote![__R: #runtime]);
let mut expand_generics = kernel_generics.clone();
Expand Down

0 comments on commit 1226222

Please sign in to comment.