From 855e51e45d60f0fcf32201507302fb03b5ba6e9a Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 30 Nov 2023 16:23:46 +0800 Subject: [PATCH] [AutoParallel] rm infershape for dist_embedding (#59526) * [AutoParallel] rm infershape for dist_embedding * [AutoParallel] rm infershape for dist_embedding * Update dist_embedding.py --- .../static/operators/dist_embedding.py | 105 +++--------------- 1 file changed, 18 insertions(+), 87 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py index 2f7f8accf1f393..e4f9ff09fca20e 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py @@ -49,10 +49,10 @@ ParallelMode, get_default_distributed_operator_impl, gradient_synchronization, - infer_shape, naive_copy_op_dist_attr_for_program, register_distributed_operator_impl, register_distributed_operator_impl_container, + set_comm_op_dist_attr_for_program, update_op_dims_mapping, ) @@ -188,6 +188,7 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var): [-1] + list(Ids_var_dist_attr.dims_mapping), Ids_var_dist_attr.process_mesh, ) + src_op._rename_input(Ids_var.name, intermediate_var_0.name) op_dist_attr.del_input_dist_attr(Ids_var.name) op_dist_attr.set_input_dist_attr( intermediate_var_0.name, intermediate_var_0_dist_attr @@ -496,48 +497,22 @@ def forward(ctx, *args, **kwargs): assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape( - main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr - ) - - intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key( - ".".join(["c_embedding", 'tmp']) - ), - dtype=Weight_var.dtype, - shape=Out_var.shape, - type=core.VarDesc.VarType.LOD_TENSOR, - persistable=False, - stop_gradient=Out_var.stop_gradient, - ) - # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program( - intermediate_var_0, out_var_dist_attr - ) - - check_variable_and_dtype( - Out_var, - 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - 'c_allreduce_sum', - ) - c_embedding_op = main_block.append_op( - type='c_embedding', - inputs={'Ids': [Ids_var], 'W': [Weight_var]}, - outputs={'Out': [intermediate_var_0]}, - attrs={ - "start_index": relative_idx, - OP_ROLE_KEY: src_op.attr('op_role'), - }, - ) - if intermediate_var_0.shape != ref_shape: - intermediate_var_0.desc.set_shape(ref_shape) + c_embedding_op_desc = main_block.append_op(type='nop').desc + c_embedding_op_desc.set_type("c_embedding") + c_embedding_op_desc.set_input('Ids', [Ids_var.name]) + c_embedding_op_desc.set_input('W', [Weight_var.name]) + c_embedding_op_desc.set_output('Out', [Out_var.name]) + c_embedding_op_desc._set_attr('start_index', relative_idx) + c_embedding_op_desc._set_attr(OP_ROLE_KEY, src_op.attr('op_role')) + c_embedding_op = main_block.ops[-1] + assert c_embedding_op.type == "c_embedding" + naive_copy_op_dist_attr_for_program(c_embedding_op, src_op, ctx) # use_model_parallel c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', - inputs={'X': [intermediate_var_0]}, + inputs={'X': [Out_var]}, outputs={'Out': [Out_var]}, attrs={ 'ring_id': group.id, @@ -549,49 +524,12 @@ def forward(ctx, *args, **kwargs): c_allreduce_sum_op._set_attr( 'op_namescope', '/' + ParallelMode.TensorParallel ) - if Out_var.shape != ref_shape: - Out_var.desc.set_shape(ref_shape) - - # set dist op's dist_attr with serial op's dist_attr - # matmulv2 - embedding_op_dist_attr = OperatorDistAttr() - embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh - embedding_op_dist_attr.impl_type = op_dist_attr.impl_type - embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx - for input_varname in c_embedding_op.desc.input_arg_names(): - input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) - assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}" - embedding_op_dist_attr.set_input_dist_attr( - input_varname, input_dist_attr - ) - output_varname = c_embedding_op.desc.output_arg_names()[0] - output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) - assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}" - embedding_op_dist_attr.set_output_dist_attr( - output_varname, output_dist_attr - ) - ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr) - # allreduce - allreduce_op_dist_attr = OperatorDistAttr() - allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh - allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type - allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx - for input_varname in c_allreduce_sum_op.desc.input_arg_names(): - input_var = main_block._var_recursive(input_varname) - tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) - assert tensor_dist_attr is not None - allreduce_op_dist_attr.set_input_dist_attr( - input_varname, tensor_dist_attr - ) - for output_varname in c_allreduce_sum_op.desc.output_arg_names(): - output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname) - assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}" - allreduce_op_dist_attr.set_output_dist_attr( - output_varname, output_dist_attr - ) - ctx.set_op_dist_attr_for_program( - c_allreduce_sum_op, allreduce_op_dist_attr + set_comm_op_dist_attr_for_program( + c_allreduce_sum_op, + op_dist_attr.process_mesh, + out_var_dist_attr, + ctx, ) # param initialization sync @@ -699,13 +637,6 @@ def backward(ctx, *args, **kwargs): per_part_size = Weight_var.shape[0] relative_idx = relative_idx * per_part_size - check_variable_and_dtype( - Out_grad, - 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - '_c_identity', - ) - c_embedding_grad_op_desc = main_block.append_op(type='nop').desc c_embedding_grad_op_desc.set_type("c_embedding_grad") c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])