Skip to content

Commit

Permalink
Add elementwise benchmarks and comparisons. (pytorch#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-xq authored and Mikhail Zolotukhin committed Mar 3, 2020
1 parent 4998cc6 commit 9642333
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 17 deletions.
122 changes: 109 additions & 13 deletions benchmarks/tensorexpr/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import framework
import itertools
import numpy as np
import torch


class ElementMulBench(framework.Benchmark):
# A template class for elementwise operations.
# A derived class will override the class instance to customize its behavior.
class ElementBench(framework.Benchmark):
# List of customization class variables.
op_str = None
binary_op_pt_func = None
binary_op_np_func = None
unary_op_pt_func = None
unary_op_np_func = None
split_input = True
def __init__(self, mode, device, N):
super().__init__(mode, device)
self.N = N
Expand All @@ -11,27 +22,60 @@ def __init__(self, mode, device, N):
self.d4 = self.rand([N], device=device, requires_grad=self.requires_grad)
self.inputs = [self.d1, self.d2, self.d3, self.d4]

def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
if not binary_op:
binary_op = lambda x, y: x + y
if not unary_op:
unary_op = lambda x: x
if self.split_input:
d1 = unary_op(d1)
d2 = unary_op(d2)
d3 = unary_op(d3)
d4 = unary_op(d4)
else:
d2 = unary_op(d1 + 0.001)
d3 = unary_op(d1 + 0.002)
d4 = unary_op(d1 + 0.003)
d1 = unary_op(d1)
a = binary_op(d1, d2)
b = binary_op(d3, d4)
c = a + b
return c

def forward(self, d1, d2, d3, d4):
y = d1 * d2 + d3 * d4
return y
binary_op = self.__class__.binary_op_pt_func
unary_op = self.__class__.unary_op_pt_func
return self._eval(d1, d2, d3, d4, binary_op, unary_op)

def reference(self):
return self.numpy(self.d1) * self.numpy(self.d2) + self.numpy(self.d3) * self.numpy(self.d4)
binary_op = self.__class__.binary_op_np_func
unary_op = self.__class__.unary_op_np_func
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
return self._eval(d1, d2, d3, d4, binary_op, unary_op)

def config(self):
return [self.N]

@staticmethod
def module():
return 'element_mul'
@classmethod
def module(cls):
return 'element_' + cls.op_str

def memory_workload(self):
input_count = len(self.inputs)
if self.mode == 'fwd':
sol_count = 4 + 1
algorithmic_count = 3 + 1
if self.split_input:
sol_count = input_count + 1
algorithmic_count = input_count + 1
else:
sol_count = 1 + 1
algorithmic_count = 1 + 1
else:
sol_count = (4 + 1) + (1 + 4)
algorithmic_count = (4 + 1) + ((2 + 1) * 4)
if self.split_input:
sol_count = (input_count + 1) + (1 + input_count)
algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
else:
sol_count = 1 + 1
algorithmic_count = 1 + 1

buffer_size = self.N * 4
return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count}
Expand All @@ -41,4 +85,56 @@ def default_configs():
return [[1 << 27]]


framework.register_benchmark_class(ElementMulBench)
def register_element_ops():
binary_op_list = [
["mul", lambda a, b: a * b],
["add", lambda a, b: a + b],
["sub", lambda a, b: a - b],
["div", lambda a, b: a / (b + 1e-4)],
["pow", lambda a, b: torch.pow(a, b), lambda a, b: np.power(a, b)], # no fuson triggered
["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)],
["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)],
]

unary_op_list = [
["exp", lambda x: torch.exp(x), lambda x: np.exp(x)],
["sin", lambda x: torch.sin(x), lambda x: np.sin(x)],
["cos", lambda x: torch.cos(x), lambda x: np.cos(x)],
]

for split_input, binary_op in itertools.product([True, False], binary_op_list):
# Make a copy of ElementBench
if len(binary_op) == 2:
[op_str, op_pt_func] = binary_op
op_np_func = op_pt_func
elif len(binary_op) == 3:
[op_str, op_pt_func, op_np_func] = binary_op
split_str = 'split' if split_input else 'shared'
op_str = split_str + '_' + op_str
bm_cls = type('ElementBench_' + op_str, (ElementBench,), {})
bm_cls.op_str = op_str
bm_cls.binary_op_pt_func = op_pt_func
bm_cls.binary_op_np_func = op_np_func
bm_cls.split_input = split_input
framework.register_benchmark_class(bm_cls)

for split_input, unary_op in itertools.product([True, False], unary_op_list):
# Make a copy of ElementBench
if len(unary_op) == 2:
[op_str, op_pt_func] = unary_op
op_np_func = op_pt_func
elif len(unary_op) == 3:
[op_str, op_pt_func, op_np_func] = unary_op
split_str = 'split' if split_input else 'shared'
op_str = split_str + '_' + op_str
bm_cls = type('ElementBench_' + op_str, (ElementBench,), {})
bm_cls.op_str = op_str
bm_cls.unary_op_pt_func = op_pt_func
bm_cls.unary_op_np_func = op_np_func
bm_cls.split_input = split_input
framework.register_benchmark_class(bm_cls)


#framework.register_benchmark_class(ElementMulBench)
register_element_ops()

3 changes: 1 addition & 2 deletions benchmarks/tensorexpr/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def forward(self):

def check(self):
np.testing.assert_allclose(
self.reference(), self.numpy(self.forward(*self.inputs)), atol=1e-7)
self.reference(), self.numpy(self.forward(*self.inputs)), atol=1e-2)

def config(self):
'''returns an array for the current benchmark configs
Expand Down Expand Up @@ -81,7 +81,6 @@ def __init__(self, mode, device):
method_engine = getattr(self.engine, method)
setattr(self, method, method_engine)


def rand(self, shape, device=None, requires_grad=False):
v = self.engine.rand(shape, device=device, requires_grad=requires_grad)
if requires_grad:
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensorexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_addcmul(x, y, z, w):

x = traced(rand_a, rand_b, rand_c, rand_d)
y = run_addcmul(rand_a, rand_b, rand_c, rand_d)
np.testing.assert_allclose(x.numpy(), y.numpy())
np.testing.assert_allclose(x.numpy(), y.numpy(), atol=1e-6)


def test_three_arg_cuda():
Expand Down Expand Up @@ -678,7 +678,7 @@ def test_relu(x, y):
traced = torch.jit.trace(torch_fn, (ins, ins))
x = traced(rand_a, rand_b)
y = torch_fn(rand_a, rand_b)
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3)
# nans
traced = torch.jit.trace(torch_fn, (ins, ins))
x = traced(nans, rand_b)
Expand Down
27 changes: 27 additions & 0 deletions torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,33 @@ void CudaPrinter::visit(const For* v) {
}
}

void CudaPrinter::visit(const Intrinsics* v) {
std::string func_name;
// TODO: handle other data types.
switch (v->op_type()) {
case IntrinsicsOp::kSin:
func_name = "sinf";
break;
case IntrinsicsOp::kCos:
func_name = "cosf";
break;
case IntrinsicsOp::kExp:
func_name = "expf";
break;
default:
IRPrinter::visit(v);
return;
}
os() << func_name << "(";
for (int i = 0; i < v->nparams(); i++) {
if (i > 0) {
os() << ", ";
}
os() << v->param(i);
}
os() << ")";
}

void CudaPrinter::visit(const Load* v) {
// TODO: find a better metric in using ldg or not. Support different dtypes.
os() << "__ldg(" << v->base_handle() << " + " << v->index() << ")";
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/tensorexpr/cuda_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class CudaPrinter : public IRPrinter {
os() << ")";
}

void visit(const Intrinsics* v);
void visit(const For* v);

void visit(const Load* v);
Expand Down

0 comments on commit 9642333

Please sign in to comment.