Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSNR not working with multiple GPUs and dataparallel #266

Closed
amonod-gpfw opened this issue May 28, 2021 · 5 comments · Fixed by #267
Closed

PSNR not working with multiple GPUs and dataparallel #266

amonod-gpfw opened this issue May 28, 2021 · 5 comments · Fixed by #267
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority Critical task/issue v0.3.x

Comments

@amonod-gpfw
Copy link

amonod-gpfw commented May 28, 2021

🐛 Bug

(This is a sort of follow up to lightning issue #7257 and torchmetrics bugfix #214)

Hi folks,

I have a problem when using lightning, DataParallel and torchmetrics. When training a small denoising network on MNIST with 2 gpus using DP and torchmetrics to compute training and validation PSNR I get the following error:

RuntimeError: All input tensors must be on the same device. Received cuda:0 and cuda:1

Code for reproduction

import argparse
from typing import Optional
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from torchmetrics import PSNR


def add_gaussian_noise(cleanTensor, sigma=.1):
	# adds gausian noise of standard deviation sigma
	noiseTensor = torch.normal(mean=torch.zeros_like(cleanTensor), std=sigma)
	noisyTensor = cleanTensor + noiseTensor
	return noiseTensor, noisyTensor

class LitConvAE(pl.LightningModule):
	def __init__(self, hparams):
		super().__init__()
		# network architecture
		self.encoder = nn.Sequential(
			nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=2, padding=1),
			nn.ReLU(True),
			nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3, stride=1, padding=1))
		self.decoder = nn.Sequential(
			nn.ConvTranspose2d(in_channels=8, out_channels=8, kernel_size=3, stride=1, padding=1),
			nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
			nn.ReLU(True))
		# Model-specific parameters
		self.learning_rate = hparams.learning_rate
		self.noise_sigma = hparams.noise_sigma
		# save all hyperparameters to a .yaml file
		self.save_hyperparameters()
		# metrics from torchmetrics
		self.train_psnr = PSNR(data_range=1, dim=(-2, -1))
		self.val_psnr = PSNR(data_range=1, dim=(-2, -1))

	def forward(self, x):
		# typically defines inference behavior
		x = self.encoder(x)
		x = self.decoder(x)
		return x

	def training_step(self, batch, batch_idx):
		# training behavior can be different from that of inference
		clean_batch, _ = batch  # do not care about image class for denoising
		noise_batch, noisy_batch = add_gaussian_noise(clean_batch, self.noise_sigma)
		denoised_batch = self.decoder(self.encoder(noisy_batch))
		loss = nn.functional.mse_loss(denoised_batch, clean_batch, reduction='sum')  # squared l2 norm
		self.log('train_loss', loss)  # log at each step
		self.train_psnr(denoised_batch, clean_batch)
		self.log('train_psnr', self.train_psnr, on_step=False, on_epoch=True)  # log at each end of epoch
		return loss

	def validation_step(self, batch, batch_idx):
		# training behavior can be different from that of inference
		clean_batch, _ = batch  # do not care about image class for denoising
		noise_batch, noisy_batch = add_gaussian_noise(clean_batch, self.noise_sigma)
		denoised_batch = self.decoder(self.encoder(noisy_batch))
		self.val_psnr(denoised_batch, clean_batch)
		self.log('validation_psnr', self.val_psnr, on_step=False, on_epoch=True)

	def configure_optimizers(self):
		optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
		return optimizer

	@staticmethod
	def add_model_specific_args(parent_parser):
		# Model-specific arguments
		parser = parent_parser.add_argument_group("LitConvAE")
		parser.add_argument('--noise_sigma', type=float, default=.2, help='noise standard deviation (between 0. and 1.)')
		parser.add_argument('--learning_rate', type=float, default=1e-3, help='learning rate')
		return parent_parser


class MNISTDataModule(pl.LightningDataModule):
	def __init__(self, batch_size=32, dataset_dir='./', data_transform=transforms.ToTensor(), num_workers=4):
		super().__init__()
		self.batch_size = batch_size
		self.dataset_dir = dataset_dir
		self.data_transform = data_transform
		self.num_workers = num_workers

	def prepare_data(self):
		# Use this method to do things that might write to disk
		# or that need to be done only from a single process in distributed settings.
		datasets.MNIST(root=self.dataset_dir, train=True, download=False)
		datasets.MNIST(root=self.dataset_dir, train=False, download=False)

	def setup(self, stage: Optional[str] = None):
		# data operations you might want to perform on every GPU
		if stage == 'fit' or stage is None:
			dataset_full = datasets.MNIST(self.dataset_dir, train=True, transform=self.data_transform)
			train_split = 11 * len(dataset_full) // 12
			print(f"\ntrain / val split: {[train_split, len(dataset_full) - train_split]} \n")
			self.dataset_train, self.dataset_val = random_split(dataset_full, [train_split, len(dataset_full) - train_split])

		# Assign test dataset for use in dataloader(s)
		if stage == 'test' or stage is None:
			self.dataset_test = datasets.MNIST(self.data_dir, train=False, transform=self.data_transform)

	def train_dataloader(self):
		return DataLoader(self.dataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

	def val_dataloader(self):
		return DataLoader(self.dataset_val, batch_size=self.batch_size, num_workers=self.num_workers)

	def test_dataloader(self):
		return DataLoader(self.dataset_test, batch_size=self.batch_size, num_workers=self.num_workers)


if __name__ == "__main__":
	# Parse arguments
	parser = argparse.ArgumentParser(description="Denoise MNIST with a convolutional Autoencoder")
	# DataModule-specific arguments
	parser.add_argument('--batch_size', type=int, default=32, help='number of examples per batch')
	parser.add_argument('--num_workers', type=int, default=4, help='number of separate processes for the DataLoader (default: 4)')
	# Trainer arguments
	parser.add_argument('--gpus', type=int, default=2, help='how many gpus to select')
	parser.add_argument('--accelerator', type=str, default='dp', help="which multi-GPU backend you want to use (default: 'dp')")
	parser.add_argument('--max_epochs', type=int, default=10, help='number of epochs you want the model to train for')
	# Program-specific arguments
	parser.add_argument('--data_dir', type=str, default=r'path/to_mnist_dir', help='path to the parent directory of MNIST torchvision dataset')

	# add model specific args
	parser = LitConvAE.add_model_specific_args(parser)

	hyperparams = parser.parse_args()

	# initialize the neural network
	model = LitConvAE(hyperparams)

	dataModule = MNISTDataModule(batch_size=hyperparams.batch_size, dataset_dir=hyperparams.data_dir, data_transform=transforms.ToTensor())

	trainer = pl.Trainer.from_argparse_args(hyperparams)

	# the training and validation loops happen here
	trainer.fit(model, dataModule)

Expected behavior

Training should perform correctly.

Training and validation work fine when using a single GPU and DP (although there is not much point in doing that).

Environment

  • PyTorch Version (e.g., 1.0): 1.8.1
  • PyTorch Lightning Version (e.g., 1.0): 1.3.2
  • torchmetrics Version: 0.3.2
  • OS: Linux
  • How you installed PyTorch (conda, pip, source): installed everything using conda
  • Python version: 3.7.10
  • CUDA/cuDNN version: 11
  • GPU models and configuration: 2 Titan Xp
@amonod-gpfw amonod-gpfw added bug / fix Something isn't working help wanted Extra attention is needed labels May 28, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@amonod-gpfw
Copy link
Author

Here is the full traceback:

Traceback (most recent call last):
  File "lightning_dp_psnr_issue.py", line 139, in <module>
    trainer.fit(model, dataModule)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self._run(model)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
    self.dispatch()
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
    self.accelerator.start_training(self)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
    return self.run_train()
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 869, in run_train
    self.train_loop.run_training_epoch()
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 489, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 729, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 432, in optimizer_step
    using_lbfgs=is_lbfgs,
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py", line 1403, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py", line 134, in __optimizer_step
    trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 329, in optimizer_step
    self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in run_optimizer_step
    self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 193, in optimizer_step
    optimizer.step(closure=lambda_closure, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/optim/optimizer.py", line 89, in wrapper
    return func(*args, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 724, in train_step_and_backward_closure
    split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 812, in training_step_and_backward
    result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 280, in training_step
    training_step_output = self.trainer.accelerator.training_step(args)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 204, in training_step
    return self.training_type_plugin.training_step(*args)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/dp.py", line 98, in training_step
    return self.model(*args, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/_utils.py", line 429, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/overrides/data_parallel.py", line 77, in forward
    output = super().forward(*inputs, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 46, in forward
    output = self.module.training_step(*inputs, **kwargs)
  File "lightning_dp_psnr_issue.py", line 52, in training_step
    self.train_psnr(denoised_batch, clean_batch)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torchmetrics/metric.py", line 180, in forward
    self._forward_cache = self.compute()
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torchmetrics/metric.py", line 251, in wrapped_func
    self._computed = compute(*args, **kwargs)
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/torchmetrics/regression/psnr.py", line 147, in compute
    sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error])
RuntimeError: All input tensors must be on the same device. Received cuda:1 and cuda:0

Exception ignored in: <function tqdm.__del__ at 0x7f4b657960e0>
Traceback (most recent call last):
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/tqdm/std.py", line 1145, in __del__
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/tqdm/std.py", line 1299, in close
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/tqdm/std.py", line 1492, in display
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/tqdm/std.py", line 1148, in __str__
  File "/home/amonod/miniconda/envs/pytorch181/lib/python3.7/site-packages/tqdm/std.py", line 1450, in format_dict
TypeError: cannot unpack non-iterable NoneType object

@justusschock
Copy link
Member

This happens because we only reduce between distributed processes.

That being said, I am not sure, how we would correctly implement it with dp (due to the internal state we cannot easily copy them). Also you cannot gather results since they are added to the states within the DPs module meaning they don't have access to the DP scope information.

I'll need to think about this.

To unblock you, can you use DDP instead (which is recommended anyways)?

@justusschock justusschock self-assigned this May 28, 2021
@justusschock justusschock added the Priority Critical task/issue label May 28, 2021
@amonod-gpfw
Copy link
Author

Thanks for the answer. I am using DDP in the mean time, but I might still be interested in using DP simply because I'm porting some previous pytorch code that uses DP to lightning and I want to make sure things still work the same way (wasn't using torchmetrics before though).

@SkafteNicki
Copy link
Member

Hi @amonod-gpfw, so we cannot support updates of metrics in the training_step method when running in DP mode, due to the internal states being destroyed after each forward pass. Instead the update should happen in the training_step_end method, which gathers the results from all processes. Something like this should work:

    def training_step(self, batch, batch_idx):
        data, target = batch
        preds = self(data)
        ...
        return {'loss' : loss, 'preds' : preds, 'target' : target}

    def training_step_end(self, outputs):
        #update and log
        self.metric(outputs['preds'], outputs['target'])
        self.log('metric', self.metric)

I am also going to add a note to the documentation for future reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed Priority Critical task/issue v0.3.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants