This repository has been archived by the owner on Oct 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 361
/
Copy pathtest_mlp.py
85 lines (64 loc) · 2.71 KB
/
test_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import unittest
import os
import jax
import jax.numpy as jnp
import optax
import ray
from alpa import init, parallelize, PipeshardParallel
from alpa.model.model_util import TrainState
from alpa.parallel_method import LocalPipelineParallel
from alpa.pipeline_parallel.layer_construction import manual_layer_construction
from alpa.testing import MLPModel, assert_allclose
class PipelineMLPTest(unittest.TestCase):
def setUp(self):
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
init(cluster="ray")
def train_2_layer_mlp(self, method):
def train_step(state, batch):
@manual_layer_construction
def loss_func(params, x, y):
out = state.apply_fn(params, x)
loss = jnp.mean((out - y)**2)
return loss
# Note, we can only use jax.grad in this testcase.
# TODO: Fix https://github.com/alpa-projects/alpa/issues/560
grads = jax.grad(loss_func)(state.params, batch["x"], batch["y"])
return grads
batch_size = 64
hidden_size = 1024
x = jnp.ones((batch_size, hidden_size))
y = jnp.ones((batch_size, hidden_size))
# Init model and optimizer
model = MLPModel(num_layers=4,
hidden_size=hidden_size,
add_manual_pipeline_marker=True)
rngkey = jax.random.PRNGKey(0)
params = model.init(rngkey, x)
tx = optax.sgd(learning_rate=1e-2)
state = TrainState.create(apply_fn=model.apply,
params=params,
tx=tx,
dynamic_scale=None)
# Train step
batch = {"x": x, "y": y}
gradients = train_step(state, batch)
p_train_step = parallelize(train_step, donate_argnums=(), method=method)
gradients_with_pipeline = p_train_step(state, batch)
# Check results
assert_allclose(gradients, gradients_with_pipeline)
# Check debug utilities
if isinstance(method, PipeshardParallel):
executable = p_train_step.get_last_executable()
executable.dump_debug_info("tmp")
def test_2_layer_mlp_local_pipeline_parallel(self):
self.train_2_layer_mlp(LocalPipelineParallel())
def test_2_layer_mlp_pipeshard_parallel(self):
self.train_2_layer_mlp(PipeshardParallel(layer_option="manual"))
def suite():
suite = unittest.TestSuite()
suite.addTest(PipelineMLPTest("test_2_layer_mlp_local_pipeline_parallel"))
suite.addTest(PipelineMLPTest("test_2_layer_mlp_pipeshard_parallel"))
return suite
if __name__ == '__main__':
runner = unittest.TextTestRunner()
runner.run(suite())