Skip to content

Commit

Permalink
Revert "Fix/wgpu/tanh (tracel-ai#1090)"
Browse files Browse the repository at this point in the history
This reverts commit b070706.
  • Loading branch information
syl20bnr committed Jan 4, 2024
1 parent 8db982d commit 71230cf
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 30 deletions.
20 changes: 0 additions & 20 deletions burn-wgpu/src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ use std::fmt::Display;
pub enum Function {
Powf(Elem),
Erf(Elem),
#[cfg(target_os = "macos")]
SafeTanh(Elem),
}

impl Display for Function {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Function::Powf(elem) => format_powf(f, elem),
Function::Erf(elem) => format_erf(f, elem),
#[cfg(target_os = "macos")]
Function::SafeTanh(elem) => format_safe_tanh(f, elem),
}
}
}
Expand Down Expand Up @@ -73,19 +69,3 @@ fn erf(x: {elem}) -> {elem} {{
"
))
}

#[cfg(target_os = "macos")]
fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result {
f.write_fmt(format_args!(
"
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
fn safe_tanh(x: {elem}) -> {elem} {{
if x > 43.0 {{
return 1.0;
}} else {{
return tanh(x);
}}
}}
"
))
}
4 changes: 0 additions & 4 deletions burn-wgpu/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,6 @@ impl ElemWiseKernelCodegen<BodyPhase> {
Operator::Erf { input: _, out: _ } => {
register_function(Function::Erf(Elem::F32));
}
#[cfg(target_os = "macos")]
Operator::Tanh { input: _, out: _ } => {
register_function(Function::SafeTanh(Elem::F32))
}
_ => {}
}
self.operations.push(ops.clone());
Expand Down
7 changes: 1 addition & 6 deletions burn-wgpu/src/codegen/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,7 @@ impl Display for Operator {
Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
Operator::Tanh { input, out } => {
#[cfg(target_os = "macos")]
let result = f.write_fmt(format_args!("let {out} = safe_tanh({input});"));
#[cfg(not(target_os = "macos"))]
let result = f.write_fmt(format_args!("let {out} = tanh({input});"));

result
f.write_fmt(format_args!("let {out} = tanh({input});"))
}
Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
Operator::Recip { input, out } => {
Expand Down

0 comments on commit 71230cf

Please sign in to comment.