diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index e8800cedb376d9..b809e026544bb8 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -1041,6 +1041,18 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type, "Predictor", "Choose default funciton type in JitLayer."); +#ifdef PADDLE_WITH_CUSTOM_DEVICE +/** + * Custom Device NPU related FLAG + * Name: FLAGS_npu_storage_format + * Since Version: 2.5.0 + * Value Range: bool, default=false + * Example: + * Note: Enable NPU Storage Format for Ascend910 performance improvement. + */ +PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, ""); +#endif + #ifdef PADDLE_WITH_CUDNN_FRONTEND /** * CUDNNv8 related FLAG diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index e9b963a781db9f..0d2cd0cbf2db0c 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import inspect import numpy as np import warnings @@ -379,7 +380,11 @@ def gradient(self): new_ivar = self._grad_ivar() # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op - if 'npu' in get_all_custom_device_type(): + if ( + os.environ.get('FLAGS_npu_storage_format', None) + in [1, '1', True, 'True', 'true'] + and 'npu' in get_all_custom_device_type() + ): new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1) new_ivar = new_ivar._copy_to(core.CPUPlace(), True) if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS: diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index d29f91d035f288..face92190c0f52 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode from paddle.device import ( get_all_custom_device_type, @@ -149,7 +151,11 @@ def _conv_nd( new_shape[channel_dim] = -1 bias = bias.reshape(new_shape) # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op - if 'npu' in get_all_custom_device_type(): + if ( + os.environ.get('FLAGS_npu_storage_format', None) + in [1, '1', True, 'True', 'true'] + and 'npu' in get_all_custom_device_type() + ): with no_grad(): bias_storage = _C_ops.npu_identity( bias, 3 @@ -747,7 +753,11 @@ def conv2d( + [1 for i in range(len(x.shape) - channel_dim - 1)], ) # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op - if 'npu' in get_all_custom_device_type(): + if ( + os.environ.get('FLAGS_npu_storage_format', None) + in [1, '1', True, 'True', 'true'] + and 'npu' in get_all_custom_device_type() + ): with no_grad(): bias_storage = _C_ops.npu_identity( bias, 3 diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 64f9f8913313de..c0117560f25e2a 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -28,6 +28,7 @@ # TODO: define normalization api import numbers +import os import warnings import numpy as np @@ -681,7 +682,11 @@ def __init__( self._variance.stop_gradient = True # TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op - if 'npu' in get_all_custom_device_type(): + if ( + os.environ.get('FLAGS_npu_storage_format', None) + in [1, '1', True, 'True', 'true'] + and 'npu' in get_all_custom_device_type() + ): with no_grad(): weight_trans = _C_ops.npu_identity( self.weight, 3