Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Oct 25, 2023
1 parent 6562f94 commit f2e05c9
Showing 1 changed file with 128 additions and 49 deletions.
177 changes: 128 additions & 49 deletions onnxruntime/test/python/onnxruntime_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,71 +158,75 @@ def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64):

np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8)

def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101(self):
# Two axis fusion
# [3, 64, 7, 7] -> [192, 7, 7]
# data: (3, 32, 7, 7), (RSRR, [0, 1])
# shape: None, (R, [0, 1])
# reshaped: None, None
# -----------------------------------
# new reshaped: None, (SRR, [0, 1, 0, 1, 0, 1])
self._check_distributed_reshape(
shape=(3, 64, 7, 7),
target_shape=(192, 7, 7),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("RS[0]RR", "R"),
output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])],
output_shard_specs=("S[0]RR",),
)

def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_64_rsrr_01(self):
# Two axis decomposition
# [192, 7, 64] -> [3, 64, 7, 64]
# data: (96, 7, 64), (SRR, [0, 1, 0, 1, 0, 1])
# shape: None, (R, [0, 1])
# reshaped: None, None
# -----------------------------------
# new reshaped: None, (RSRR, [0.0, 1.0])
self._check_distributed_reshape(
shape=(192, 7, 64),
target_shape=(3, 64, 7, 64),
input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2,
input_shard_specs=("S[0]RR", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("RS[0]RR",),
)

def test_reshape_two_axis_fusion_s01r_s01(self):
def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self):
# Two axis fusion.
# S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1]
self._check_distributed_reshape(
shape=(2, 3),
shape=(
2,
3,
),
target_shape=(6,),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("S[0]R", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("S[0]",),
)

def test_reshape_two_axis_fusion_rs01_s0101(self):
def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self):
# Two axis fusion.
# RS[0], shape=[2, 4], device_mesh=[0, 1] -> S[0], shape = [8], device_mesh=[0, 1, 0, 1]
self._check_distributed_reshape(
shape=(2, 4),
shape=(
2,
4,
),
target_shape=(8,),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("RS[0]", "R"),
output_device_meshs=[np.array([0, 1, 0, 1])],
output_shard_specs=("S[0]",),
)

def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self):
# Two axis fusion.
# S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1]
self._check_distributed_reshape(
shape=(
2,
3,
5,
),
target_shape=(
2,
15,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("S[0]RR", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self):
# Two axis fusion.
# RS[0]R, shape=[2, 3, 5], device_mesh=[0, 1] -> RS[0], shape = [2, 15], device_mesh=[0, 1]
# RS[0]R, shape=[2, 4, 5], device_mesh=[0, 1] -> RS[0], shape = [2, 20], device_mesh=[0, 1]
self._check_distributed_reshape(
shape=(
2,
4,
5,
),
target_shape=(
2,
20,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("RS[0]R", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("RS[0]",),
)

def test_reshape_two_axis_fusion_rrs01_rs010101(self):
def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self):
# Two axis fusion.
# RRS[0], shape=[2, 3, 6], device_mesh=[0, 1] -> RS[0], shape = [2, 18], device_mesh=[0, 1, 0, 1, 0, 1]
self._check_distributed_reshape(
Expand All @@ -246,7 +250,7 @@ def test_reshape_two_axis_fusion_rrs01_rs010101(self):
# Two axis fusion.
# RS[0]R, shape=[2, 8, 3], device_mesh=[0, 1, 0, 1] -> RS[0], shape = [2, 24], device_mesh=[0, 1, 0, 1]

def test_reshape_two_axis_decomposition_s01_s01r(self):
def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self):
# Two axis decomposition
# S[0], shape=[6], device_mesh=[0, 1] -> S[0]R, shape=[2, 3], device_mesh=[0, 1]
self._check_distributed_reshape(
Expand All @@ -261,10 +265,22 @@ def test_reshape_two_axis_decomposition_s01_s01r(self):
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1]
self._check_distributed_reshape(
shape=(16,),
target_shape=(
1,
16,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("S[0]", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("RS[0]",),
)

def test_reshape_two_axis_decomposition_shape_16_s01_shape_2_8_s01r(self):
def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[2, 8], device_mesh=[0, 1]
self._check_distributed_reshape(
Expand All @@ -279,12 +295,24 @@ def test_reshape_two_axis_decomposition_shape_16_s01_shape_2_8_s01r(self):
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1]
self._check_distributed_reshape(
shape=(16,),
target_shape=(
4,
4,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("S[0]", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_decomposition_shape_16_s01_shape_8_2_s01r(self):
# // Two axis decomposition
# // S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1]
def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1]
self._check_distributed_reshape(
shape=(16,),
target_shape=(
Expand All @@ -297,13 +325,38 @@ def test_reshape_two_axis_decomposition_shape_16_s01_shape_8_2_s01r(self):
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1]
self._check_distributed_reshape(
shape=(16,),
target_shape=(
16,
1,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("S[0]", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1, 0, 1]

def test_reshape_two_axis_decomposition_shape_16_s0101_shape_2_8_rs01(self):
self._check_distributed_reshape(
shape=(16,),
target_shape=(
1,
16,
),
input_device_meshs=[np.array([0, 1, 0, 1])] * 2,
input_shard_specs=("S[0]", "R"),
output_device_meshs=[np.array([0, 1, 0, 1])],
output_shard_specs=("RS[0]",),
)

def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self):
# Two axis decomposition
# repeats=2 8 = repeats * [unique IDs]
# S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[2, 8], device_mesh=[0, 1]
Expand All @@ -319,7 +372,7 @@ def test_reshape_two_axis_decomposition_shape_16_s0101_shape_2_8_rs01(self):
output_shard_specs=("RS[0]",),
)

def test_reshape_two_axis_decomposition_shape_16_s0101_shape_4_4_s0101r(self):
def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1, 0, 1]
self._check_distributed_reshape(
Expand All @@ -334,9 +387,35 @@ def test_reshape_two_axis_decomposition_shape_16_s0101_shape_4_4_s0101r(self):
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1, 0, 1]
self._check_distributed_reshape(
shape=(16,),
target_shape=(
8,
2,
),
input_device_meshs=[np.array([0, 1, 0, 1])] * 2,
input_shard_specs=("S[0]", "R"),
output_device_meshs=[np.array([0, 1, 0, 1])],
output_shard_specs=("S[0]R",),
)

def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self):
# Two axis decomposition
# S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1, 0, 1]
self._check_distributed_reshape(
shape=(16,),
target_shape=(
16,
1,
),
input_device_meshs=[np.array([0, 1, 0, 1])] * 2,
input_shard_specs=("S[0]", "R"),
output_device_meshs=[np.array([0, 1, 0, 1])],
output_shard_specs=("S[0]R",),
)

# llama case:
# [1, 7] -> [-1, 7]
Expand All @@ -353,7 +432,7 @@ def test_reshape_two_axis_decomposition_shape_16_s0101_shape_4_4_s0101r(self):
# -----------------------------------
# new reshaped: None, (RR, [0, 1])

def test_reshape_two_axis_decomposition_shape_21_4096_s01_shape_3_7_4096_rrs01(self):
def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01(self):
# Two axis decomposition
# [21, 4096] -> [3, 7, 4096]
# data: (21, 2048), (RS, [0, 1])
Expand All @@ -378,7 +457,7 @@ def test_reshape_two_axis_decomposition_shape_21_4096_s01_shape_3_7_4096_rrs01(s
output_shard_specs=("RRS[0]",),
)

def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs01_shape_3_7_64_64_rrs01r(self):
def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rrsr_01(self):
# Two axis decomposition
# [3, 7, 4096] -> [3, 7, 64, 64]
# data: (3, 7, 2048), (RRS, [0, 1])
Expand Down

0 comments on commit f2e05c9

Please sign in to comment.