Skip to content

Commit

Permalink
[refactor] Let LaunchContextBuilder be the argument of the compiled k…
Browse files Browse the repository at this point in the history
…ernel function on host

ghstack-source-id: f98d8749890464dcbd2941f1a19ea1a34532f110
Pull Request resolved: #7550
  • Loading branch information
lin-hitonami committed Mar 14, 2023
1 parent b8121a7 commit af02304
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
11 changes: 11 additions & 0 deletions taichi/inc/data_type_with_c_type.inc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Doesn't contain f16 and u1.
PER_C_TYPE(f32, float32)
PER_C_TYPE(f64, float64)
PER_C_TYPE(i8, int8)
PER_C_TYPE(i16, int16)
PER_C_TYPE(i32, int32)
PER_C_TYPE(i64, int64)
PER_C_TYPE(u8, uint8)
PER_C_TYPE(u16, uint16)
PER_C_TYPE(u32, uint32)
PER_C_TYPE(u64, uint64)
17 changes: 17 additions & 0 deletions taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,23 @@ void LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {
ctx_->extra_args[i][j] = d;
}

template <typename T>
void LaunchContextBuilder::set_struct_arg(std::vector<int> index, T v) {
if (ctx_->arg_buffer_size == 0) {
// Currently arg_buffer_size is always zero on non-LLVM-based backends,
// and this function is no-op for these backends.
return;
}
int offset = ctx_->args_type->get_element_offset(index);
*(T *)(ctx_->arg_buffer + offset) = v;
}

#define PER_C_TYPE(x, ctype) \
template void LaunchContextBuilder::set_struct_arg<ctype>( \
std::vector<int> index, ctype v);
#include "taichi/inc/data_type_with_c_type.inc.h"
#undef PER_C_TYPE

void LaunchContextBuilder::set_arg_external_array_with_shape(
int arg_id,
uintptr_t ptr,
Expand Down
8 changes: 8 additions & 0 deletions taichi/program/launch_context_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class LaunchContextBuilder {

void set_arg(int arg_id, TypedConstant d);

template <typename T>
void set_arg_u8(int arg_id, T d) {
auto dt = kernel_->parameter_list[arg_id].get_dtype();
TI_ASSERT(dt->is_primitive(PrimitiveTypeID::u8));
}
template <typename T>
void set_struct_arg(std::vector<int> index, T v);

void set_extra_arg_int(int i, int j, int32 d);

void set_arg_external_array_with_shape(int arg_id,
Expand Down

0 comments on commit af02304

Please sign in to comment.