Skip to content

Commit

Permalink
[AutoParallel] rm infershape for dist_embedding (#59526)
Browse files Browse the repository at this point in the history
* [AutoParallel] rm infershape for dist_embedding

* [AutoParallel] rm infershape for dist_embedding

* Update dist_embedding.py
  • Loading branch information
zhaoyinglia authored Nov 30, 2023
1 parent 09e0e45 commit 855e51e
Showing 1 changed file with 18 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
ParallelMode,
get_default_distributed_operator_impl,
gradient_synchronization,
infer_shape,
naive_copy_op_dist_attr_for_program,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
set_comm_op_dist_attr_for_program,
update_op_dims_mapping,
)

Expand Down Expand Up @@ -188,6 +188,7 @@ def adopt_lookup_table_v1(ctx, main_block, src_op, Ids_var):
[-1] + list(Ids_var_dist_attr.dims_mapping),
Ids_var_dist_attr.process_mesh,
)
src_op._rename_input(Ids_var.name, intermediate_var_0.name)
op_dist_attr.del_input_dist_attr(Ids_var.name)
op_dist_attr.set_input_dist_attr(
intermediate_var_0.name, intermediate_var_0_dist_attr
Expand Down Expand Up @@ -496,48 +497,22 @@ def forward(ctx, *args, **kwargs):
assert out_tensor_dist_attr is not None
out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert out_var_dist_attr is not None
ref_shape = infer_shape(
main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr
)

intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(["c_embedding", 'tmp'])
),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient,
)
# set intermediate_var_0's dist_attr with Out_var's dist_attr
ctx.set_tensor_dist_attr_for_program(
intermediate_var_0, out_var_dist_attr
)

check_variable_and_dtype(
Out_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum',
)

c_embedding_op = main_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var], 'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={
"start_index": relative_idx,
OP_ROLE_KEY: src_op.attr('op_role'),
},
)
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
c_embedding_op_desc = main_block.append_op(type='nop').desc
c_embedding_op_desc.set_type("c_embedding")
c_embedding_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_op_desc.set_input('W', [Weight_var.name])
c_embedding_op_desc.set_output('Out', [Out_var.name])
c_embedding_op_desc._set_attr('start_index', relative_idx)
c_embedding_op_desc._set_attr(OP_ROLE_KEY, src_op.attr('op_role'))
c_embedding_op = main_block.ops[-1]
assert c_embedding_op.type == "c_embedding"
naive_copy_op_dist_attr_for_program(c_embedding_op, src_op, ctx)

# use_model_parallel
c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
inputs={'X': [Out_var]},
outputs={'Out': [Out_var]},
attrs={
'ring_id': group.id,
Expand All @@ -549,49 +524,12 @@ def forward(ctx, *args, **kwargs):
c_allreduce_sum_op._set_attr(
'op_namescope', '/' + ParallelMode.TensorParallel
)
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)

# set dist op's dist_attr with serial op's dist_attr
# matmulv2
embedding_op_dist_attr = OperatorDistAttr()
embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_embedding_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
assert input_dist_attr is not None, f"dist_attr is {op_dist_attr}"
embedding_op_dist_attr.set_input_dist_attr(
input_varname, input_dist_attr
)
output_varname = c_embedding_op.desc.output_arg_names()[0]
output_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name)
assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}"
embedding_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(c_embedding_op, embedding_op_dist_attr)

# allreduce
allreduce_op_dist_attr = OperatorDistAttr()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr(
input_varname, tensor_dist_attr
)
for output_varname in c_allreduce_sum_op.desc.output_arg_names():
output_dist_attr = op_dist_attr.get_output_dist_attr(output_varname)
assert output_dist_attr is not None, f"dist_attr is {op_dist_attr}"
allreduce_op_dist_attr.set_output_dist_attr(
output_varname, output_dist_attr
)
ctx.set_op_dist_attr_for_program(
c_allreduce_sum_op, allreduce_op_dist_attr
set_comm_op_dist_attr_for_program(
c_allreduce_sum_op,
op_dist_attr.process_mesh,
out_var_dist_attr,
ctx,
)

# param initialization sync
Expand Down Expand Up @@ -699,13 +637,6 @@ def backward(ctx, *args, **kwargs):
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size

check_variable_and_dtype(
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)

c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
Expand Down

0 comments on commit 855e51e

Please sign in to comment.