Skip to content

Commit

Permalink
[aot] Pass LaunchContextBuilder to CompiledGraph::init_runtime_context
Browse files Browse the repository at this point in the history
ghstack-source-id: 42b91cc4aaee38b88909a84f31a07a7daad66eed
Pull Request resolved: #7610
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Mar 21, 2023
1 parent 37fcdcc commit 9140f65
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
16 changes: 6 additions & 10 deletions taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ void CompiledGraph::run(
for (const auto &dispatch : dispatches) {
TI_ASSERT(dispatch.compiled_kernel);
LaunchContextBuilder launch_ctx(dispatch.compiled_kernel);
init_runtime_context(dispatch.symbolic_args, args,
launch_ctx.get_context());
init_runtime_context(dispatch.symbolic_args, args, launch_ctx);
// Run cgraph loaded from AOT module
dispatch.compiled_kernel->launch(launch_ctx);
}
Expand All @@ -27,8 +26,7 @@ void CompiledGraph::jit_run(
for (const auto &dispatch : dispatches) {
TI_ASSERT(dispatch.ti_kernel);
LaunchContextBuilder launch_ctx(dispatch.ti_kernel);
init_runtime_context(dispatch.symbolic_args, args,
launch_ctx.get_context());
init_runtime_context(dispatch.symbolic_args, args, launch_ctx);
// Compile & Run (JIT): The compilation result will be cached, so don't
// worry that the kernels dispatched by this cgraph will be compiled
// repeatedly.
Expand All @@ -40,7 +38,7 @@ void CompiledGraph::jit_run(
void CompiledGraph::init_runtime_context(
const std::vector<Arg> &paramter_list,
const std::unordered_map<std::string, IValue> &args,
RuntimeContext &ctx) {
LaunchContextBuilder &ctx) {
for (int i = 0; i < paramter_list.size(); ++i) {
auto &symbolic_arg = paramter_list[i];
auto found = args.find(symbolic_arg.name);
Expand Down Expand Up @@ -87,8 +85,7 @@ void CompiledGraph::init_runtime_context(
"dtype={} but got an ndarray with dtype={}",
symbolic_arg.name, symbolic_arg_primitive_dtype.to_string(),
arr_primitive_dtype.to_string());
ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(),
arr->shape);
ctx.set_arg_ndarray(i, *arr);
} else if (symbolic_arg.tag == aot::ArgKind::kScalar ||
symbolic_arg.tag == aot::ArgKind::kMatrix) {
TI_ASSERT(ival.tag == aot::ArgKind::kScalar);
Expand All @@ -97,12 +94,11 @@ void CompiledGraph::init_runtime_context(
} else if (symbolic_arg.tag == aot::ArgKind::kTexture) {
TI_ASSERT(ival.tag == aot::ArgKind::kTexture);
Texture *tex = reinterpret_cast<Texture *>(ival.val);
ctx.set_arg_texture(i, tex->get_device_allocation_ptr_as_int());
ctx.set_arg_texture(i, *tex);
} else if (symbolic_arg.tag == aot::ArgKind::kRWTexture) {
TI_ASSERT(ival.tag == aot::ArgKind::kTexture);
Texture *tex = reinterpret_cast<Texture *>(ival.val);
ctx.set_arg_rw_texture(i, tex->get_device_allocation_ptr_as_int(),
tex->get_size());
ctx.set_arg_rw_texture(i, *tex);
} else {
TI_ERROR("Error in compiled graph: unknown tag {}", ival.tag);
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct TI_DLL_EXPORT CompiledGraph {
static void init_runtime_context(
const std::vector<Arg> &paramter_list,
const std::unordered_map<std::string, IValue> &args,
RuntimeContext &ctx);
LaunchContextBuilder &ctx);
};

} // namespace aot
Expand Down

0 comments on commit 9140f65

Please sign in to comment.