Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device mismatch when using AMP with Pytorch DataParallel #503

Open
michaelklachko opened this issue Sep 25, 2019 · 4 comments
Open

Device mismatch when using AMP with Pytorch DataParallel #503

michaelklachko opened this issue Sep 25, 2019 · 4 comments

Comments

@michaelklachko
Copy link

I'm running the following on 4 GPUs:

model = Resnet50()
model = model.cuda()
criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
optimizer = torch.optim.SGD(model.parameters(), 0.001)
model, optimizer = amp.initialize(model, optimizer, opt_level='O3', keep_batchnorm_fp32=False)
model = torch.nn.DataParallel(model)

And I get the following error:

Selected optimization level O3:  Pure FP16 training.
Defaults for this optimization level are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : False
master_weights         : False
loss_scale             : 1.0
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : False
master_weights         : False
loss_scale             : 1.0
lr: 0.1 wd 0.0001
Traceback (most recent call last):
  File "main.py", line 559, in <module>
    main()
  File "main.py", line 555, in main
    train(train_loader, val_loader, model, criterion, optimizer, start_epoch, best_acc, args)
  File "main.py", line 409, in train
    output = model(input_var, epoch=epoch, i=i)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/apex/amp/_initialize.py", line 194, in new_fwd
    **applier(kwargs, input_caster))
  File "/home/michael/noisynet/models/resnet.py", line 161, in forward
    x = self.conv1(x)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 343, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 340, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

@jianchao-li
Copy link

jianchao-li commented Oct 2, 2019

This should be an expected behavior. The PyTorch DataParallel only works with O1. You may refer to #269 for more details.

@vadimkantorov
Copy link

I found the root cause: forward must be patched after DataParallel(...) call (because otherwise the patched method refers the old model object and not the dynamically created replica). Maybe some other patching way exists that would work fine with DP, but definitely not the straightforward way in https://github.com/NVIDIA/apex/blob/master/apex/amp/_initialize.py#L201

The workaround I found:

model = apex.amp.initialize(torch.nn.Sequential(model), opt_level = 'O2')[0]
model = torch.nn.DataParallel(model, device_ids = args.devices)
model.forward = lambda *args, old_fwd = model.forward, input_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_type']), output_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_outputs'] if apex.amp._amp_state.opt_properties.options.get('cast_model_outputs') is not None else torch.float32), **kwargs: apex.amp._initialize.applier(old_fwd(*apex.amp._initialize.applier(args, input_caster), **apex.amp._initialize.applier(kwargs, input_caster)), output_caster)

@jianchao-li

@mcarilli
Copy link
Contributor

mcarilli commented Dec 18, 2019

This is still very useful information and I haven't been ignoring it, but to be honest I'm probably not going to implement a fix in Apex soon. My absolute top priority right now is getting automatic mixed precision into Pytorch natively, which will eliminate all extension building/version matching issues. I'm taking care to ensure the native integration will support DistributedDataParallel, DataParallel, and model parallel usage. We are targeting the 1.5 release:
pytorch/pytorch#25081
Gradient scaling and autocasting will be independently-usable components.
The gradient scaling PR is mature, awaiting final documentation review:
pytorch/pytorch#26512
The autocasting PR is about 3/4 done in terms of op coverage:
pytorch/pytorch#29552
Autocasting will likely be exposed via a context manager that can be used to locally enable/disable mixed precision for any desired regions of the model.

If you are having problems with the current incarnation of Apex, my best advice is to wait for the PRs to be merged. Getting native mixed precision support as soon as possible is the best path forward for everyone IMO.

@vadimkantorov
Copy link

vadimkantorov commented Dec 30, 2020

@mcarilli Btw Is O2/O3 supported in PyTorch autocast? I had colleagues mentioned that they saw no RAM decrease when using PyTorch core autocast, as if activations were still stored in fp32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants