diff --git a/src/cpu/rnn/postgemm_dispatcher.hpp b/src/cpu/rnn/postgemm_dispatcher.hpp index 516b2bc77f8..9f5d05ba251 100644 --- a/src/cpu/rnn/postgemm_dispatcher.hpp +++ b/src/cpu/rnn/postgemm_dispatcher.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2019-2023 Intel Corporation +* Copyright 2019-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,8 +45,8 @@ namespace dnnl { namespace impl { namespace cpu { -template -float activation(float s, float alpha, float cliping); +float activation(alg_kind_t alg_kind, prop_kind_t prop_kind, float s, + float alpha, float cliping); template @@ -88,22 +88,6 @@ struct rnn_postgemm_dispatcher { break; case alg_kind::vanilla_rnn: postgemm_func = &class_name::rnn_postgemm; - switch (pd->activation_kind()) { - case alg_kind::eltwise_relu: - activation_func - = &activation; - break; - case alg_kind::eltwise_tanh: - activation_func - = &activation; - break; - case alg_kind::eltwise_logistic: - activation_func - = &activation; - break; - default: assert(!"Unsupported activation function"); break; - } break; case alg_kind::vanilla_gru: case alg_kind::vanilla_augru: @@ -233,7 +217,6 @@ struct rnn_postgemm_dispatcher { } protected: - float (*activation_func)(float s, float alpha, float cliping); virtual rnn_postgemm_sig(rnn_postgemm) = 0; virtual rnn_postgemm_sig(lstm_postgemm) = 0; virtual rnn_postgemm_sig(lstm_projection_postgemm) = 0; diff --git a/src/cpu/rnn/ref_postgemm_rnn.cpp b/src/cpu/rnn/ref_postgemm_rnn.cpp index 3be80385e77..51cd8892521 100644 --- a/src/cpu/rnn/ref_postgemm_rnn.cpp +++ b/src/cpu/rnn/ref_postgemm_rnn.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,40 +31,29 @@ using namespace dnnl::impl::utils; using namespace dnnl::impl::math; using namespace rnn_utils; -template <> -float activation( - float s, float alpha, float cliping) { - return relu_fwd(s, alpha); -} - -template <> -float activation( - float s, float alpha, float cliping) { - return relu_bwd(s, alpha); -} - -template <> -float activation( - float s, float alpha, float cliping) { - return tanh_fwd(s); -} - -template <> -float activation( - float s, float alpha, float cliping) { - return one_m_square(s); -} - -template <> -float activation( - float s, float alpha, float cliping) { - return logistic_fwd(s); -} - -template <> -float activation( - float s, float alpha, float cliping) { - return x_m_square(s); +float activation(alg_kind_t alg_kind, prop_kind_t prop_kind, float s, + float alpha, float cliping) { + using namespace dnnl::impl::alg_kind; + + if (prop_kind == prop_kind::forward + || prop_kind == prop_kind::forward_inference) { + switch (alg_kind) { + case eltwise_relu: return relu_fwd(s, alpha); + case eltwise_tanh: return tanh_fwd(s); + case eltwise_logistic: return logistic_fwd(s); + default: assert(!"unsupported algorithm"); + } + } else if (prop_kind == prop_kind::backward) { + switch (alg_kind) { + case eltwise_relu: return relu_bwd(s, alpha); + case eltwise_tanh: return one_m_square(s); + case eltwise_logistic: return x_m_square(s); + default: assert(!"unsupported algorithm"); + } + } else { + assert(!"unsupported propagation kind"); + } + return NAN; } constexpr float linear(float s, float alpha, float clipping) { @@ -118,7 +107,8 @@ rnn_postgemm_sig( (rnn_postgemm_fwd_t::rnn_postgemm)) { const float *scales = this->pd_->attr()->rnn_tparams_.scales_; const auto act_f = [this](float a, float alpha, float clipping) { - return gates_t(this->activation_func(a, alpha, clipping)); + return gates_t(activation(this->pd_->activation_kind(), + this->pd_->get_prop_kind(), a, alpha, clipping)); }; const auto linear_f = [](float a, float alpha, float clipping) { return gates_t(linear(a, alpha, clipping)); @@ -178,7 +168,8 @@ rnn_postgemm_sig( (rnn_postgemm_bwd_t::rnn_postgemm)) { const float *scales = this->pd_->attr()->rnn_tparams_.scales_; const auto act_f = [this](float a, float alpha, float clipping) { - return this->activation_func(a, alpha, 0); + return activation(this->pd_->activation_kind(), + this->pd_->get_prop_kind(), a, alpha, 0); }; const auto linear_f = [](float a, float alpha, float clipping) { return linear(a, alpha, 0);