Skip to content

Commit

Permalink
Fixed CudaCodeGen output streams. Switch to __ldg by default (pytorch…
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-xq authored and Mikhail Zolotukhin committed Mar 3, 2020
1 parent 3ab9ba6 commit 81f0847
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
21 changes: 13 additions & 8 deletions torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ScopedVarName {
const std::string& name)
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}

~ScopedVarName() {
~ScopedVarName() noexcept(false) {
auto iter = mapping_->find(var_);
TORCH_CHECK(iter != mapping_->end(), "Invalid var entry");
mapping_->erase(var_);
Expand Down Expand Up @@ -124,29 +124,34 @@ void CudaPrinter::visit(const For* v) {
}
}

void CudaPrinter::visit(const Load* v) {
// TODO: find a better metric in using ldg or not. Support different dtypes.
os() << "__ldg(" << v->base_handle() << " + " << v->index() << ")";
}

void CudaCodeGen::Initialize() {
printer_.reset(new CudaPrinter(&oss_));
// TODO: handle multiple kernels.
// TODO: handle dynamic dimension.
// TODO: call nvrtc.
oss_ << "extern \"C\" __global__" << std::endl << "void f(";
os() << "extern \"C\" __global__" << std::endl << "void f(";
const std::vector<BufferArg> buffer_args = this->buffer_args();
for (int i = 0; i < buffer_args.size(); i++) {
if (i > 0) {
oss_ << ", ";
os() << ", ";
}
const BufferArg& buffer_arg = buffer_args[i];
const Var& var = buffer_arg.var();
Dtype dtype = buffer_arg.dtype();
oss_ << dtype.ToCppString() << (buffer_arg.isVar() ? " " : "* ")
os() << dtype.ToCppString() << (buffer_arg.isVar() ? " " : "* ")
<< name_manager()->get_unique_name(var);
}
oss_ << ") {";
os() << ") {";

oss_ << std::endl;
os() << std::endl;
stmt().accept(printer_.get());
oss_ << std::endl;
oss_ << "}";
os() << std::endl;
os() << "}";

// Check that all block extents had been set.
const std::vector<Expr>& gpu_block_extents = printer_->gpu_block_extents();
Expand Down
11 changes: 6 additions & 5 deletions torch/csrc/jit/tensorexpr/cuda_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace tensorexpr {
// A class that overrides the underlying IRPrinter to produce Cuda C.
class CudaPrinter : public IRPrinter {
public:
explicit CudaPrinter(std::ostream* os) : IRPrinter(*os), os_(os) {}
explicit CudaPrinter(std::ostream* os) : IRPrinter(*os) {}

void visit(const Cast* v) {
auto dtype = v->dtype();
Expand All @@ -38,9 +38,7 @@ class CudaPrinter : public IRPrinter {

void visit(const For* v);

std::ostream& os() {
return *os_;
}
void visit(const Load* v);

const std::vector<Expr>& gpu_block_extents() const {
return gpu_block_extents_;
Expand All @@ -53,7 +51,6 @@ class CudaPrinter : public IRPrinter {
using IRPrinter::name_manager;

private:
std::ostream* os_ = nullptr;
std::vector<Expr> gpu_block_extents_;
std::vector<Expr> gpu_thread_extents_;
};
Expand Down Expand Up @@ -94,6 +91,10 @@ class TORCH_API CudaCodeGen : public CodeGen {
return printer_->name_manager();
}

std::ostream& os() {
return printer_->os();
}

std::ostringstream oss_;
std::unique_ptr<CudaPrinter> printer_;
CUfunction function_;
Expand Down
4 changes: 0 additions & 4 deletions torch/csrc/jit/tensorexpr/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ class TORCH_API IRPrinter : public IRVisitor {
}

private:
std::ostream& raw_os() {
return printer_os_;
}

PrinterStream printer_os_;
UniqueNameManager name_manager_;
};
Expand Down

0 comments on commit 81f0847

Please sign in to comment.