Skip to content

Commit

Permalink
Add more operator support and tests (pytorch#140)
Browse files Browse the repository at this point in the history
* Add more operator support and tests

rm log

add more cuda tests

clean up debug

relu is already added

fix the frac/relu support

* rm the extra relu

* redundant op

* rm frac
  • Loading branch information
lly-zero-one authored and Mikhail Zolotukhin committed Feb 18, 2020
1 parent fad5348 commit fcc16c2
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 34 deletions.
82 changes: 51 additions & 31 deletions test/test_tensorexpr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import torch.nn.functional as F


class ExecutionCounter(object):
Expand Down Expand Up @@ -423,11 +424,13 @@ def easy(x, y):
c = torch.lt(x, y)
return c

traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
a = torch.ones(1024, dtype=torch.int32)
b = torch.zeros(1024, dtype=torch.int32)
x = traced(a, b)
np.testing.assert_allclose(np.zeros(1024), x.numpy())
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
for dev in device_options:
traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
a = torch.ones(1024, dtype=torch.int32, device=dev)
b = torch.zeros(1024, dtype=torch.int32, device=dev)
x = traced(a, b)
np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy())


def test_min_max():
Expand All @@ -446,10 +449,24 @@ def test_clamp():
def test(x):
return torch.clamp(x + 3.0, 0.0, 6.0)

traced = torch.jit.trace(test, (torch.zeros(1024)))
a = 20.0 * torch.rand(1024) - 10.0
an = a.numpy()
np.testing.assert_allclose(traced(a), np.clip(an + 3.0, 0.0, 6.0))
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]

for dev in device_options:
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
a = 20.0 * torch.rand(1024, device=dev) - 10.0
an = a.cpu().numpy()
np.testing.assert_allclose(traced(a).cpu(), np.clip(an + 3.0, 0.0, 6.0))

def test_relu():
def test(x):
return torch.clamp(F.relu(x), 0, 0.5)

device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
for dev in device_options:
traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
a = 20.0 * torch.rand(1024, device=dev) - 10.0
an = a.cpu().numpy()
np.testing.assert_allclose(traced(a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5))


def test_reps():
Expand Down Expand Up @@ -487,8 +504,15 @@ def test(x, y, z):
res = traced(x, y, z)
np.testing.assert_allclose(xn * yn * zn, res.numpy())

def test_binary_ops():
pass

def test_unary_ops():

def test_round(x, y):
c = torch.round(torch.add(x, y))
return c

def test_sin(x, y):
c = torch.sin(torch.add(x, y))
return c
Expand Down Expand Up @@ -610,6 +634,7 @@ def test_relu(x, y):
return c

fns = {
test_round,
test_sin,
test_asin,
test_sinh,
Expand Down Expand Up @@ -640,30 +665,25 @@ def test_relu(x, y):
test_neg,
test_relu,
}
rand_a = torch.rand(1024, dtype=torch.float)
rand_b = torch.rand(1024, dtype=torch.float)
zeros = torch.zeros(1024, dtype=torch.float)
cc = np.array(1024, dtype=float)
cc.fill(np.nan)
nans = torch.from_numpy(cc)
device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu']

for torch_fn in fns:
# random floats
traced = torch.jit.trace(
torch_fn,
(
torch.zeros(1024, dtype=torch.float),
torch.zeros(1024, dtype=torch.float),
),
)
x = traced(rand_a, rand_b)
y = torch_fn(rand_a, rand_b)
np.testing.assert_allclose(x.numpy(), y.numpy(), 1e-7, 1e-6)
# nans
traced = torch.jit.trace(torch_fn, (torch.zeros(1024), torch.zeros(1024)))
x = traced(nans, rand_b)
y = torch_fn(nans, rand_b)
np.testing.assert_allclose(x.numpy(), y.numpy())
for dev in device_options:
rand_a = torch.rand(1024, device=dev)
rand_b = torch.rand(1024, device=dev)
ins = 20 * torch.rand(1024, device=dev)
cc = np.array(1024, dtype=float)
cc.fill(np.nan)
nans = torch.from_numpy(cc).to(dev)
traced = torch.jit.trace(torch_fn, (ins, ins))
x = traced(rand_a, rand_b)
y = torch_fn(rand_a, rand_b)
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
# nans
traced = torch.jit.trace(torch_fn, (ins, ins))
x = traced(nans, rand_b)
y = torch_fn(nans, rand_b)
np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())


def test_nans():
Expand Down
24 changes: 24 additions & 0 deletions torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,30 @@ void CudaPrinter::visit(const Load* v) {
os() << "__ldg(" << v->base_handle() << " + " << v->index() << ")";
}

void CudaPrinter::visit(const Max* v) {
auto dtype = v->dtype();
if (dtype == kFloat32) {
os() << "fmaxf";
}
os() << "(";
v->lhs().accept(this);
os() << ",";
v->rhs().accept(this);
os() << ")";
}

void CudaPrinter::visit(const Min* v) {
auto dtype = v->dtype();
if (dtype == kFloat32) {
os() << "fminf";
}
os() << "(";
v->lhs().accept(this);
os() << ",";
v->rhs().accept(this);
os() << ")";
}

void CudaCodeGen::Initialize() {
printer_.reset(new CudaPrinter(&oss_));
// TODO: handle multiple kernels.
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/tensorexpr/cuda_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class CudaPrinter : public IRPrinter {
void visit(const For* v);

void visit(const Load* v);
void visit(const Max* v);
void visit(const Min* v);

const std::vector<Expr>& gpu_block_extents() const {
return gpu_block_extents_;
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ Tensor TensorExprKernel::ComputeValue(const torch::jit::Value* v) {

case aten::relu: {
return ComputeOneOperand("aten_relu", v, [](const Expr& a) {
Expr zero_cond = CompareSelect::make(a, Expr(0.0f), kLT);
return ifThenElse(zero_cond, Expr(0.0f), a);
return Max::make(a, 0, false);
});
} break;

Expand Down Expand Up @@ -509,7 +508,7 @@ Tensor TensorExprKernel::ComputeValue(const torch::jit::Value* v) {

case aten::frac: {
return ComputeOneOperand(
"aten_frac", v, [](const Expr& a) { return frac(a); });
"aten_frac", v, [](const Expr& a) { return a - floor(a); });
} break;

case aten::lgamma: {
Expand Down

0 comments on commit fcc16c2

Please sign in to comment.