Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove skip_first_batches support for StatefulDataloader and fix all the tests #3068

Merged
merged 6 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,11 +1160,11 @@ def prepare_data_loader(
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
Should not be used if the original dataloader is a `StatefulDataLoader`.
"""

def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.sampler = batch_sampler.sampler
self.skip_batches = skip_batches

def __iter__(self):
Expand All @@ -1182,7 +1182,8 @@ def __len__(self):

class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use
`skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

Args:
dataset (`torch.utils.data.dataset.Dataset`):
Expand Down Expand Up @@ -1211,11 +1212,9 @@ def __iter__(self):

def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if
the original dataloader is a `StatefulDataLoader`.
"""
if is_torchdata_stateful_dataloader_available():
from torchdata.stateful_dataloader import StatefulDataLoader

state = PartialState()
if state.distributed_type == DistributedType.XLA:
device = dataloader.device
Expand Down Expand Up @@ -1259,7 +1258,6 @@ def skip_first_batches(dataloader, num_batches=0):
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
Expand All @@ -1276,17 +1274,12 @@ def skip_first_batches(dataloader, num_batches=0):
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(
dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs
)
elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader):
dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)

Expand Down
106 changes: 46 additions & 60 deletions src/accelerate/test_utils/scripts/external_deps/test_pippy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision.models import resnet34
from transformers import (
BertConfig,
BertForMaskedLM,
GPT2Config,
GPT2ForSequenceClassification,
T5Config,
T5ForConditionalGeneration,
)

from accelerate import PartialState
from accelerate.inference import prepare_pippy
from accelerate.utils import DistributedType, send_to_device, set_seed
from accelerate.utils import DistributedType, set_seed


model_to_config = {
"t5": (T5ForConditionalGeneration, T5Config, 1024),
"bert": (BertForMaskedLM, BertConfig, 512),
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
}
Expand All @@ -42,23 +38,19 @@ def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
# config_args["pad_token_id"] = 0
model_config = config(**config_args)
model = initializer(model_config)
return model, torch.randint(
low=0,
high=model_config.vocab_size,
size=(num_processes, seq_len),
device=device,
dtype=torch.int64,
requires_grad=False,
)
kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False)
trace_input = torch.randint(size=(1, seq_len), **kwargs)
inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs)
return model, trace_input, inference_inputs


def test_gpt2(batch_size: int = 2):
def test_bert(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inputs.to("cuda")
inputs = inference_inputs.to("cuda")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
Expand All @@ -68,63 +60,57 @@ def test_gpt2(batch_size: int = 2):
assert output is not None, "Output was not generated in the last process!"


def test_t5(batch_size: int = 2):
def test_gpt2(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size)
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
model = prepare_pippy(
model,
no_split_module_classes=model._no_split_modules,
example_kwargs=example_inputs,
)
model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = send_to_device(example_inputs, "cuda:0")
inputs = inference_inputs.to("cuda")
with torch.no_grad():
output = model(*inputs.values())
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"


def test_resnet(batch_size: int = 2):
set_seed(42)
state = PartialState()
model = resnet34()
input_tensor = torch.rand(batch_size, 3, 224, 224)
model = prepare_pippy(
model,
example_args=(input_tensor,),
)
inputs = send_to_device(input_tensor, "cuda:0")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"
# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34
# def test_resnet(batch_size: int = 2):
# set_seed(42)
# state = PartialState()
# model = resnet34()
# input_tensor = torch.rand(1, 3, 224, 224)
# model = prepare_pippy(
# model,
# example_args=(input_tensor,),
# )
# inference_inputs = torch.rand(batch_size, 3, 224, 224)
# inputs = send_to_device(inference_inputs, "cuda:0")
# with torch.no_grad():
# output = model(inputs)
# # Zach: Check that we just grab the real outputs we need at the end
# if not state.is_last_process:
# assert output is None, "Output was not generated on just the last process!"
# else:
# assert output is not None, "Output was not generated in the last process!"


if __name__ == "__main__":
state = PartialState()
state.print("Testing pippy integration...")
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing T5...")
test_t5()
test_t5(1)
test_t5(3)
state.print("Testing CV model...")
test_resnet()
test_resnet(3)
try:
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing BERT...")
test_bert()
else:
print("Less than two GPUs found, not running tests!")
finally:
state.destroy_process_group()
else:
print("Less than two GPUs found, not running tests!")
7 changes: 0 additions & 7 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,6 @@ def test_skip_data_loader(self):
assert isinstance(dataloader, StatefulDataLoader)
assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]

@require_torchdata_stateful_dataloader
def test_skip_first_batches(self):
dataloader = StatefulDataLoader(list(range(16)), batch_size=4)
new_dataloader = skip_first_batches(dataloader, num_batches=2)
assert isinstance(new_dataloader, StatefulDataLoader)
assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]

@require_torchdata_stateful_dataloader
def test_end_of_dataloader(self):
dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from unittest import mock, skip

import torch

Expand Down Expand Up @@ -261,6 +261,9 @@ def test_ddp_comm_hook(self):
testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"]
run_command(self.launch_args + testargs)

@skip(
reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added."
)
@require_multi_device
def test_distributed_inference_examples_stable_diffusion(self):
testargs = ["examples/inference/distributed/stable_diffusion.py"]
Expand Down
Loading