Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed still, still changes metric states from fp32 to fp16 #2314

Closed
nickcolossal opened this issue Jan 17, 2024 · 1 comment · Fixed by #2379
Closed

DeepSpeed still, still changes metric states from fp32 to fp16 #2314

nickcolossal opened this issue Jan 17, 2024 · 1 comment · Fixed by #2379
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x v1.3.x

Comments

@nickcolossal
Copy link

nickcolossal commented Jan 17, 2024

🐛 Bug

SpearmanCorrCoef does not work with deepspeed strategy when precision is 16. I believe this is related to a unexpected type conversion from 32 to 16. Spearman logging works as expected when precision is set to 32. This seems to be nearly the exact issue as #1561 except I'm using SpearmanCorrCoef instead of PearsonCorrCoef.

To Reproduce

See modified code example from #1561

Code sample
import os
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torchmetrics import PearsonCorrCoef, MeanAbsoluteError, SpearmanCorrCoef


# Dataset for testing


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class PlDataModule(LightningDataModule):
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 32)
        self.metric = PearsonCorrCoef()
        self.metric2 = SpearmanCorrCoef()
        self.mae = MeanAbsoluteError()
        print("Before DeepSpeed initialization")
        print("self.metric2.preds", self.metric.mean_x)
        print("self.metric2.preds.dtype", self.metric.mean_x.dtype)
        print("self.metric.mean_x", self.metric.mean_x)
        print("self.metric.mean_x.dtype", self.metric.mean_x.dtype)
        print("self.metric.mean_y", self.metric.mean_y)
        print("self.metric.dtype", self.metric.mean_y.dtype)
        print("self.metric.var_x", self.metric.var_x)
        print("self.metric.var_x.dtype", self.metric.var_x.dtype)
        print("self.metric.var_y", self.metric.var_y)
        print("self.metric.var_y.dtype", self.metric.var_y.dtype)
        print("self.metric.corr_xy", self.metric.corr_xy)
        print("self.metric.corr_xy.dtype", self.metric.corr_xy.dtype)
        print("self.metric.n_total", self.metric.n_total)

        print("self.mae.sum_abs_error.dtype", self.mae.sum_abs_error.dtype)
        print("self.mae.total.dtype", self.mae.total.dtype)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        pred = self.forward(batch)
        loss = self(batch).sum()
        self.metric.update(torch.flatten(pred), torch.flatten(batch))

        print("After DeepSpeed initialization")
        print("self.metric2.preds", self.metric.mean_x)
        print("self.metric2.preds.dtype", self.metric.mean_x.dtype)
        print("self.metric.mean_x", self.metric.mean_x)
        print("self.metric.mean_x.dtype", self.metric.mean_x.dtype)
        print("self.metric.mean_y", self.metric.mean_y)
        print("self.metric.dtype", self.metric.mean_y.dtype)
        print("self.metric.var_x", self.metric.var_x)
        print("self.metric.var_x.dtype", self.metric.var_x.dtype)
        print("self.metric.var_y", self.metric.var_y)
        print("self.metric.var_y.dtype", self.metric.var_y.dtype)
        print("self.metric.corr_xy", self.metric.corr_xy)
        print("self.metric.corr_xy.dtype", self.metric.corr_xy.dtype)
        print("self.metric.n_total", self.metric.n_total)

        print("self.mae.sum_abs_error.dtype", self.mae.sum_abs_error.dtype)
        print("self.mae.total.dtype", self.mae.total.dtype)

        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=".",
        # default_root_dir=os.getcwd(),
        limit_train_batches=10,
        num_sanity_val_steps=0,
        max_epochs=1,
        strategy="deepspeed_stage_1",
        accelerator="gpu",
        precision=16
    )
    trainer.fit(model, datamodule=PlDataModule())


run()
My output
Before DeepSpeed initialization
self.metric2.preds tensor([0.])
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([0.])
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.])
self.metric.dtype torch.float32
self.metric.var_x tensor([0.])
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([0.])
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([0.])
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([0.])
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
/opt/conda/envs/py3.9/lib/python3.9/site-packages/lightning_fabric/connector.py:558: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
[2024-01-17 20:10:01,703] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Before DeepSpeed initialization
self.metric2.preds tensor([0.])
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([0.])
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.])
self.metric.dtype torch.float32
self.metric.var_x tensor([0.])
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([0.])
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([0.])
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([0.])
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Before DeepSpeed initialization
self.metric2.preds tensor([0.])
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([0.])
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.])
self.metric.dtype torch.float32
self.metric.var_x tensor([0.])
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([0.])
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([0.])
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([0.])
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Before DeepSpeed initialization
self.metric2.preds tensor([0.])
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([0.])
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.])
self.metric.dtype torch.float32
self.metric.var_x tensor([0.])
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([0.])
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([0.])
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([0.])
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
[2024-01-17 20:10:06,324] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-01-17 20:10:06,346] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-01-17 20:10:06,346] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
initializing deepspeed distributed: GLOBAL_RANK: 1, MEMBER: 2/4
initializing deepspeed distributed: GLOBAL_RANK: 3, MEMBER: 4/4
initializing deepspeed distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Enabling DeepSpeed FP16. Model parameters and inputs will be cast to `float16`.
You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
[2024-01-17 20:10:12,542] [WARNING] [engine.py:1163:_do_optimizer_sanity_check] **** You are using ZeRO with an untested optimizer, proceed with caution *****

  | Name    | Type              | Params
----------------------------------------------
0 | layer   | Linear            | 1.1 K 
1 | metric  | PearsonCorrCoef   | 0     
2 | metric2 | SpearmanCorrCoef  | 0     
3 | mae     | MeanAbsoluteError | 0     
----------------------------------------------
1.1 K     Trainable params
0         Non-trainable params
1.1 K     Total params
0.004     Total estimated model params size (MB)
/opt/conda/envs/py3.9/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/opt/conda/envs/py3.9/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0:   0%|                                                                                                                                                                                                                                                                           | 0/8 [00:00<?, ?it/s]After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor(-0.1016, device='cuda:1', dtype=torch.float16)
self.metric2.preds.dtype torch.float16
self.metric.mean_x tensor(-0.1016, device='cuda:1', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor(0.1820, device='cuda:1', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([20.1875], device='cuda:1')
self.metric.var_x.dtype torch.float32
self.metric2.preds tensor(0.0733, device='cuda:2', dtype=torch.float16)
self.metric2.preds tensor(0.0815, device='cuda:3', dtype=torch.float16)
self.metric2.preds.dtype torch.float16
self.metric2.preds.dtype torch.float16
self.metric.mean_x tensor(0.0815, device='cuda:3', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.var_y tensor([61.3438], device='cuda:1')
self.metric.var_y.dtype torch.float32
self.metric.mean_y tensor(0.1897, device='cuda:3', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.corr_xy tensor([4.2008], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.var_x tensor([25.6094], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.n_total tensor([64.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.var_y tensor([81.1875], device='cuda:3')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([7.5516], device='cuda:3')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([64.], device='cuda:3')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.mean_x tensor(0.0733, device='cuda:2', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor(0.1203, device='cuda:2', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([18.3750], device='cuda:2')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([47.2188], device='cuda:2')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([-3.2508], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric2.preds tensor(0.0458, device='cuda:0', dtype=torch.float16)
self.metric2.preds.dtype torch.float16
self.metric.n_total tensor([64.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.mean_x tensor(0.0458, device='cuda:0', dtype=torch.float16)
self.metric.mean_x.dtype torch.float16
self.metric.mean_y tensor(0.2625, device='cuda:0', dtype=torch.float16)
self.metric.dtype torch.float16
self.metric.var_x tensor([29.6719], device='cuda:0')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([59.7812], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([-0.3287], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([64.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
/opt/conda/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1947: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.)
  overflow_gpu = get_accelerator().ByteTensor([overflow])
/opt/conda/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1947: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.)
  overflow_gpu = get_accelerator().ByteTensor([overflow])
/opt/conda/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1947: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.)
  overflow_gpu = get_accelerator().ByteTensor([overflow])
/opt/conda/envs/py3.9/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1947: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.)
  overflow_gpu = get_accelerator().ByteTensor([overflow])
Epoch 0:  12%|███████████████████████████████▎                                                                                                                                                                                                                          | 1/8 [00:01<00:08,  0.80it/s, v_num=6]After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor([-0.0121], device='cuda:3')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0121], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric2.preds tensor([-0.0125], device='cuda:1')
self.metric2.preds tensor([0.0485], device='cuda:2')
self.metric2.preds.dtype torch.float32
self.metric2.preds.dtype torch.float32
self.metric.mean_y tensor([0.0196], device='cuda:3')
self.metric.dtype torch.float32
self.metric.var_x tensor([45.3866], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.mean_x tensor([-0.0125], device='cuda:1')
self.metric.mean_x tensor([0.0485], device='cuda:2')
self.metric.mean_x.dtype torch.float32
self.metric.mean_x.dtype torch.float32
After DeepSpeed initialization
self.metric.var_y tensor([141.2409], device='cuda:3')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([21.5493], device='cuda:3')
self.metric.corr_xy.dtype torch.float32
self.metric.mean_y tensor([0.1020], device='cuda:2')
self.metric.mean_y tensor([0.1844], device='cuda:1')
self.metric.dtype torch.float32
self.metric.dtype torch.float32
self.metric.n_total tensor([128.], device='cuda:3')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.var_x tensor([51.5392], device='cuda:2')
self.metric.var_x.dtype torch.float32
self.metric.var_x tensor([45.8725], device='cuda:1')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([123.9526], device='cuda:2')
self.metric.var_y tensor([119.6558], device='cuda:1')
self.metric.var_y.dtype torch.float32
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([10.3950], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.corr_xy tensor([5.2495], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([128.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.n_total tensor([128.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric2.preds tensor([-0.0033], device='cuda:0')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0033], device='cuda:0')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.2276], device='cuda:0')
self.metric.dtype torch.float32
self.metric.var_x tensor([50.6512], device='cuda:0')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([127.1647], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([5.6895], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([128.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Epoch 0:  25%|██████████████████████████████████████████████████████████████▌                                                                                                                                                                                           | 2/8 [00:01<00:04,  1.48it/s, v_num=6]After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor([-0.0159], device='cuda:3')
self.metric2.preds.dtype torch.float32
self.metric2.preds tensor([0.0102], device='cuda:2')
self.metric2.preds tensor([-0.0651], device='cuda:1')
self.metric2.preds.dtype torch.float32
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0159], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric.mean_x tensor([-0.0651], device='cuda:1')
self.metric.mean_x tensor([0.0102], device='cuda:2')
self.metric.mean_x.dtype torch.float32
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.0450], device='cuda:3')
self.metric.dtype torch.float32
self.metric.var_x tensor([66.0074], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.mean_y tensor([0.1518], device='cuda:1')
self.metric.mean_y tensor([0.0957], device='cuda:2')
self.metric.dtype torch.float32
self.metric.dtype torch.float32
self.metric.var_y tensor([197.2956], device='cuda:3')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([23.7070], device='cuda:3')
self.metric.corr_xy.dtype torch.float32
self.metric.var_x tensor([81.1348], device='cuda:1')
self.metric.var_x.dtype torch.float32
self.metric.var_x tensor([66.0174], device='cuda:2')
self.metric.var_x.dtype torch.float32
self.metric.n_total tensor([192.], device='cuda:3')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
After DeepSpeed initialization
self.metric.var_y tensor([210.6820], device='cuda:1')
self.metric.var_y.dtype torch.float32
self.metric.var_y tensor([176.6633], device='cuda:2')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([9.1227], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.corr_xy tensor([7.5847], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([192.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.n_total tensor([192.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric2.preds tensor([0.0152], device='cuda:0')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([0.0152], device='cuda:0')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.1851], device='cuda:0')
self.metric.dtype torch.float32
self.metric.var_x tensor([74.1469], device='cuda:0')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([183.5430], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([7.0658], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([192.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Epoch 0:  38%|█████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                                            | 3/8 [00:01<00:02,  2.07it/s, v_num=6]After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor([0.0101], device='cuda:3')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([0.0101], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric2.preds tensor([0.0068], device='cuda:2')
self.metric2.preds.dtype torch.float32
self.metric.mean_y tensor([0.0358], device='cuda:3')
self.metric.dtype torch.float32
self.metric.var_x tensor([85.8434], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.mean_x tensor([0.0068], device='cuda:2')
self.metric.mean_x.dtype torch.float32
self.metric.var_y tensor([264.4183], device='cuda:3')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([24.4432], device='cuda:3')
self.metric.mean_y tensor([0.0268], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric.dtype torch.float32
self.metric.n_total tensor([256.], device='cuda:3')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.var_x tensor([90.8689], device='cuda:2')
self.metric.var_x.dtype torch.float32
After DeepSpeed initialization
After DeepSpeed initialization
self.metric.var_y tensor([239.1074], device='cuda:2')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([9.4654], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([256.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric2.preds tensor([-0.0979], device='cuda:1')
self.metric2.preds.dtype torch.float32
self.metric2.preds tensor([0.0190], device='cuda:0')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0979], device='cuda:1')
self.metric.mean_x.dtype torch.float32
self.metric.mean_x tensor([0.0190], device='cuda:0')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.1603], device='cuda:1')
self.metric.dtype torch.float32
self.metric.mean_y tensor([0.1295], device='cuda:0')
self.metric.dtype torch.float32
self.metric.var_x tensor([112.7737], device='cuda:1')
self.metric.var_x.dtype torch.float32
self.metric.var_x tensor([102.9188], device='cuda:0')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([247.7011], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([15.7675], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([256.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.var_y tensor([284.1293], device='cuda:1')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([14.9629], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([256.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Epoch 0:  50%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                             | 4/8 [00:01<00:01,  2.59it/s, v_num=6]After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor([-0.0174], device='cuda:3')
self.metric2.preds.dtype torch.float32
self.metric2.preds tensor([-0.0322], device='cuda:2')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0174], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric2.preds tensor([0.0461], device='cuda:0')
self.metric2.preds.dtype torch.float32
self.metric.mean_y tensor([0.0089], device='cuda:3')
self.metric.dtype torch.float32
self.metric.mean_x tensor([-0.0322], device='cuda:2')
self.metric.mean_x.dtype torch.float32
self.metric.var_x tensor([110.2827], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.mean_x tensor([0.0461], device='cuda:0')
self.metric.mean_x.dtype torch.float32
self.metric.var_y tensor([332.8673], device='cuda:3')
self.metric.var_y.dtype torch.float32
self.metric.mean_y tensor([-0.0115], device='cuda:2')
self.metric.dtype torch.float32
self.metric.corr_xy tensor([24.1581], device='cuda:3')
self.metric.corr_xy.dtype torch.float32
self.metric.mean_y tensor([0.1119], device='cuda:0')
self.metric.dtype torch.float32
self.metric.n_total tensor([320.], device='cuda:3')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.var_x tensor([123.9075], device='cuda:2')
self.metric.var_x.dtype torch.float32
self.metric.var_x tensor([134.4380], device='cuda:0')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([323.3611], device='cuda:2')
self.metric.var_y.dtype torch.float32
After DeepSpeed initialization
self.metric.var_y tensor([333.9321], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([18.5053], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric.corr_xy tensor([22.7242], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([320.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.n_total tensor([320.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric2.preds tensor([-0.0986], device='cuda:1')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0986], device='cuda:1')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.1119], device='cuda:1')
self.metric.dtype torch.float32
self.metric.var_x tensor([130.6392], device='cuda:1')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([354.5143], device='cuda:1')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([14.0262], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([320.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Epoch 0:  62%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                             | 5/8 [00:01<00:01,  2.99it/s, v_num=6]After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor([-0.0136], device='cuda:3')
self.metric2.preds.dtype torch.float32
self.metric2.preds tensor([-0.0371], device='cuda:2')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0136], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric.mean_x tensor([-0.0371], device='cuda:2')
self.metric.mean_x.dtype torch.float32
self.metric2.preds tensor([-0.0572], device='cuda:0')
self.metric2.preds.dtype torch.float32
self.metric.mean_y tensor([-0.0470], device='cuda:2')
self.metric.dtype torch.float32
self.metric.mean_y tensor([-0.0164], device='cuda:3')
self.metric.dtype torch.float32
self.metric.var_x tensor([136.4565], device='cuda:2')
self.metric.var_x.dtype torch.float32
self.metric.mean_x tensor([-0.0572], device='cuda:0')
self.metric.mean_x.dtype torch.float32
self.metric.var_x tensor([131.5274], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([370.1378], device='cuda:2')
self.metric.var_y.dtype torch.float32
self.metric.var_y tensor([401.3443], device='cuda:3')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([19.2386], device='cuda:2')
self.metric.mean_y tensor([0.0692], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.dtype torch.float32
self.metric.n_total tensor([384.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.corr_xy tensor([19.0301], device='cuda:3')
self.metric.corr_xy.dtype torch.float32
self.metric.var_x tensor([172.9264], device='cuda:0')
self.metric.var_x.dtype torch.float32
self.metric.n_total tensor([384.], device='cuda:3')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
After DeepSpeed initialization
self.metric.var_y tensor([393.5345], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([24.9133], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([384.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric2.preds tensor([-0.1306], device='cuda:1')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.1306], device='cuda:1')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([0.0875], device='cuda:1')
self.metric.dtype torch.float32
self.metric.var_x tensor([151.7127], device='cuda:1')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([426.1085], device='cuda:1')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([21.1041], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([384.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Epoch 0:  75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                              | 6/8 [00:01<00:00,  3.35it/s, v_num=6]After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor([0.0208], device='cuda:3')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([0.0208], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric2.preds tensor([-0.1341], device='cuda:1')
self.metric2.preds.dtype torch.float32
self.metric.mean_y tensor([-0.0137], device='cuda:3')
self.metric.dtype torch.float32
self.metric2.preds tensor([-0.0103], device='cuda:0')
self.metric2.preds.dtype torch.float32
self.metric.var_x tensor([172.7051], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.mean_x tensor([-0.1341], device='cuda:1')
self.metric.mean_x.dtype torch.float32
self.metric.var_y tensor([453.7365], device='cuda:3')
self.metric.mean_x tensor([-0.0103], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.mean_x.dtype torch.float32
self.metric.corr_xy tensor([36.4642], device='cuda:3')
self.metric.corr_xy.dtype torch.float32
self.metric.mean_y tensor([0.0664], device='cuda:1')
self.metric.dtype torch.float32
self.metric.mean_y tensor([0.0815], device='cuda:0')
self.metric.n_total tensor([448.], device='cuda:3')
self.metric.dtype torch.float32
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.var_x tensor([175.6417], device='cuda:1')
self.metric.var_x.dtype torch.float32
After DeepSpeed initialization
self.metric.var_x tensor([208.9821], device='cuda:0')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([478.1778], device='cuda:1')
self.metric.var_y tensor([450.9269], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([27.2243], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.corr_xy tensor([27.1260], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([448.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.n_total tensor([448.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric2.preds tensor([-0.0849], device='cuda:2')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0849], device='cuda:2')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([-0.0463], device='cuda:2')
self.metric.dtype torch.float32
self.metric.var_x tensor([160.5845], device='cuda:2')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([428.3007], device='cuda:2')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([16.4307], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([448.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Epoch 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                               | 7/8 [00:01<00:00,  3.69it/s, v_num=6]After DeepSpeed initialization
After DeepSpeed initialization
After DeepSpeed initialization
self.metric2.preds tensor([-0.0889], device='cuda:3')
self.metric2.preds.dtype torch.float32
self.metric2.preds tensor([-0.1270], device='cuda:1')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.0889], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([-0.0383], device='cuda:3')
self.metric.dtype torch.float32
self.metric2.preds tensor([-0.0621], device='cuda:0')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.1270], device='cuda:1')
self.metric.mean_x.dtype torch.float32
self.metric.var_x tensor([243.8180], device='cuda:3')
self.metric.var_x.dtype torch.float32
self.metric.mean_x tensor([-0.0621], device='cuda:0')
self.metric.var_y tensor([536.7351], device='cuda:3')
self.metric.mean_x.dtype torch.float32
self.metric.var_y.dtype torch.float32
self.metric.mean_y tensor([0.0673], device='cuda:1')
self.metric.dtype torch.float32
self.metric.corr_xy tensor([36.8554], device='cuda:3')
self.metric.corr_xy.dtype torch.float32
self.metric.mean_y tensor([0.0712], device='cuda:0')
self.metric.dtype torch.float32
self.metric.n_total tensor([512.], device='cuda:3')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.var_x tensor([210.7141], device='cuda:1')
self.metric.var_x.dtype torch.float32
self.metric.var_x tensor([258.7119], device='cuda:0')
self.metric.var_x.dtype torch.float32
After DeepSpeed initialization
self.metric.var_y tensor([556.6730], device='cuda:1')
self.metric.var_y.dtype torch.float32
self.metric.var_y tensor([508.7692], device='cuda:0')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([20.3403], device='cuda:1')
self.metric.corr_xy.dtype torch.float32
self.metric.corr_xy tensor([21.7713], device='cuda:0')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([512.], device='cuda:1')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric.n_total tensor([512.], device='cuda:0')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
self.metric2.preds tensor([-0.1383], device='cuda:2')
self.metric2.preds.dtype torch.float32
self.metric.mean_x tensor([-0.1383], device='cuda:2')
self.metric.mean_x.dtype torch.float32
self.metric.mean_y tensor([-0.0263], device='cuda:2')
self.metric.dtype torch.float32
self.metric.var_x tensor([211.3848], device='cuda:2')
self.metric.var_x.dtype torch.float32
self.metric.var_y tensor([502.9049], device='cuda:2')
self.metric.var_y.dtype torch.float32
self.metric.corr_xy tensor([17.6688], device='cuda:2')
self.metric.corr_xy.dtype torch.float32
self.metric.n_total tensor([512.], device='cuda:2')
self.mae.sum_abs_error.dtype torch.float32
self.mae.total.dtype torch.int64
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.99it/s, v_num=6]/opt/conda/envs/py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
/opt/conda/envs/py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
/opt/conda/envs/py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
/opt/conda/envs/py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1879: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
  warnings.warn(
`Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:02<00:00,  3.91it/s, v_num=6]


Expected behavior

Expecting metrics to be of type float32 when precision set to 16 and logging of metrics to be non-zero.

Environment

Output of torch.utils.collect_env
Collecting environment information...
PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:33:10)  [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.10.0-27-cloud-amd64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA L4
GPU 1: NVIDIA L4
GPU 2: NVIDIA L4
GPU 3: NVIDIA L4

Nvidia driver version: 535.86.10
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             48
On-line CPU(s) list:                0-47
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          1
Stepping:                           7
BogoMIPS:                           4400.36
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          768 KiB (24 instances)
L1i cache:                          768 KiB (24 instances)
L2 cache:                           24 MiB (24 instances)
L3 cache:                           38.5 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-47
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT Host state unknown

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-lightning==2.1.3
[pip3] pytorch-ranger==0.1.1
[pip3] torch==2.1.2+cu121
[pip3] torch-optimizer==0.3.0
[pip3] torchaudio==2.1.2+cu121
[pip3] torchmetrics==1.3.0
[pip3] torchvision==0.16.2+cu121
[pip3] triton==2.1.0
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] pytorch-lightning         2.1.3                    pypi_0    pypi
[conda] pytorch-ranger            0.1.1                    pypi_0    pypi
[conda] torch                     2.1.2+cu121              pypi_0    pypi
[conda] torch-optimizer           0.3.0                    pypi_0    pypi
[conda] torchaudio                2.1.2+cu121              pypi_0    pypi
[conda] torchmetrics              1.3.0                    pypi_0    pypi
[conda] torchvision               0.16.2+cu121             pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

Additional context

All tests run within a VSCode devcontainer

@nickcolossal nickcolossal added bug / fix Something isn't working help wanted Extra attention is needed labels Jan 17, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x v1.3.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
@Borda @nickcolossal and others