Skip to content

Commit

Permalink
mock and reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Dec 1, 2024
1 parent 12955d1 commit e0a87af
Showing 1 changed file with 18 additions and 69 deletions.
87 changes: 18 additions & 69 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,62 +63,6 @@ def event_counter_callback():
return EventCounterCallback()


@pytest.fixture
def trainer(
model,
optimizer,
train_dataset,
evaluator1,
evaluator2,
event_counter_callback,
request,
):
# extract parameters from the test function
params = request.param
precision = params.get('precision', 'fp32')
max_duration = params.get('max_duration', '1ep')
save_interval = params.get('save_interval', '1ep')
device = params.get('device', 'cpu')
deepspeed_zero_stage = params.get('deepspeed_zero_stage', None)
use_fsdp = params.get('use_fsdp', False)

deepspeed_config = None
if deepspeed_zero_stage:
deepspeed_config = {'zero_optimization': {'stage': deepspeed_zero_stage}}

parallelism_config = None
if use_fsdp:
parallelism_config = {
'fsdp': {
'sharding_strategy': 'FULL_SHARD',
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
},
}

return Trainer(
model=model,
train_dataloader=DataLoader(
dataset=train_dataset,
batch_size=4,
sampler=dist.get_sampler(train_dataset),
num_workers=0,
),
eval_dataloader=(evaluator1, evaluator2),
device_train_microbatch_size=2,
precision=precision,
train_subset_num_batches=1,
eval_subset_num_batches=1,
max_duration=max_duration,
save_interval=save_interval,
optimizers=optimizer,
callbacks=[event_counter_callback],
device=device,
deepspeed_config=deepspeed_config,
parallelism_config=parallelism_config,
)


@pytest.mark.parametrize('event', list(Event))
def test_event_values(event: Event):
assert event.name.lower() == event.value
Expand Down Expand Up @@ -177,9 +121,22 @@ def test_event_calls(
event_counter_callback,
):
with patch.object(Trainer, 'save_checkpoint', return_value=None):
# mock forward and backward to speed up
with patch.object(model, 'forward', return_value=torch.tensor(0.0)) as mock_forward, \
patch.object(model, 'backward', return_value=None) as mock_backward:
# mock forward method
with patch.object(model, 'forward', return_value=torch.tensor(0.0)):
# initialize the Trainer with the current parameters
deepspeed_config = None
if deepspeed_zero_stage:
deepspeed_config = {'zero_optimization': {'stage': deepspeed_zero_stage}}

parallelism_config = None
if use_fsdp:
parallelism_config = {
'fsdp': {
'sharding_strategy': 'FULL_SHARD',
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
},
}

trainer_instance = Trainer(
model=model,
Expand All @@ -199,16 +156,8 @@ def test_event_calls(
optimizers=optimizer,
callbacks=[event_counter_callback],
device=device,
deepspeed_config={'zero_optimization': {
'stage': deepspeed_zero_stage,
}} if deepspeed_zero_stage else None,
parallelism_config={
'fsdp': {
'sharding_strategy': 'FULL_SHARD',
'mixed_precision': 'PURE',
'backward_prefetch': 'BACKWARD_PRE',
},
} if use_fsdp else None,
deepspeed_config=deepspeed_config,
parallelism_config=parallelism_config,
)

trainer_instance.fit()
Expand Down

0 comments on commit e0a87af

Please sign in to comment.