Skip to content

Commit

Permalink
[NPU] add FLAGS_npu_storage_format env to enable npu storage format, …
Browse files Browse the repository at this point in the history
…test=develop (#48774)
  • Loading branch information
qili93 authored Dec 7, 2022
1 parent c6a2b0f commit e5bc2ee
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 4 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,18 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor",
"Choose default funciton type in JitLayer.");

#ifdef PADDLE_WITH_CUSTOM_DEVICE
/**
* Custom Device NPU related FLAG
* Name: FLAGS_npu_storage_format
* Since Version: 2.5.0
* Value Range: bool, default=false
* Example:
* Note: Enable NPU Storage Format for Ascend910 performance improvement.
*/
PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, "");
#endif

#ifdef PADDLE_WITH_CUDNN_FRONTEND
/**
* CUDNNv8 related FLAG
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import inspect
import numpy as np
import warnings
Expand Down Expand Up @@ -379,7 +380,11 @@ def gradient(self):

new_ivar = self._grad_ivar()
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type():
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
new_ivar = new_ivar._copy_to(core.CPUPlace(), True)
if self._grad_ivar().type == core.VarDesc.VarType.SELECTED_ROWS:
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/nn/functional/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode
from paddle.device import (
get_all_custom_device_type,
Expand Down Expand Up @@ -149,7 +151,11 @@ def _conv_nd(
new_shape[channel_dim] = -1
bias = bias.reshape(new_shape)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type():
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
bias_storage = _C_ops.npu_identity(
bias, 3
Expand Down Expand Up @@ -747,7 +753,11 @@ def conv2d(
+ [1 for i in range(len(x.shape) - channel_dim - 1)],
)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type():
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
bias_storage = _C_ops.npu_identity(
bias, 3
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# TODO: define normalization api

import numbers
import os
import warnings

import numpy as np
Expand Down Expand Up @@ -681,7 +682,11 @@ def __init__(
self._variance.stop_gradient = True

# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if 'npu' in get_all_custom_device_type():
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
weight_trans = _C_ops.npu_identity(
self.weight, 3
Expand Down

0 comments on commit e5bc2ee

Please sign in to comment.