Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem in training iterable dataset #6437

Open
21Timothy opened this issue Nov 20, 2023 · 5 comments
Open

Problem in training iterable dataset #6437

21Timothy opened this issue Nov 20, 2023 · 5 comments

Comments

@21Timothy
Copy link

21Timothy commented Nov 20, 2023

Describe the bug

I am using PyTorch DDP (Distributed Data Parallel) to train my model. Since the data is too large to load into memory at once, I am using load_dataset to read the data as an iterable dataset. I have used datasets.distributed.split_dataset_by_node to distribute the dataset. However, I have noticed that this distribution results in different processes having different amounts of data to train on. As a result, when the earliest process finishes training and starts predicting on the test set, other processes are still training, causing the overall training speed to be very slow.

Steps to reproduce the bug

def train(args, model, device, train_loader, optimizer, criterion, epoch, length):
    model.train()
    idx_length = 0
    for batch_idx, data in enumerate(train_loader):
        s_time = time.time()
        X = data['X']
        target = data['y'].reshape(-1, 28)
        X, target = X.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        idx_length += 1
        if batch_idx % args.log_interval == 0:
            # print('Train Epoch: {} Batch_idx: {} Process: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #     epoch, batch_idx, torch.distributed.get_rank(), batch_idx * len(X), length / torch.distributed.get_world_size(),
            #                                          100. * batch_idx * len(
            #                                              X) * torch.distributed.get_world_size() / length, loss.item()))
            print('Train Epoch: {} Batch_idx: {} Process: {} [{}/{} ({:.0f}%)]\t'.format(
                epoch, batch_idx, torch.distributed.get_rank(), batch_idx * len(X), length / torch.distributed.get_world_size(),
                                                                100. * batch_idx * len(
                                                                    X) * torch.distributed.get_world_size() / length))
            if args.dry_run:
                break
    print('Process %s length: %s time: %s' % (torch.distributed.get_rank(), idx_length, datetime.datetime.now()))

train_iterable_dataset = load_dataset("parquet", data_files=data_files, split="train", streaming=True)
test_iterable_dataset = load_dataset("parquet", data_files=data_files, split="test", streaming=True)
train_iterable_dataset = train_iterable_dataset.map(process_fn)
test_iterable_dataset = test_iterable_dataset.map(process_fn)
train_iterable_dataset = train_iterable_dataset.map(scale)
test_iterable_dataset = test_iterable_dataset.map(scale)

train_iterable_dataset = datasets.distributed.split_dataset_by_node(train_iterable_dataset,
                                                                    world_size=world_size, rank=local_rank).shuffle(seed=1234)
test_iterable_dataset = datasets.distributed.split_dataset_by_node(test_iterable_dataset,
                                                                   world_size=world_size, rank=local_rank).shuffle(seed=1234)
print(torch.distributed.get_rank(), train_iterable_dataset.n_shards, test_iterable_dataset.n_shards)

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
    cuda_kwargs = {'num_workers': 3,#ngpus_per_node,
                   'pin_memory': True,
                   'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_iterable_dataset, **train_kwargs,
                                           # sampler=torch.utils.data.distributed.DistributedSampler(
                                           #     train_iterable_dataset,
                                           #     num_replicas=ngpus_per_node,
                                           #     rank=0)
                                           )
test_loader = torch.utils.data.DataLoader(test_iterable_dataset, **test_kwargs,
                                          # sampler=torch.utils.data.distributed.DistributedSampler(
                                          #     test_iterable_dataset,
                                          #     num_replicas=ngpus_per_node,
                                          #     rank=0)
                                          )
for epoch in range(1, args.epochs + 1):
    start_time = time.time()
    train_iterable_dataset.set_epoch(epoch)
    test_iterable_dataset.set_epoch(epoch)
    train(args, model, device, train_loader, optimizer, criterion, epoch, train_len)
    test(args, model, device, criterion2, test_loader)

And here’s the part of output:

Train Epoch: 1 Batch_idx: 5000 Process: 0 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5000 Process: 1 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5000 Process: 2 [320000/4710975.0 (7%)]	
Train Epoch: 1 Batch_idx: 5862 Process: 3 Data_length: 12 coststime: 0.04095172882080078
Train Epoch: 1 Batch_idx: 5862 Process: 0 Data_length: 3 coststime: 0.0751960277557373
Train Epoch: 1 Batch_idx: 5867 Process: 3 Data_length: 49 coststime: 0.0032558441162109375
Train Epoch: 1 Batch_idx: 5872 Process: 1 Data_length: 2 coststime: 0.022842884063720703
Train Epoch: 1 Batch_idx: 5876 Process: 3 Data_length: 63 coststime: 0.002694845199584961
Process 3 length: 5877 time: 2023-11-17 17:03:26.582317
Train epoch 1 costTime: 241.72063446044922s . Process 3 Start to test.
3 0 tensor(45508.8516, device='cuda:3')
3 100 tensor(45309.0469, device='cuda:3')
3 200 tensor(45675.3047, device='cuda:3')
3 300 tensor(45263.0273, device='cuda:3')
Process 3 Reduce metrics.
Train Epoch: 2 Batch_idx: 0 Process: 3 [0/4710975.0 (0%)]	
Train Epoch: 1 Batch_idx: 5882 Process: 1 Data_length: 63 coststime: 0.05185818672180176
Train Epoch: 1 Batch_idx: 5887 Process: 1 Data_length: 12 coststime: 0.006895303726196289
Process 1 length: 5888 time: 2023-11-17 17:20:48.578204
Train epoch 1 costTime: 1285.7279663085938s . Process 1 Start to test.
1 0 tensor(45265.9141, device='cuda:1')

Expected behavior

I'd like to know how to fix this problem.

Environment info

torch==2.0
datasets==2.14.0
@21Timothy
Copy link
Author

Has anyone ever encountered this problem before?

@lhoestq
Copy link
Member

lhoestq commented Nov 29, 2023

split_dataset_by_node doesn't give the exact same number of examples to each node in the case of iterable datasets, though it tries to be as equal as possible. In particular if your dataset is sharded and you have a number of shards that is a factor of the number of workers, then the shards will be evenly distributed among workers. If the shards don't contain the same number of examples, then some workers might end up with more examples than others.

However if you use a Dataset you'll end up with the same amount of data, because we know the length of the dataset we can split it exactly where we want. Also Dataset objects don't load the full dataset in memory; instead it memory maps Arrow files from disk.

@21Timothy
Copy link
Author

split_dataset_by_node doesn't give the exact same number of examples to each node in the case of iterable datasets, though it tries to be as equal as possible. In particular if your dataset is sharded and you have a number of shards that is a factor of the number of workers, then the shards will be evenly distributed among workers. If the shards don't contain the same number of examples, then some workers might end up with more examples than others.

However if you use a Dataset you'll end up with the same amount of data, because we know the length of the dataset we can split it exactly where we want. Also Dataset objects don't load the full dataset in memory; instead it memory maps Arrow files from disk.

Thanks for your answer! I finally solve it by using the torch.distributed.algorithms.join.Join. I think maybe some rookie like me would face the same question the day after tomorrow hh.

@lhoestq
Copy link
Member

lhoestq commented Apr 29, 2024

Great ! Maybe it can be worth having an example that we can include in the docs for other people, did you need anything else than the Join context manager used with the model and optimizer ?

@21Timothy
Copy link
Author

Great ! Maybe it can be worth having an example that we can include in the docs for other people, did you need anything else than the Join context manager used with the model and optimizer ?

I think it's none. I have tried barrier() to solve the problem but I failed. Maybe it's a tool for other situation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants