-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Device mismatch when using AMP with Pytorch DataParallel #503
Comments
This should be an expected behavior. The PyTorch |
I found the root cause: forward must be patched after DataParallel(...) call (because otherwise the patched method refers the old model object and not the dynamically created replica). Maybe some other patching way exists that would work fine with DP, but definitely not the straightforward way in https://github.com/NVIDIA/apex/blob/master/apex/amp/_initialize.py#L201 The workaround I found: model = apex.amp.initialize(torch.nn.Sequential(model), opt_level = 'O2')[0]
model = torch.nn.DataParallel(model, device_ids = args.devices)
model.forward = lambda *args, old_fwd = model.forward, input_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_type']), output_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_outputs'] if apex.amp._amp_state.opt_properties.options.get('cast_model_outputs') is not None else torch.float32), **kwargs: apex.amp._initialize.applier(old_fwd(*apex.amp._initialize.applier(args, input_caster), **apex.amp._initialize.applier(kwargs, input_caster)), output_caster) |
This is still very useful information and I haven't been ignoring it, but to be honest I'm probably not going to implement a fix in Apex soon. My absolute top priority right now is getting automatic mixed precision into Pytorch natively, which will eliminate all extension building/version matching issues. I'm taking care to ensure the native integration will support DistributedDataParallel, DataParallel, and model parallel usage. We are targeting the 1.5 release: If you are having problems with the current incarnation of Apex, my best advice is to wait for the PRs to be merged. Getting native mixed precision support as soon as possible is the best path forward for everyone IMO. |
@mcarilli Btw Is O2/O3 supported in PyTorch autocast? I had colleagues mentioned that they saw no RAM decrease when using PyTorch core autocast, as if activations were still stored in fp32 |
I'm running the following on 4 GPUs:
And I get the following error:
The text was updated successfully, but these errors were encountered: