Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel] Add ClipGradByGlobalNorm & check_finite_and_unscale in Dygraph #32354

Merged
merged 4 commits into from
Apr 22, 2021

Conversation

ForFishes
Copy link
Member

@ForFishes ForFishes commented Apr 19, 2021

PR types

New features

PR changes

Others

Describe

[HybridParallel] Add ClipGradByGlobalNorm & check_finite_and_unscale in Dygraph

支持fleet下,使用amp/clip global grad。引入distributed_scaler来wrapper scaler。具体示例代码

    # 用户的单卡组网
    model = Bert(**config)
    optimizer = paddle.optimizer.AdamW(parameters=model.parameters())
    if args.use_amp:
        # 使用amp,则需要GradScaler
        scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss)

    # 添加fleet支持
    if paddle.distributed.get_world_size() > 1:
        model = fleet.distributed_model(model)
        optimizer = fleet.distributed_optimizer(optimizer)
        # 如果是mp/pp下的amp,则需要用distributed_scaler wrapper。原因:check_nan_inf 需要allreduce_max各个卡上信息。
        if args.use_amp:
            scaler = fleet.distributed_scaler(scaler)

   # 接下来代码,用单卡训练一致
    loss = model(input)
    if args.use_amp:
        scaler.scale(loss).backward()
        scaler.minimize(optimizer, loss)
    else:
        loss.backward()
        optimizer.step()

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

self._found_inf)
# allreduce_max found_inf in check_group
if self._is_mp:
self._found_inf = paddle.cast(self._found_inf, dtype="int64")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

‘int8’ is enough.

@ForFishes ForFishes merged commit 7ea999f into PaddlePaddle:develop Apr 22, 2021
@ForFishes ForFishes deleted the clip_grad_amp branch April 22, 2021 05:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants