Skip to content

Commit

Permalink
Modify existing fuser tests to suite TE fuser (pytorch#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
bertmaher authored Mar 3, 2020
1 parent 292d4a2 commit 197bbc4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
48 changes: 24 additions & 24 deletions test/test_jit_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)

FUSION_GROUP = 'tensorexpr::Group'

def strip_profiling_nodes(nodes):
profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut'])
Expand Down Expand Up @@ -57,11 +58,11 @@ def assertAllFused(self, graph, except_for=()):
self.assertEqual(len(diff_graphs), 1)
graph = diff_graphs[0].g('Subgraph')

allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::BailoutTemplate',
allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
'prim::BailOut', 'prim::TupleConstruct'} | set(except_for)
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
'got {}'.format(graph))
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
self.assertTrue([node.kind() for node in graph.nodes()].count(FUSION_GROUP) == 1)

def _test_fused_abs(self, device='cpu'):
def func(x):
Expand All @@ -72,25 +73,31 @@ def func(x):
self.assertAllFused(scripted.graph_for(a))

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_abs_cpu(self):
self._test_fused_abs()

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_abs_cuda(self):
self._test_fused_abs(device="cuda")

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_zero_element_tensors(self):
def _test_zero_element_tensors(self, device="cpu"):
def decode(sin_t, cos_t):
theta = torch.atan2(sin_t.float(), cos_t.float())
return theta

sin = torch.zeros(0, device="cuda")
cos = torch.zeros(0, device="cuda")
sin = torch.zeros(0, device=device)
cos = torch.zeros(0, device=device)
inputs = [sin, cos]
ge = self.checkScript(decode, inputs)

@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_zero_element_tensors_cuda(self):
self._test_zero_element_tensors(device="cuda")

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
def test_zero_element_tensors_cpu(self):
self._test_zero_element_tensors(device="cpu")

@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
def test_arg_configurations_smoke_cuda(self):
# A smoke test to make sure we won't use the same kernel for contiguous
Expand Down Expand Up @@ -216,7 +223,6 @@ def chunk_4_last(x):
self.checkScript(fn, [tensor])

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_chunk_correctness(self):
return self._test_chunk_correctness(self, 'cpu')

Expand All @@ -235,7 +241,7 @@ def f(x, y):

ge = self.checkTrace(f, (x, y))
graph = ge.graph_for(x, y)
FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_') \
FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \
.check_count('ConstantChunk', 2, exactly=True).run(str(graph))

@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
Expand All @@ -256,7 +262,7 @@ def func2(x):
for func in [func1, func2]:
module = self.checkScript(func, inputs)
forward_graph = module.graph_for(*inputs)
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1)
fusion_group = list(forward_graph.nodes())[-1]
self.assertEqual(len(list(fusion_group.inputs())), 1)

Expand Down Expand Up @@ -498,7 +504,7 @@ def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph):
self.assertNotIn(node_not_in_graph, rep)
self.assertIn(node_not_in_graph, rep_noopt)

fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
fusion_groups = [node for node in graph.nodes() if node.kind() == FUSION_GROUP]
self.assertEqual(len(fusion_groups), 1)
fused_graph = str(fusion_groups[0].g('Subgraph'))
for node_in_fusegraph in in_fusegraph:
Expand Down Expand Up @@ -549,7 +555,6 @@ def fn_test_scalar_arg_requires_grad(x, p):

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@unittest.skip("deduplicating introduces aliasing in backward graph's outputs")
@enable_cpu_fuser
def test_fuser_deduplication(self):
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
# see the discussion in PR #14957.
Expand All @@ -571,7 +576,6 @@ def f(x, y):
self.assertEqual(ga2.data_ptr(), gb2.data_ptr())

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
@unittest.skip("temporarily disabled because fusion was restricted in fixing #22833")
def test_fuser_iou(self):
# This checks if most of Intersection over Union is fused.
Expand Down Expand Up @@ -615,7 +619,6 @@ def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):

@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
@enable_cpu_fuser
def test_fusion_reuse_multi_gpu(self):
def fn(x, y):
return x * y * x * y
Expand All @@ -635,7 +638,6 @@ def fn(x, y):

@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
@enable_cpu_fuser
def test_kernel_cache_multi_gpu(self):
def not_fusible(x):
return x
Expand All @@ -658,10 +660,11 @@ def fn(x, y, z):
# should reuse the same KernelSpec in the KernelSpec cache.
ge = self.checkScript(fn, inputs)
self.assertGraphContainsExactly(
ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
ge.graph_for(*inputs), FUSION_GROUP, 3, True)
new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
# XXX: This assumes that the same kernel isn't already used by another test
self.assertEqual(new_cache_size - prev_cache_size, 1)
# FIXME: Use the TE fuser's way of querying the cache.
# self.assertEqual(new_cache_size - prev_cache_size, 1)

@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
def test_nonzero_device_cuda(self):
Expand All @@ -682,7 +685,7 @@ def test_lstm_cuda(self):
return
forward_graph = module.graph_for(*inputs)
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
forward_graph, FUSION_GROUP, 1, consider_subgraphs=True)
self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
# Everything is differentiable but TupleConstruct return
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
Expand Down Expand Up @@ -722,7 +725,7 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
inputs = get_lstm_inputs('cuda', training=False)
self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
forward_graph = cu.cell.graph_for(*inputs)
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1)

# TODO: Fuser doesn't work at all when inputs require grad. Fix that
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
Expand All @@ -737,7 +740,6 @@ def test_lstm_traced_cuda(self):

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
@enable_cpu_fuser
def test_lstm_traced_cpu(self):
inputs = get_lstm_inputs('cpu')
try:
Expand All @@ -759,7 +761,7 @@ def test_milstm_cuda(self):
module = self.checkScript(MiLSTMCell, inputs)
forward_graph = module.graph_for(*inputs)
self.assertGraphContainsExactly(
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
forward_graph, FUSION_GROUP, 1, consider_subgraphs=True)
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
.check_next("return").check("FusionGroup").run(str(forward_graph))
hy, cy = module(*inputs)
Expand Down Expand Up @@ -836,7 +838,6 @@ def fn_test_rand(x, y):
self.assertEqual(out[0], out[1])

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_scalar(self):
def fn(x, y):
return 2 * x + y
Expand Down Expand Up @@ -879,10 +880,9 @@ def should_not_fuse(x, z):
]
ge = self.checkScript(should_not_fuse, inputs)
self.assertGraphContainsExactly(
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
ge.graph_for(*inputs), FUSION_GROUP, 0, consider_subgraphs=True)

@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
@enable_cpu_fuser
def test_where_and_typing(self):
def f(x, y):
mask = x > y
Expand Down
9 changes: 9 additions & 0 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
from test_jit_fuser import *


if __name__ == "__main__":
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_set_texpr_fuser_enabled(True)
run_tests()

0 comments on commit 197bbc4

Please sign in to comment.