Skip to content

Commit

Permalink
Fix VarType
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Feb 4, 2024
1 parent 062b99e commit f07ab3c
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 29 deletions.
24 changes: 8 additions & 16 deletions python/paddle/distributed/passes/auto_parallel_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,11 @@


def set_op_dtype_to_fp16(op):
if (
op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32
):
if op.has_attr('in_dtype') and op.attr('in_dtype') == paddle.float32:
op._set_attr('in_dtype', __target_dtype__)
if (
op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32
):
if op.has_attr('out_dtype') and op.attr('out_dtype') == paddle.float32:
op._set_attr('out_dtype', __target_dtype__)
if op.has_attr('dtype') and op.attr('dtype') == core.VarDesc.VarType.FP32:
if op.has_attr('dtype') and op.attr('dtype') == paddle.float32:
op._set_attr('dtype', __target_dtype__)


Expand Down Expand Up @@ -297,7 +291,7 @@ def set_var_to_fp16(self, var_name, block):
):
return

if var.dtype == core.VarDesc.VarType.FP32:
if var.dtype == paddle.float32:
var.desc.set_dtype(__target_dtype__)

def resolute_cast_op(self, block):
Expand Down Expand Up @@ -445,9 +439,7 @@ def _insert_forward_cast_ops(
num_cast_ops = 0

for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name
):
if src_dtype == paddle.float32 and _keep_fp32_input(op, in_name):
continue

consume_op_attr = dist_context.get_op_dist_attr_for_program(op)
Expand Down Expand Up @@ -692,7 +684,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):

def _split_grads(params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp32_grads = [g for g in grads if g.dtype == paddle.float32]
fp16_grads = [g for g in grads if g.dtype == __target_dtype__]
assert len(fp32_grads) + len(fp16_grads) == len(
grads
Expand Down Expand Up @@ -809,9 +801,9 @@ def is_initialization_op(op):
'dtype'
), f"initialization op is supported to has dtype attribute but got {str(op)}."
out_var = startup_program.global_block().var(output_name)
if out_var.dtype == core.VarDesc.VarType.FP32:
if out_var.dtype == paddle.float32:
out_var.desc.set_dtype(__target_dtype__)
if op.attr('dtype') == core.VarDesc.VarType.FP32:
if op.attr('dtype') == paddle.float32:
op._set_attr('dtype', __target_dtype__)


Expand Down
7 changes: 4 additions & 3 deletions python/paddle/distributed/passes/auto_parallel_master_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import OrderedDict
from typing import List, Tuple

import paddle
from paddle.base import Variable
from paddle.distributed.auto_parallel.static.utils import (
is_backward_op,
Expand Down Expand Up @@ -118,10 +119,10 @@ def _add_cast_op(self, cur_block, grad_names: List[str], dist_context):
for grad_name, idx in reversed(grad_first_ids.items()):
grad_var = cur_block.var(grad_name)
if (
grad_var.dtype == core.VarDesc.VarType.FP16
or grad_var.dtype == core.VarDesc.VarType.BF16
grad_var.dtype == paddle.float16
or grad_var.dtype == paddle.bfloat16
):
is_fp16 = grad_var.dtype == core.VarDesc.VarType.FP16
is_fp16 = grad_var.dtype == paddle.float16
producer_op = cur_block.ops[idx]
producer_op_dist_attr = (
dist_context.get_op_dist_attr_for_program(producer_op)
Expand Down
5 changes: 1 addition & 4 deletions python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,7 @@ def _run(self, inputs, labels=None):
# train and test may take different arguments
if inputs[idx] is not None:
feed[n] = inputs[idx]
if (
self._amp_level == 'O2'
and input_dtypes[idx] == core.VarDesc.VarType.FP16
):
if self._amp_level == 'O2' and input_dtypes[idx] == paddle.float16:
if isinstance(feed[n], core.LoDTensor):
feed[n] = feed[n]._as_type(core.VarDesc.VarType.FP16)
elif isinstance(feed[n], np.array):
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/incubate/distributed/models/moe/grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def get_l2_norm_pow(params_grads, sum_dtype=None):
merge_grad = clip.merge_selected_rows(g)
merge_grad = clip.get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
if sum_square.dtype == paddle.float16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
elif sum_square.dtype == paddle.float32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
Expand Down Expand Up @@ -222,7 +222,7 @@ def _dygraph_clip(self, params_grads):
# TODO(wangxi): use inplace elementwise_mul
clip_input = (
clip_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16
if g.dtype == paddle.float16
else clip_var
)
new_grad = paddle.multiply(x=g, y=clip_input)
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/incubate/optimizer/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,14 @@ def _get_parameter(self, name, scope=None):
assert master_param is not None

master_param_t = scope.find_var(master_param).get_tensor()
assert master_param_t._dtype() == core.VarDesc.VarType.FP32
assert master_param_t._dtype() == paddle.float32

param_t = scope.find_var(name).get_tensor()
if param_t._dtype() == core.VarDesc.VarType.FP32:
if param_t._dtype() == paddle.float32:
assert param_t._ptr() == master_param_t._ptr()
return param_t, None
else:
assert param_t._dtype() == core.VarDesc.VarType.FP16
assert param_t._dtype() == paddle.float16
assert param_t.shape() == master_param_t.shape()
return param_t, master_param_t

Expand Down

0 comments on commit f07ab3c

Please sign in to comment.