From 70f81fb13a09a16e55775ab281cf5a5f94d3d414 Mon Sep 17 00:00:00 2001 From: Xiaoqiang Zheng Date: Thu, 13 Feb 2020 01:22:08 -0800 Subject: [PATCH] Add elementwise benchmarks and comparisons. (#155) --- benchmarks/tensorexpr/elementwise.py | 122 ++++++++++++++++++--- benchmarks/tensorexpr/framework.py | 3 +- test/test_tensorexpr.py | 4 +- torch/csrc/jit/tensorexpr/cuda_codegen.cpp | 27 +++++ torch/csrc/jit/tensorexpr/cuda_codegen.h | 1 + 5 files changed, 140 insertions(+), 17 deletions(-) diff --git a/benchmarks/tensorexpr/elementwise.py b/benchmarks/tensorexpr/elementwise.py index 28b133b5339213..616435351c2ea8 100644 --- a/benchmarks/tensorexpr/elementwise.py +++ b/benchmarks/tensorexpr/elementwise.py @@ -1,7 +1,18 @@ import framework +import itertools +import numpy as np +import torch - -class ElementMulBench(framework.Benchmark): +# A template class for elementwise operations. +# A derived class will override the class instance to customize its behavior. +class ElementBench(framework.Benchmark): + # List of customization class variables. + op_str = None + binary_op_pt_func = None + binary_op_np_func = None + unary_op_pt_func = None + unary_op_np_func = None + split_input = True def __init__(self, mode, device, N): super().__init__(mode, device) self.N = N @@ -11,27 +22,60 @@ def __init__(self, mode, device, N): self.d4 = self.rand([N], device=device, requires_grad=self.requires_grad) self.inputs = [self.d1, self.d2, self.d3, self.d4] + def _eval(self, d1, d2, d3, d4, binary_op, unary_op): + if not binary_op: + binary_op = lambda x, y: x + y + if not unary_op: + unary_op = lambda x: x + if self.split_input: + d1 = unary_op(d1) + d2 = unary_op(d2) + d3 = unary_op(d3) + d4 = unary_op(d4) + else: + d2 = unary_op(d1 + 0.001) + d3 = unary_op(d1 + 0.002) + d4 = unary_op(d1 + 0.003) + d1 = unary_op(d1) + a = binary_op(d1, d2) + b = binary_op(d3, d4) + c = a + b + return c + def forward(self, d1, d2, d3, d4): - y = d1 * d2 + d3 * d4 - return y + binary_op = self.__class__.binary_op_pt_func + unary_op = self.__class__.unary_op_pt_func + return self._eval(d1, d2, d3, d4, binary_op, unary_op) def reference(self): - return self.numpy(self.d1) * self.numpy(self.d2) + self.numpy(self.d3) * self.numpy(self.d4) + binary_op = self.__class__.binary_op_np_func + unary_op = self.__class__.unary_op_np_func + [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] + return self._eval(d1, d2, d3, d4, binary_op, unary_op) def config(self): return [self.N] - @staticmethod - def module(): - return 'element_mul' + @classmethod + def module(cls): + return 'element_' + cls.op_str def memory_workload(self): + input_count = len(self.inputs) if self.mode == 'fwd': - sol_count = 4 + 1 - algorithmic_count = 3 + 1 + if self.split_input: + sol_count = input_count + 1 + algorithmic_count = input_count + 1 + else: + sol_count = 1 + 1 + algorithmic_count = 1 + 1 else: - sol_count = (4 + 1) + (1 + 4) - algorithmic_count = (4 + 1) + ((2 + 1) * 4) + if self.split_input: + sol_count = (input_count + 1) + (1 + input_count) + algorithmic_count = (input_count + 1) + ((2 + 1) * input_count) + else: + sol_count = 1 + 1 + algorithmic_count = 1 + 1 buffer_size = self.N * 4 return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} @@ -41,4 +85,56 @@ def default_configs(): return [[1 << 27]] -framework.register_benchmark_class(ElementMulBench) +def register_element_ops(): + binary_op_list = [ + ["mul", lambda a, b: a * b], + ["add", lambda a, b: a + b], + ["sub", lambda a, b: a - b], + ["div", lambda a, b: a / (b + 1e-4)], + ["pow", lambda a, b: torch.pow(a, b), lambda a, b: np.power(a, b)], # no fuson triggered + ["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)], + ["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)], + ] + + unary_op_list = [ + ["exp", lambda x: torch.exp(x), lambda x: np.exp(x)], + ["sin", lambda x: torch.sin(x), lambda x: np.sin(x)], + ["cos", lambda x: torch.cos(x), lambda x: np.cos(x)], + ] + + for split_input, binary_op in itertools.product([True, False], binary_op_list): + # Make a copy of ElementBench + if len(binary_op) == 2: + [op_str, op_pt_func] = binary_op + op_np_func = op_pt_func + elif len(binary_op) == 3: + [op_str, op_pt_func, op_np_func] = binary_op + split_str = 'split' if split_input else 'shared' + op_str = split_str + '_' + op_str + bm_cls = type('ElementBench_' + op_str, (ElementBench,), {}) + bm_cls.op_str = op_str + bm_cls.binary_op_pt_func = op_pt_func + bm_cls.binary_op_np_func = op_np_func + bm_cls.split_input = split_input + framework.register_benchmark_class(bm_cls) + + for split_input, unary_op in itertools.product([True, False], unary_op_list): + # Make a copy of ElementBench + if len(unary_op) == 2: + [op_str, op_pt_func] = unary_op + op_np_func = op_pt_func + elif len(unary_op) == 3: + [op_str, op_pt_func, op_np_func] = unary_op + split_str = 'split' if split_input else 'shared' + op_str = split_str + '_' + op_str + bm_cls = type('ElementBench_' + op_str, (ElementBench,), {}) + bm_cls.op_str = op_str + bm_cls.unary_op_pt_func = op_pt_func + bm_cls.unary_op_np_func = op_np_func + bm_cls.split_input = split_input + framework.register_benchmark_class(bm_cls) + + +#framework.register_benchmark_class(ElementMulBench) +register_element_ops() + diff --git a/benchmarks/tensorexpr/framework.py b/benchmarks/tensorexpr/framework.py index f70eb03eb8b279..6ad917eb386b4a 100644 --- a/benchmarks/tensorexpr/framework.py +++ b/benchmarks/tensorexpr/framework.py @@ -24,7 +24,7 @@ def forward(self): def check(self): np.testing.assert_allclose( - self.reference(), self.numpy(self.forward(*self.inputs)), atol=1e-7) + self.reference(), self.numpy(self.forward(*self.inputs)), atol=1e-2) def config(self): '''returns an array for the current benchmark configs @@ -81,7 +81,6 @@ def __init__(self, mode, device): method_engine = getattr(self.engine, method) setattr(self, method, method_engine) - def rand(self, shape, device=None, requires_grad=False): v = self.engine.rand(shape, device=device, requires_grad=requires_grad) if requires_grad: diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 60dbd2ba878c43..9689e9ffeab37d 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -98,7 +98,7 @@ def run_addcmul(x, y, z, w): x = traced(rand_a, rand_b, rand_c, rand_d) y = run_addcmul(rand_a, rand_b, rand_c, rand_d) - np.testing.assert_allclose(x.numpy(), y.numpy()) + np.testing.assert_allclose(x.numpy(), y.numpy(), atol=1e-6) def test_three_arg_cuda(): @@ -678,7 +678,7 @@ def test_relu(x, y): traced = torch.jit.trace(torch_fn, (ins, ins)) x = traced(rand_a, rand_b) y = torch_fn(rand_a, rand_b) - np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) # nans traced = torch.jit.trace(torch_fn, (ins, ins)) x = traced(nans, rand_b) diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index 2a36458c4b350d..8131bdc4c9166b 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -124,6 +124,33 @@ void CudaPrinter::visit(const For* v) { } } +void CudaPrinter::visit(const Intrinsics* v) { + std::string func_name; + // TODO: handle other data types. + switch (v->op_type()) { + case IntrinsicsOp::kSin: + func_name = "sinf"; + break; + case IntrinsicsOp::kCos: + func_name = "cosf"; + break; + case IntrinsicsOp::kExp: + func_name = "expf"; + break; + default: + IRPrinter::visit(v); + return; + } + os() << func_name << "("; + for (int i = 0; i < v->nparams(); i++) { + if (i > 0) { + os() << ", "; + } + os() << v->param(i); + } + os() << ")"; +} + void CudaPrinter::visit(const Load* v) { // TODO: find a better metric in using ldg or not. Support different dtypes. os() << "__ldg(" << v->base_handle() << " + " << v->index() << ")"; diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h index e7d133013b3b2a..39275ea5d72572 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.h +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h @@ -36,6 +36,7 @@ class CudaPrinter : public IRPrinter { os() << ")"; } + void visit(const Intrinsics* v); void visit(const For* v); void visit(const Load* v);