Skip to content

Commit

Permalink
make test_pipe more stable (microsoft#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaden Smith authored Jan 20, 2021
1 parent 7b0bee0 commit e59ba12
Showing 1 changed file with 40 additions and 22 deletions.
62 changes: 40 additions & 22 deletions tests/unit/test_pipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import copy

import torch
import torch.nn as nn
Expand All @@ -13,8 +14,7 @@

from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology
PipeTopo = PipeDataParallelTopology
import deepspeed.runtime.pipe.module as PipelineModule
from deepspeed.runtime.pipe.module import LayerSpec
from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec

from common import distributed_test

Expand Down Expand Up @@ -74,7 +74,13 @@ def forward(self, x, y):
return self.loss_fn(x, y)


class AlexNetPipe(PipelineModule.PipelineModule):
class AlexNetPipe(AlexNet):
def to_layers(self):
layers = [*self.features, lambda x: x.view(x.size(0), -1), self.classifier]
return layers


class AlexNetPipeSpec(PipelineModule):
def __init__(self, num_classes=10, **kwargs):
self.num_classes = num_classes
specs = [
Expand Down Expand Up @@ -135,6 +141,9 @@ def train_cifar(model, args, num_steps=400, average_dp_losses=True, fp16=True, s
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
ds_utils.set_random_seed(seed)

# disable dropout
model.eval()

trainset = cifar_trainset(fp16=fp16)
args.local_rank = dist.get_rank()

Expand All @@ -148,7 +157,7 @@ def train_cifar(model, args, num_steps=400, average_dp_losses=True, fp16=True, s
for step in range(num_steps):
loss = engine.train_batch()
losses.append(loss.item())
if step % 50 == 0:
if step % 50 == 0 and dist.get_rank() == 0:
print(f'STEP={step} LOSS={loss.item()}')

if average_dp_losses:
Expand All @@ -160,18 +169,16 @@ def train_cifar(model, args, num_steps=400, average_dp_losses=True, fp16=True, s
return losses


@pytest.mark.parametrize('base_topo,test_topo',
@pytest.mark.parametrize('topo',
[
(PipeTopo(num_pp=1,
num_dp=4),
PipeTopo(num_pp=2,
num_dp=2)),
(PipeTopo(num_pp=1,
num_dp=4),
PipeTopo(num_pp=4,
num_dp=1)),
PipeTopo(num_pp=1,
num_dp=4),
PipeTopo(num_pp=2,
num_dp=2),
PipeTopo(num_pp=4,
num_dp=1),
])
def test_pipe_cifar10_seedlayers(base_topo, test_topo, tmpdir):
def test_pipe_cifar10(topo, tmpdir):
config_dict = {
"train_batch_size": 16,
"train_micro_batch_size_per_gpu": 4,
Expand Down Expand Up @@ -199,21 +206,32 @@ def test_pipe_cifar10_seedlayers(base_topo, test_topo, tmpdir):
}
args = args_from_dict(tmpdir, config_dict)

# Allocate model for consistent initial weights.
init_net = AlexNetPipe()

@distributed_test(world_size=4)
def _helper(base_topo, test_topo, tmpdir, steps=500):
def _helper(topo, tmpdir, steps=500):
assert steps >= 100

base_model = AlexNetPipe(num_classes=10,
topology=base_topo,
seed_layers=config_dict['pipeline']['seed_layers'])
base_net = copy.deepcopy(init_net)
base_model = PipelineModule(layers=base_net.to_layers(),
num_stages=1,
loss_fn=nn.CrossEntropyLoss())

# Train with just data parallelism
base_losses = train_cifar(base_model,
args,
num_steps=steps,
fp16=config_dict['fp16']['enabled'])

test_model = AlexNetPipe(num_classes=10,
topology=test_topo,
seed_layers=config_dict['pipeline']['seed_layers'])
test_net = copy.deepcopy(init_net)
test_model = PipelineModule(layers=test_net.to_layers(),
topology=topo,
loss_fn=nn.CrossEntropyLoss())

#test_model = AlexNetPipe(num_classes=10,
# topology=test_topo,
# seed_layers=config_dict['pipeline']['seed_layers'])
test_losses = train_cifar(test_model,
args,
num_steps=steps,
Expand Down Expand Up @@ -246,4 +264,4 @@ def _helper(base_topo, test_topo, tmpdir, steps=500):
test_avg = sum(test) / len(test)
assert rel_diff(base_avg, test_avg) < 0.03

_helper(base_topo, test_topo, tmpdir)
_helper(topo, tmpdir)

0 comments on commit e59ba12

Please sign in to comment.