Skip to content

Commit

Permalink
Fix/wgpu/tanh (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Dec 21, 2023
1 parent d82e6b1 commit b070706
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
20 changes: 20 additions & 0 deletions burn-wgpu/src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ use std::fmt::Display;
pub enum Function {
Powf(Elem),
Erf(Elem),
#[cfg(target_os = "macos")]
SafeTanh(Elem),
}

impl Display for Function {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Function::Powf(elem) => format_powf(f, elem),
Function::Erf(elem) => format_erf(f, elem),
#[cfg(target_os = "macos")]
Function::SafeTanh(elem) => format_safe_tanh(f, elem),
}
}
}
Expand Down Expand Up @@ -69,3 +73,19 @@ fn erf(x: {elem}) -> {elem} {{
"
))
}

#[cfg(target_os = "macos")]
fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
f.write_fmt(format_args!(
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
fn safe_tanh(x: {elem}) -> {elem} {{
if x > 43.0 {{
return 1.0;
}} else {{
return tanh(x);
}}
}}
"
))
}
4 changes: 4 additions & 0 deletions burn-wgpu/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ impl ElemWiseKernelCodegen<BodyPhase> {
Operator::Erf { input: _, out: _ } => {
register_function(Function::Erf(Elem::F32));
}
#[cfg(target_os = "macos")]
Operator::Tanh { input: _, out: _ } => {
register_function(Function::SafeTanh(Elem::F32))
}
_ => {}
}
self.operations.push(ops.clone());
Expand Down
7 changes: 6 additions & 1 deletion burn-wgpu/src/codegen/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ impl Display for Operator {
Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
Operator::Tanh { input, out } => {
f.write_fmt(format_args!("let {out} = tanh({input});"))
#[cfg(target_os = "macos")]
let result = f.write_fmt(format_args!("let {out} = safe_tanh({input});"));
#[cfg(not(target_os = "macos"))]
let result = f.write_fmt(format_args!("let {out} = tanh({input});"));

result
}
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
Operator::Recip { input, out } => {
Expand Down

0 comments on commit b070706

Please sign in to comment.