Skip to content

Commit

Permalink
Make Broadcast work (pytorch#147)
Browse files Browse the repository at this point in the history
$ python benchmarks/tensorexpr/benchmark.py broadcast_3args --device gpu --mode fwd --jit_mode trace
  • Loading branch information
zheng-xq authored Feb 12, 2020
1 parent f3990d7 commit 21a599c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 8 deletions.
7 changes: 6 additions & 1 deletion benchmarks/tensorexpr/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def main():
parser.add_argument('--mode', type=str, default='fwd,both',
help='a comma separated list of running modes')
parser.add_argument('--engine', type=str, default='pt',
help='the underlying tensor engine. one of pt or tf')
help='the underlying tensor engine. only pt for now')
parser.add_argument('--jit_mode', type=str, default='trace',
help='the jit mode to use: one of {trace, none}')

args = parser.parse_args()

def set_global_threads(num_threads):
Expand Down Expand Up @@ -64,6 +67,7 @@ def set_global_threads(num_threads):
def run_default_configs(bench_cls, allow_skip=True):
for mode, device, config in itertools.product(modes, devices, bench_cls.default_configs()):
benchmark = bench_cls(mode, device, *config)
benchmark.jit_mode = args.jit_mode
if not benchmark.is_supported():
if allow_skip:
continue
Expand Down Expand Up @@ -111,6 +115,7 @@ def run_default_configs(bench_cls, allow_skip=True):
except ValueError:
pass
benchmark = bench_cls(*config)
benchmark.jit_mode = args.jit_mode
framework.run_benchmark(benchmark)

if not match_class_name:
Expand Down
51 changes: 49 additions & 2 deletions benchmarks/tensorexpr/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def __init__(self, mode, device, case, M, N, K):
else:
raise ValueError('invalid case: %s' % (case))

def forward(self):
y = self.d1 + self.d2
self.inputs = [self.d1, self.d2]

def forward(self, d1, d2):
y = d1 + d2
return y

def reference(self):
Expand Down Expand Up @@ -74,6 +76,51 @@ def module():
return 'broadcast_col'


class BroadcastThreeArgs(framework.Benchmark):
def __init__(self, mode, device, M, N, K, L):
super().__init__(mode, device)
self.M = M
self.N = N
self.K = K
self.L = L

self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad)
self.d2 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad)
self.d3 = self.rand([L, K, 1, 1], device=device, requires_grad=self.requires_grad)

self.inputs = [self.d1, self.d2, self.d3]

def forward(self, d1, d2, d3):
y = d1 + d2 + d3
return y

def reference(self):
return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3)

def config(self):
return [self.M, self.N, self.K, self.L]

@staticmethod
def default_configs():
return [[32, 16, 64, 128]]

def memory_workload(self):
if self.mode == 'fwd':
sol_count = 1
algorithmic_count = 1
else:
sol_count = (1) + (1)
algorithmic_count = 1 + (1 + 1 + 1)

buffer_size = self.M * self.N * self.K * self.L * 4
return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count}

@staticmethod
def module():
return 'broadcast_3args'


framework.register_benchmark_class(BroadcastRowBench)
framework.register_benchmark_class(BroadcastMidBench)
framework.register_benchmark_class(BroadcastColBench)
framework.register_benchmark_class(BroadcastThreeArgs)
14 changes: 10 additions & 4 deletions benchmarks/tensorexpr/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
import tensor_engine

import torch

class BenchmarkBase(object):
def __init__(self, mode, device):
Expand All @@ -24,7 +24,7 @@ def forward(self):

def check(self):
np.testing.assert_allclose(
self.reference(), self.numpy(self.forward()), atol=1e-7)
self.reference(), self.numpy(self.forward(*self.inputs)), atol=1e-7)

def config(self):
'''returns an array for the current benchmark configs
Expand Down Expand Up @@ -107,14 +107,20 @@ def run_benchmark(benchmark):
benchmark.check()
else:
print(f"Warning: no reference result for {benchmark.module()}")


bm_jit = None
for i in range(warmups + iters):
if i == warmups:
if benchmark.device == 'cuda':
engine.sync_cuda()
time_start = time.time()

z = benchmark.forward()
if i == 0 and benchmark.jit_mode == 'trace':
bm_jit = torch.jit.trace(benchmark.forward, example_inputs=benchmark.inputs)
if bm_jit:
z = bm_jit(*benchmark.inputs)
else:
z = benchmark.forward(*benchmark.inputs)
if benchmark.mode == 'both':
if benchmark.result_grad is None:
benchmark.result_grad = engine.rand_like(z)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/tensorexpr/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def rand_like(self, v):
return torch.rand_like(v)

def numpy(self, t):
return t.numpy()
return t.cpu().numpy()

def mul(self, t1, t2):
return t1 * t2
Expand Down

0 comments on commit 21a599c

Please sign in to comment.