Skip to content

Commit

Permalink
Use correct tensor type for fuser output (pytorch#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
bertmaher authored Jan 28, 2020
1 parent d7fe47f commit dfe8396
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
10 changes: 10 additions & 0 deletions test/test_tensorexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,13 @@ def test(x):
x = torch.rand(4)
y = traced(x)
np.testing.assert_allclose(x.numpy() + 3.0, y.numpy())

def test_int_output():
def test(x, y, z):
return x * y * z
xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)]
x, y, z = xs
xn, yn, zn = [t.numpy() for t in xs]
traced = torch.jit.trace(test, (x, y, z))
res = traced(x, y, z)
np.testing.assert_allclose(xn * yn * zn, res.numpy())
13 changes: 12 additions & 1 deletion torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@ Dtype texprType(const c10::optional<at::ScalarType>& st) {
}
}

at::ScalarType tensorType(const Tensor& t) {
auto const& stype = t.dtype().scalar_type();
if (stype == kInt32) {
return at::ScalarType::Int;
} else if (stype == kFloat32) {
return at::ScalarType::Float;
}
LOG(FATAL) << "Unhandled datatype";
return at::ScalarType::Float;
}

std::vector<Expr> texprSizes(const c10::VaryingShape& shape) {
std::vector<Expr> dims;
for (size_t i = 0; i < *shape.size(); i++) {
Expand Down Expand Up @@ -562,7 +573,7 @@ struct TensorExprKernel {
codegen->bind(buffer_args[i], inputs[i].toTensor().data_ptr());
}
at::Tensor output =
at::empty(bufferSizes(*tensor_output), at::ScalarType::Float);
at::empty(bufferSizes(*tensor_output), tensorType(*tensor_output));
codegen->bind(*tensor_output, output.data_ptr());

// Call the kernel.
Expand Down

0 comments on commit dfe8396

Please sign in to comment.