Skip to content

Commit

Permalink
【PIR API adaptor No.77、288】 Migrate nn.initializer.Dirac, nn.initiali…
Browse files Browse the repository at this point in the history
…zer.Orthogonal into pir (#59911)
  • Loading branch information
MarioLulab authored Dec 13, 2023
1 parent 472ad9f commit 46e3dfe
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 35 deletions.
107 changes: 79 additions & 28 deletions python/paddle/nn/initializer/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import _C_ops, in_dynamic_mode
import paddle
from paddle import _C_ops, in_dynamic_mode, pir
from paddle.utils import unique_name

from ... import base
from ...base import framework
from ...base import core, framework
from ...base.core import VarDesc
from ...base.data_feeder import check_variable_and_dtype
from ...base.framework import _current_expected_place
Expand Down Expand Up @@ -106,8 +107,8 @@ def __call__(self, var, block=None):
The most critical OP(scatter) in this initializer, which contains 7~8 ops in total.
"""
block = self._check_block(block)
assert isinstance(var, framework.Parameter)
assert isinstance(block, framework.Block)
assert isinstance(var, (framework.Variable, pir.core.ParameterMeta))
assert isinstance(block, (framework.Block, pir.Block))
check_variable_and_dtype(
var, "Out", ['float16', 'bfloat16', 'float32', 'float64'], 'Dirac'
)
Expand All @@ -121,24 +122,39 @@ def __call__(self, var, block=None):
var.shape[0] % self._groups
) == 0, "Tensor 0-dimension must be divisible by groups"

if var.dtype != VarDesc.VarType.FP32:
out_var = block.create_var(
name=unique_name.generate(".".join(['dirac', var.name, 'tmp'])),
shape=var.shape,
dtype=VarDesc.VarType.FP32,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
)
if framework.in_pir_mode():
if var.dtype != core.DataType.FLOAT32:
out_dtype = core.DataType.FLOAT32
out_var = var
else:
out_dtype = var.dtype
out_var = var
else:
out_var = var
if var.dtype != VarDesc.VarType.FP32:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(
".".join(['dirac', var.name, 'tmp'])
),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
)
else:
out_dtype = var.dtype
out_var = var

op = None
if framework.in_dygraph_mode():
with base.dygraph.no_grad():
place = _current_expected_place()
_C_ops.full_(
out_var, out_var.shape, str(float(0)), out_var.dtype, place
out_var, out_var.shape, str(float(0)), out_dtype, place
)

elif framework.in_pir_mode():
place = _current_expected_place()
out_var = _C_ops.full(out_var.shape, float(0), out_dtype, place)
else:
block.append_op(
type='fill_constant',
Expand Down Expand Up @@ -179,10 +195,12 @@ def __call__(self, var, block=None):
with base.dygraph.no_grad():
tmp_out = _C_ops.reshape(out_var, [-1])
tmp_out._share_underline_tensor_to(out_var)
elif framework.in_pir_mode():
out_var = _C_ops.reshape(out_var, [-1])
else:
x_shape = block.create_var(
name=unique_name.generate(".".join([out_var.name, "XShape"])),
dtype=out_var.dtype,
dtype=out_dtype,
shape=out_var.shape,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
Expand All @@ -196,11 +214,17 @@ def __call__(self, var, block=None):
stop_gradient=True,
)

index_tensor = block.create_var(
name=unique_name.generate('scatter_index'),
persistable=False,
stop_gradient=True,
)
if framework.in_pir_mode():
index_tensor = paddle.zeros(
[len(idx_list)], dtype=core.DataType.INT64
)
index_tensor.stop_gradient = True
else:
index_tensor = block.create_var(
name=unique_name.generate('scatter_index'),
persistable=False,
stop_gradient=True,
)

if framework.in_dygraph_mode():
with base.dygraph.no_grad():
Expand All @@ -213,6 +237,14 @@ def __call__(self, var, block=None):
_current_expected_place(),
)
tmp_tensor._share_underline_tensor_to(index_tensor)
elif framework.in_pir_mode():
_C_ops.assign_value_(
index_tensor,
[len(idx_list)],
core.DataType.INT64,
idx_list,
_current_expected_place(),
)
else:
block.append_op(
type='assign_value',
Expand All @@ -224,12 +256,17 @@ def __call__(self, var, block=None):
},
stop_gradient=True,
)

value_tensor = block.create_var(
name=unique_name.generate('scatter_value'),
persistable=False,
stop_gradient=True,
)
if framework.in_pir_mode():
value_tensor = paddle.zeros(
[len(value_list)], dtype=core.DataType.FLOAT32
)
value_tensor.stop_gradient = True
else:
value_tensor = block.create_var(
name=unique_name.generate('scatter_value'),
persistable=False,
stop_gradient=True,
)

if framework.in_dygraph_mode():
with base.dygraph.no_grad():
Expand All @@ -243,6 +280,14 @@ def __call__(self, var, block=None):
)

tmp_tensor._share_underline_tensor_to(value_tensor)
elif framework.in_pir_mode():
_C_ops.assign_value_(
value_tensor,
[len(value_list)],
core.DataType.FLOAT32,
value_list,
_current_expected_place(),
)
else:
block.append_op(
type='assign_value',
Expand All @@ -266,6 +311,12 @@ def __call__(self, var, block=None):
if var.dtype != VarDesc.VarType.FP32:
tmp_cast_out = _C_ops.cast(out_var, var.dtype)
tmp_cast_out._share_underline_tensor_to(var)
elif framework.in_pir_mode():
out_var = _C_ops.scatter(out_var, index_tensor, value_tensor, True)
out_var = _C_ops.reshape(out_var, origin_shape)
if var.dtype != core.DataType.FLOAT32:
return _C_ops.cast(out_var, var.dtype)
return out_var
else:
op = block.append_op(
type="scatter",
Expand All @@ -280,7 +331,7 @@ def __call__(self, var, block=None):
)
x_shape = block.create_var(
name=unique_name.generate(".".join([out_var.name, "XShape"])),
dtype=out_var.dtype,
dtype=out_dtype,
shape=out_var.shape,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False,
Expand Down
27 changes: 24 additions & 3 deletions python/paddle/nn/initializer/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import _C_ops
from paddle import _C_ops, pir
from paddle.utils import unique_name

from ...base import framework
Expand Down Expand Up @@ -82,8 +82,8 @@ def __call__(self, var, block=None):
The last initialization op, it contain 8 ops in orthogonal initializer.
"""
block = self._check_block(block)
assert isinstance(var, framework.Parameter)
assert isinstance(block, framework.Block)
assert isinstance(var, (framework.Variable, pir.core.ParameterMeta))
assert isinstance(block, (framework.Block, pir.Block))
self._seed = block.program.random_seed

shape = var.shape
Expand Down Expand Up @@ -122,6 +122,27 @@ def __call__(self, var, block=None):
tmp._share_underline_tensor_to(var)

return None
elif framework.in_pir_mode():
place = framework._current_expected_place()
normal_var = _C_ops.gaussian(
flatten_shape, 0.0, 1.0, self._seed, var.dtype, place
)
q, r = _C_ops.qr(normal_var, 'reduced')

r_diag = _C_ops.diag(r, 0, 0)

r_sign = _C_ops.sign(r_diag)

q = _C_ops.multiply(q, r_sign)

if row < col:
q = _C_ops.transpose(q, [1, 0])

q = _C_ops.reshape(q, var.shape)

tmp = _C_ops.scale(q, self._gain, 0.0, True)

return tmp

# 'qr' op only support float32/float64 now
check_variable_and_dtype(
Expand Down
79 changes: 75 additions & 4 deletions test/legacy_test/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.base import framework
from paddle.base.core import VarDesc
from paddle.pir_utils import test_with_pir_api
from paddle.regularizer import L2Decay

DELTA = 0.00001
Expand Down Expand Up @@ -1611,6 +1612,33 @@ def test_orthogonal(self):

self.check_result(res_dygraph, res_static)

def test_orthogonal_pir(self):
self.config()
paddle.set_default_dtype(self.dtype)

paddle.disable_static()
paddle.seed(2021)
linear = paddle.nn.Linear(
self.in_features, self.out_features, weight_attr=self.weight_attr
)
res_dygraph = linear.weight.numpy()

paddle.enable_static()
paddle.seed(2021)
start_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
linear = paddle.nn.Linear(
self.in_features,
self.out_features,
weight_attr=self.weight_attr,
)

exe = paddle.static.Executor()
res_static = exe.run(start_prog, fetch_list=[linear.weight])[0]

self.check_result(res_dygraph, res_static)


# 2-D Parameter with shape: [15, 10]
class TestOrthogonalInitializer2(TestOrthogonalInitializer1):
Expand Down Expand Up @@ -1686,6 +1714,7 @@ def check_result(self, a, b):
np.matmul(a, a.T), 9 * np.eye(6), rtol=1e-5, atol=1e-8
)

@test_with_pir_api
def test_orthogonal(self):
self.config()
paddle.set_default_dtype(self.dtype)
Expand All @@ -1705,17 +1734,18 @@ def test_orthogonal(self):
start_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
inp = paddle.rand(shape=[8, self.in_features, 10, 10])
conv2d = paddle.nn.Conv2D(
self.in_features,
self.out_features,
self.kernel_size,
weight_attr=self.weight_attr,
)
output = conv2d(inp)
exe = paddle.static.Executor()
res_static = exe.run(
paddle.static.default_startup_program(),
fetch_list=[conv2d.weight],
)[0]

exe.run(start_prog)
res_static = exe.run(main_prog, fetch_list=[conv2d.weight])[0]
self.check_result(res_dygraph, res_static)


Expand Down Expand Up @@ -1834,6 +1864,47 @@ def test_dirac(self):
weight_dygraph, weight_static, conv_input, conv_output
)

def test_dirac_pir(self):
self.config()
paddle.set_default_dtype(self.dtype)

paddle.disable_static()
conv = self.conv_layer(
self.in_channels,
self.out_channels,
self.kernel_size,
weight_attr=self.weight_attr,
)
weight_dygraph = conv.weight.numpy()

paddle.enable_static()
with paddle.pir_utils.IrGuard():
start_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
inp = paddle.rand(self.input_shape)
conv = self.conv_layer(
self.in_channels,
self.out_channels,
self.kernel_size,
weight_attr=self.weight_attr,
)

output = conv(inp)

exe = paddle.static.Executor()
exe.run(start_prog)
fetch = exe.run(
main_prog, fetch_list=[inp, output, conv.weight]
)
conv_input = fetch[0]
conv_output = fetch[1]
weight_static = fetch[2]

self.check_result(
weight_dygraph, weight_static, conv_input, conv_output
)


# initialize Conv2D weight
class TestDiracInitializer2(TestDiracInitializer1):
Expand Down

0 comments on commit 46e3dfe

Please sign in to comment.