Skip to content

Commit

Permalink
[ir] [autodiff] Initialize ADStack with a zero (#1791)
Browse files Browse the repository at this point in the history
* [ir] [autodiff] Initialize ADStack with a zero

* format_all

* [skip ci] Apply suggestions from code review

Co-authored-by: 彭于斌 <[email protected]>
  • Loading branch information
yuanming-hu and archibate authored Aug 28, 2020
1 parent e72553c commit dec1591
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ENV CC=/usr/bin/clang-8
ENV CXX=/usr/bin/clang++-8
RUN wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/llvm-10.0.0.src.tar.xz
RUN tar xvJf llvm-10.0.0.src.tar.xz
RUN cd llvm-10.0.0.src && mkdir build
RUN cd llvm-10.0.0.src && mkdir build
WORKDIR /llvm-10.0.0.src/build
RUN cmake .. -DLLVM_ENABLE_RTTI:BOOL=ON -DBUILD_SHARED_LIBS:BOOL=OFF -DCMAKE_BUILD_TYPE=Release -DLLVM_TARGETS_TO_BUILD="X86;NVPTX" -DLLVM_ENABLE_ASSERTIONS=ON
RUN make -j 8
Expand Down
1 change: 0 additions & 1 deletion cmake/PythonNumpyPybind11.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,3 @@ else ()
endif ()

include_directories(${PYBIND11_INCLUDE_DIR})

15 changes: 12 additions & 3 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,18 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor {
})
.empty();
if (!load_only) {
alloc->replace_with(Stmt::make<StackAllocaStmt>(
alloc->ret_type.data_type,
get_current_program().config.ad_stack_size));
auto dtype = alloc->ret_type.data_type;
auto stack_alloca = Stmt::make<StackAllocaStmt>(
dtype, alloc->get_kernel()->program.config.ad_stack_size);
auto stack_alloca_ptr = stack_alloca.get();

alloc->replace_with(std::move(stack_alloca));

// Note that unlike AllocaStmt, StackAllocaStmt does NOT have an 0 as
// initial value. Therefore here we push an initial 0 value.
auto zero = stack_alloca_ptr->insert_after_me(
Stmt::make<ConstStmt>(TypedConstant(dtype, 0)));
zero->insert_after_me(Stmt::make<StackPushStmt>(stack_alloca_ptr, zero));
}
}

Expand Down

0 comments on commit dec1591

Please sign in to comment.