Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training is interrupted without error with MulitGPU #5604

Closed
skull3r7 opened this issue Jan 21, 2021 · 24 comments · Fixed by #6438
Closed

Training is interrupted without error with MulitGPU #5604

skull3r7 opened this issue Jan 21, 2021 · 24 comments · Fixed by #6438
Assignees
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update

Comments

@skull3r7
Copy link

skull3r7 commented Jan 21, 2021

🐛 Bug

The training is interrupted randomly in the middle of an epoch without errors. The console only says: Terminated.
The error does not necessarily occur, if it does then mostly between epochs 2-4. It is noticeable that processes are still running after the termination, the graphic cards are still used by python processes.

We train the PyTorch version of the ImageGPT model with huggingface transformers. Could also be problem of huggingface, we are not sure.

Epoch 1: 29%|█▍ | 9413/32393 [3:28:18<8:28:33, 1.33s/it, loss=3.23, v_num=9]Terminated

Please reproduce using the BoringModel

Cant reproduce with Boring Model.

Code

class ImageGPT(pl.LightningModule):

    def __init__(self,
                 learning_rate=learning_rate
                 ):
        super().__init__()
        self.gpt2 =  ImageGPT2LMHeadModel(config=...)
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.gpt2(x, past_key_values=None)

....


logger = pl_loggers.TensorBoardLogger(save_dir="logs", name=name)

checkpoint_callback = ModelCheckpoint(
        save_top_k=1,
        verbose=True,
        monitor='val_loss',
        mode='min',
        filepath='../models',
        prefix='ImageGPT'
    )

trainer = Trainer(
                accelerator='ddp',
                max_epochs=10,
                max_steps=None,
                precision=32,
                accumulate_grad_batches=1,
                gpus=[0, 1, 2],
                callbacks=[checkpoint_callback],
                logger=logger,
                gradient_clip_val=0.6
            )

trainer.fit(model=model, datamodule=datamodule)

Expected behavior

The training is fully completed across all epochs.

Environment

  • CUDA:
    • GPU:
      • TITAN RTX
      • TITAN RTX
      • TITAN RTX
    • available: True
    • version: 10.2
  • Packages:
    • numpy: 1.19.4
    • pyTorch_debug: False
    • pyTorch_version: 1.7.1
    • pytorch-lightning: 1.1.2
    • transformers: 3.5.1
    • tqdm: 4.55.0
  • System:
    • OS: Linux, 64bit
    • processor: x86_64
    • python: 3.7.4
    • version: 86-Ubuntu SMP Fri Jan 17 17:24:28 UTC 2020

Additional context

We have made the following points to solve the problem:

  • set the num-workers of the dataloaders to 0 or 1 (instead of 32-64)
  • go back to 32 bit precision
  • different learning rates
  • added gradient clipping
  • used AdamW implementation from huggingface
@skull3r7 skull3r7 added bug Something isn't working help wanted Open to be worked on labels Jan 21, 2021
@angadkalra
Copy link

This happens to me too, but I don't get "Terminated" at end of progress bar, it just stops, and when I check the system with "top i", I see 4 python processes running at 100% and 4 GPUs at about 90% capacity but nothing is changing. Sometimes, after like an hour or two, it just randomly restarts but the s/it jumps from 2.5 to like 85.

@edenlightning edenlightning added with code distributed Generic distributed-related topic priority: 1 Medium priority task labels Feb 9, 2021
@XiaomoWu
Copy link

XiaomoWu commented Feb 16, 2021

@angadkalra +1. For me switching to DP solves the problem, but it at the cost of speed.

Update:
I test on PL v1.1.5-v1.2.0 rc1, and the problem persists. I'm sorry I can't upload a reproducible example at the moment, but will probably do that later.

@carmocca
Copy link
Contributor

Hi everybody! If any of you can provide a snippet that would be awesome. Otherwise we are blind trying to fix it.

@carmocca carmocca added the waiting on author Waiting on user action, correction, or update label Feb 17, 2021
@angadkalra
Copy link

Hi everybody! If any of you can provide a snippet that would be awesome. Otherwise we are blind trying to fix it.

Code snippet? Is the code in the original post good enough or need colab?

@carmocca
Copy link
Contributor

Need something I can run. Preferably a colab 👍

@XiaomoWu
Copy link

XiaomoWu commented Feb 24, 2021

@carmocca Sorry for the late reproducible example! Please see below for a self-contained example that uses two GPUs for DDP. For me the code gets stuck at epoch 13, while the two GPUs keep busy at 100%. Switching to DP solves the problem.

Some hints

  • Removing the dropout layer in the def shared_step(self, batch) method also solves the problem. But how could such a harmless dropout layer ruin the model?
  • I can reproduce the bug on both lightning 1.1.8 and 1.2.0

image

My settings:

  • OS: Ubuntu 20.10
  • pytorch: 1.7.1 + CUDA 11.0
  • pytorch-lightning == 1.1.8
  • GPUs: 2080 Ti x2
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
import torch.optim as optim

from torch import nn
from torch.utils.data import Dataset, DataLoader

# set random seed
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

# Define Dataset
class CCDataset(Dataset):
    
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        target = self.data[idx,0].clone().detach()
        features = self.data[idx,1:].clone().detach()

        return target, features


# then define DataModule
class CCDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        
    # Dataset
    def setup(self):
        # read the train and test dataset
        # targets_train = feather.read_feather('targets_train.feather')
        # targets_val = feather.read_feather('targets_test.feather')
        targets_train = torch.rand(12157,16)
        targets_val = torch.rand(12157,16)
        
        self.train_dataset = CCDataset(targets_train)
        self.val_dataset = CCDataset(targets_val)

    # DataLoader
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=64, 
                          shuffle=True, drop_last=False, num_workers=2,
                          pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64,
                          num_workers=2, pin_memory=True,
                          drop_last=False)


# # def Model

class Model(pl.LightningModule):
    '''Mainly define the `*_step_end` methods
    '''
    def __init__(self):
        super().__init__()
        
        # dropout layers
        self.dropout_1 = nn.Dropout(0.5)
        
        # fc layers
        self.fc_1 = nn.Linear(15, 16)
        self.fc_2 = nn.Linear(16, 1)
        
    def shared_step(self, batch):
        t, x = batch
        x = self.dropout_1(F.relu(self.fc_1(x)))
        y = self.fc_2(x) # (N, 1)    
        
        return y.squeeze(), t
        
    # train step
    def training_step(self, batch, idx):
        y, t = self.shared_step(batch)
        return {'y': y, 't': t}
        
    # validation step
    def validation_step(self, batch, idx):
        y, t = self.shared_step(batch)
        return {'y': y, 't': t}
        
    # loss
    def mse_loss(self, y, t):
        return F.mse_loss(y, t)
        
    # def training_step_end
    def training_step_end(self, outputs):
        y = outputs['y']
        t = outputs['t']
        loss = self.mse_loss(y, t)
        
        return {'loss':loss}
    
    # def validation_step_end
    def validation_step_end(self, outputs):
        y = outputs['y']
        t = outputs['t']
        
        return {'y': y, 't': t}
        
    # validation step
    def validation_epoch_end(self, outputs):
        y = torch.cat([x['y'] for x in outputs])
        t = torch.cat([x['t'] for x in outputs])
        
        loss = self.mse_loss(y, t)
        rmse = torch.sqrt(loss)
        self.log('val_rmse', rmse, on_step=False)
        
    # optimizer
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-4)
        return optimizer  


# # Run

# checkpoint
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    verbose=True,
    mode='min',
    monitor='val_rmse',
    save_top_k=1)

# trainer
trainer = pl.Trainer(gpus=[0,1], 
                     checkpoint_callback=checkpoint_callback, 
                     accelerator='ddp',
                     min_epochs=10,
                     max_epochs=500)

# loop over windows
torch.manual_seed(42)

# init model
model = Model()

# create datamodule
datamodule = CCDataModule()
datamodule.setup()

# train the model
trainer.fit(model, datamodule)

@carmocca carmocca removed the waiting on author Waiting on user action, correction, or update label Feb 24, 2021
@marrrcin
Copy link

I observe the same behaviour as @angadkalra:

This happens to me too, but I don't get "Terminated" at end of progress bar, it just stops, and when I check the system with "top i", I see 4 python processes running at 100% and 4 GPUs at about 90% capacity but nothing is changing. Sometimes, after like an hour or two, it just randomly restarts but the s/it jumps from 2.5 to like 85.

This is pretty severe issue, basically making the training impossible for non-toy models - in my case, removing ModelCheckpoint callback makes the training proceed, but what's the point?

@edenlightning edenlightning added priority: 0 High priority task and removed priority: 1 Medium priority task labels Feb 24, 2021
@angadkalra
Copy link

I observe the same behaviour as @angadkalra:

This happens to me too, but I don't get "Terminated" at end of progress bar, it just stops, and when I check the system with "top i", I see 4 python processes running at 100% and 4 GPUs at about 90% capacity but nothing is changing. Sometimes, after like an hour or two, it just randomly restarts but the s/it jumps from 2.5 to like 85.

This is pretty severe issue, basically making the training impossible for non-toy models - in my case, removing ModelCheckpoint callback makes the training proceed, but what's the point?

For me, I turn my VM off and on and it'll train fine for many epochs...

@SeanNaren
Copy link
Contributor

Regarding the reproducible code above, I think I can confirm that this is due to the validation_epoch_end returning different loss's on each process, leading to the model checkpoint logic being in de-sync with multiple processes (something that the lightning metrics package handles: https://pytorch-lightning.readthedocs.io/en/latest/extensions/metrics.html)

I know this is incorrect metrically, but as a test I'd suggest changing self.log('val_rmse', rmse, on_step=False) to self.log('val_rmse', rmse, on_step=False, sync_dist=True) and seeing if this works.

If this does, two longer term solutions would be to define your own pl.Metric (I can assist here, should be easy) which will handle distributed synching for you and gathering all the outputs, or trying to use the self.gather function within the validation_epoch_end function, which will be trickier.

@XiaomoWu
Copy link

I know this is incorrect metrically, but as a test I'd suggest changing self.log('val_rmse', rmse, on_step=False) to self.log('val_rmse', rmse, on_step=False, sync_dist=True) and seeing if this works.

Hi @SeanNaren, I can confirm by adding sync_dist=True the problem is gone. Would you please help me set up a custom pl.Metric? Thank you!

@marrrcin
Copy link

I'm only calling self.log in the *_step without any custom metrics (just with loss) and adding sync_dist=True does not resolve the issue.

@SeanNaren
Copy link
Contributor

SeanNaren commented Feb 25, 2021

I know this is incorrect metrically, but as a test I'd suggest changing self.log('val_rmse', rmse, on_step=False) to self.log('val_rmse', rmse, on_step=False, sync_dist=True) and seeing if this works.

Hi @SeanNaren, I can confirm by adding sync_dist=True the problem is gone. Would you please help me set up a custom pl.Metric? Thank you!

Sure, I'll try get some time, in the meantime I tested the all_gather function which works fine. Just change your validation_epoch_end to look like this:

        # validation step
    def validation_epoch_end(self, outputs):
        y = torch.cat([x['y'] for x in outputs])
        t = torch.cat([x['t'] for x in outputs])
        y = self.all_gather(y)
        t = self.all_gather(t)

        loss = self.mse_loss(y, t)
        rmse = torch.sqrt(loss)
        self.log('val_rmse', rmse, on_step=False)

Let me give some context as to why this fix works. When we run validation across distributed processes, each GPU/process gets a different set of data batches. This means the score calculated on every GPU is different, unless we do some form of synchronisation between the process. This can either be done via:

  1. Use the pl.Metric library which has a few helpful functions to synchronize the metric across processes, so even though each GPU has different batches of data, we sync our metric result across process (efficient and clean). If you do this, you won't need to override the epoch end hook.
  2. If you override validation_epoch_end this doesn't automatically sync the batches across processes. This is because in many cases you don't want to do this, i.e if you're using a pl.Metric to handle this instead. If you do want to sync, you can sync tensors and some python primitives using self.all_gather like suggested above!

@marrrcin maybe the explanation could be insightful? If you're running into a different error and can get a reproducible script let me know, I can help resolve

@PyTorchLightning/core-contributors @justusschock this has come up a few times as a bug. How about if we receive different monitor score in the model checkpoint from processes we throw a warning? I don't see too many cases where we'd have different processes giving different results to the model checkpoint. The check will involve gathering the monitor score across processes, but considering this happens only at saving time, this might be worth it.

@skull3r7
Copy link
Author

skull3r7 commented Feb 25, 2021

Need something I can run. Preferably a colab 👍

Sorry for the long wait, but here is a Google Colab for the code from the original post PyTorchImageGPT

I observe the same behaviour as @angadkalra:

This happens to me too, but I don't get "Terminated" at end of progress bar, it just stops, and when I check the system with "top i", I see 4 python processes running at 100% and 4 GPUs at about 90% capacity but nothing is changing. Sometimes, after like an hour or two, it just randomly restarts but the s/it jumps from 2.5 to like 85.

This is pretty severe issue, basically making the training impossible for non-toy models - in my case, removing ModelCheckpoint callback makes the training proceed, but what's the point?

Good to know, i will try it without ModelCheckpoint, also with sync_dist=True

For me, I turn my VM off and on and it'll train fine for many epochs...

No VMs installed on training system

@SeanNaren
Copy link
Contributor

Need something I can run. Preferably a colab 👍

Sorry for the long wait, but here is a Google Colab for the code from the original post PyTorchImageGPT

I observe the same behaviour as @angadkalra:

This happens to me too, but I don't get "Terminated" at end of progress bar, it just stops, and when I check the system with "top i", I see 4 python processes running at 100% and 4 GPUs at about 90% capacity but nothing is changing. Sometimes, after like an hour or two, it just randomly restarts but the s/it jumps from 2.5 to like 85.

This is pretty severe issue, basically making the training impossible for non-toy models - in my case, removing ModelCheckpoint callback makes the training proceed, but what's the point?

Good to know, i will try it without ModelCheckpoint, also with sync_dist=True

For me, I turn my VM off and on and it'll train fine for many epochs...

No VMs installed on training system

I see iGPT, which may require find_unused_parameters=True set to True. In Pytorch Lightning this is set to False by default, could you try something like below when training?

from pytorch_lightning.plugins import DDPPlugin

trainer = pl.Trainer(
    gpus=2,
    checkpoint_callback=checkpoint_callback,
    accelerator='ddp',
    plugins=DDPPlugin(find_unused_parameters=True),
)

@angadkalra
Copy link

Have you guys tried updating to v1.2? I'm using Metrics API now instead of returning batch dict and everything is working fine, using ModelCheckpoint callback too. I haven't got any freezing so far.

@edenlightning
Copy link
Contributor

@skull3r7 @marrrcin @XiaomoWu please let us know if that resolves the issue!

@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Feb 26, 2021
@XiaomoWu
Copy link

@edenlightning @SeanNaren sorry for the late reply. Yes the problem is solved, either by using the pl.Metric or overriding validation_epoch_end with all_gether. Thank you so much for the help!

@marrrcin
Copy link

marrrcin commented Mar 1, 2021

DDPPlugin(find_unused_parameters=True) seems to fix the problem for me too.

@SeanNaren
Copy link
Contributor

Thanks @marrrcin for coming back to us! We've got a discussion here to turn this flag back on, or expose it to the trainer: #6219

If anyone has thoughts please leave them there, it seems that at this point we should turn find_unused_parameters=True as default, and write in the docs that if you would like to see a speed up using traditional DDP, to turn this flag off.

@edenlightning
Copy link
Contributor

Actually keeping this open, to track setting the default to true or any other solution you come up with.

@tchaton tchaton assigned tchaton and unassigned justusschock Mar 3, 2021
@ifsheldon
Copy link
Contributor

I suggest:

  • put a reminder in the section of Logging of the doc that logging with distributed training need special handling and care, and then refer to the section Multi-GPU training/Prepare your code/Synchronize validation and test logging
  • add more examples of properly overriding hooks like validation_epoch_end in the section Multi-GPU training/Prepare your code/Synchronize validation and test logging

@rednag
Copy link

rednag commented Mar 25, 2021

I have the same issue, but only running on one GPU.

@chris-boson
Copy link
Contributor

I know this is incorrect metrically, but as a test I'd suggest changing self.log('val_rmse', rmse, on_step=False) to self.log('val_rmse', rmse, on_step=False, sync_dist=True) and seeing if this works.

This suggestion didn't work for me, but setting rank_zero_only=True did the trick.

@ahmadikalkhorani
Copy link

ahmadikalkhorani commented Nov 3, 2022

In my case, the training stops after epoch 0 (right before the validation end). Setting drop_last in Dataloader solves the problem!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task waiting on author Waiting on user action, correction, or update
Projects
None yet