diff --git a/taichi/inc/data_type_with_c_type.inc.h b/taichi/inc/data_type_with_c_type.inc.h new file mode 100644 index 00000000000000..2b12f83cbbd986 --- /dev/null +++ b/taichi/inc/data_type_with_c_type.inc.h @@ -0,0 +1,11 @@ +// Doesn't contain f16 and u1. +PER_C_TYPE(f32, float32) +PER_C_TYPE(f64, float64) +PER_C_TYPE(i8, int8) +PER_C_TYPE(i16, int16) +PER_C_TYPE(i32, int32) +PER_C_TYPE(i64, int64) +PER_C_TYPE(u8, uint8) +PER_C_TYPE(u16, uint16) +PER_C_TYPE(u32, uint32) +PER_C_TYPE(u64, uint64) diff --git a/taichi/program/launch_context_builder.cpp b/taichi/program/launch_context_builder.cpp index faf1cbeda6fa5f..91ecc68ef13e78 100644 --- a/taichi/program/launch_context_builder.cpp +++ b/taichi/program/launch_context_builder.cpp @@ -144,6 +144,23 @@ void LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) { ctx_->extra_args[i][j] = d; } +template +void LaunchContextBuilder::set_struct_arg(std::vector index, T v) { + if (ctx_->arg_buffer_size == 0) { + // Currently arg_buffer_size is always zero on non-LLVM-based backends, + // and this function is no-op for these backends. + return; + } + int offset = ctx_->args_type->get_element_offset(index); + *(T *)(ctx_->arg_buffer + offset) = v; +} + +#define PER_C_TYPE(x, ctype) \ + template void LaunchContextBuilder::set_struct_arg( \ + std::vector index, ctype v); +#include "taichi/inc/data_type_with_c_type.inc.h" +#undef PER_C_TYPE + void LaunchContextBuilder::set_arg_external_array_with_shape( int arg_id, uintptr_t ptr, diff --git a/taichi/program/launch_context_builder.h b/taichi/program/launch_context_builder.h index 8a8ceca7f2e358..dd7f1d0ca34158 100644 --- a/taichi/program/launch_context_builder.h +++ b/taichi/program/launch_context_builder.h @@ -23,6 +23,14 @@ class LaunchContextBuilder { void set_arg(int arg_id, TypedConstant d); + template + void set_arg_u8(int arg_id, T d) { + auto dt = kernel_->parameter_list[arg_id].get_dtype(); + TI_ASSERT(dt->is_primitive(PrimitiveTypeID::u8)); + } + template + void set_struct_arg(std::vector index, T v); + void set_extra_arg_int(int i, int j, int32 d); void set_arg_external_array_with_shape(int arg_id,