From b07070631005a5ca69b558fcdaff190afa6fab23 Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Thu, 21 Dec 2023 14:12:49 -0500 Subject: [PATCH] Fix/wgpu/tanh (#1090) --- burn-wgpu/src/codegen/function.rs | 20 ++++++++++++++++++++ burn-wgpu/src/codegen/kernel.rs | 4 ++++ burn-wgpu/src/codegen/operator.rs | 7 ++++++- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/burn-wgpu/src/codegen/function.rs b/burn-wgpu/src/codegen/function.rs index ceddb95e06..e4540f4ace 100644 --- a/burn-wgpu/src/codegen/function.rs +++ b/burn-wgpu/src/codegen/function.rs @@ -6,6 +6,8 @@ use std::fmt::Display; pub enum Function { Powf(Elem), Erf(Elem), + #[cfg(target_os = "macos")] + SafeTanh(Elem), } impl Display for Function { @@ -13,6 +15,8 @@ impl Display for Function { match self { Function::Powf(elem) => format_powf(f, elem), Function::Erf(elem) => format_erf(f, elem), + #[cfg(target_os = "macos")] + Function::SafeTanh(elem) => format_safe_tanh(f, elem), } } } @@ -69,3 +73,19 @@ fn erf(x: {elem}) -> {elem} {{ " )) } + +#[cfg(target_os = "macos")] +fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, elem: &Elem) -> core::fmt::Result { + f.write_fmt(format_args!( + " +/// Metal has a weird numerical behaviour with tanh for inputs over 43.0 +fn safe_tanh(x: {elem}) -> {elem} {{ + if x > 43.0 {{ + return 1.0; + }} else {{ + return tanh(x); + }} +}} +" + )) +} diff --git a/burn-wgpu/src/codegen/kernel.rs b/burn-wgpu/src/codegen/kernel.rs index 6f0c2addde..f233776cdb 100644 --- a/burn-wgpu/src/codegen/kernel.rs +++ b/burn-wgpu/src/codegen/kernel.rs @@ -162,6 +162,10 @@ impl ElemWiseKernelCodegen { Operator::Erf { input: _, out: _ } => { register_function(Function::Erf(Elem::F32)); } + #[cfg(target_os = "macos")] + Operator::Tanh { input: _, out: _ } => { + register_function(Function::SafeTanh(Elem::F32)) + } _ => {} } self.operations.push(ops.clone()); diff --git a/burn-wgpu/src/codegen/operator.rs b/burn-wgpu/src/codegen/operator.rs index a24b32c79e..ea8375b611 100644 --- a/burn-wgpu/src/codegen/operator.rs +++ b/burn-wgpu/src/codegen/operator.rs @@ -164,7 +164,12 @@ impl Display for Operator { Operator::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")), Operator::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")), Operator::Tanh { input, out } => { - f.write_fmt(format_args!("let {out} = tanh({input});")) + #[cfg(target_os = "macos")] + let result = f.write_fmt(format_args!("let {out} = safe_tanh({input});")); + #[cfg(not(target_os = "macos"))] + let result = f.write_fmt(format_args!("let {out} = tanh({input});")); + + result } Operator::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")), Operator::Recip { input, out } => {