Skip to content

Commit

Permalink
Move optimizer creation after device placement for ddp backends. (#2904)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilJd authored Aug 12, 2020
1 parent 56396ab commit e3528af
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
14 changes: 7 additions & 7 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# MODEL
# copy model to each gpu
if self.trainer.on_gpu:
Expand All @@ -130,6 +123,13 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)

Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)
Expand All @@ -197,6 +190,13 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)

Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,6 @@ def ddp_train(self, process_idx, mp_queue, model):
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)
Expand All @@ -129,6 +122,13 @@ def ddp_train(self, process_idx, mp_queue, model):
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)

Expand Down

0 comments on commit e3528af

Please sign in to comment.