Skip to content

Commit

Permalink
[refactor] Migrate {set, get}_{arg, ret} functions from RuntimeContext
Browse files Browse the repository at this point in the history
to LaunchContextBuilder

ghstack-source-id: b21aef43298ffef0a3251f6e0c2b3150e8acb34e
Pull Request resolved: #7550
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Mar 21, 2023
1 parent 97260fc commit f82242b
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 36 deletions.
11 changes: 11 additions & 0 deletions taichi/inc/data_type_with_c_type.inc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Doesn't contain f16 and u1.
PER_C_TYPE(f32, float32)
PER_C_TYPE(f64, float64)
PER_C_TYPE(i8, int8)
PER_C_TYPE(i16, int16)
PER_C_TYPE(i32, int32)
PER_C_TYPE(i64, int64)
PER_C_TYPE(u8, uint8)
PER_C_TYPE(u16, uint16)
PER_C_TYPE(u32, uint32)
PER_C_TYPE(u64, uint64)
193 changes: 158 additions & 35 deletions taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,33 +60,33 @@ void LaunchContextBuilder::set_arg_float(int arg_id, float64 d) {

auto dt = kernel_->parameter_list[arg_id].get_dtype();
if (dt->is_primitive(PrimitiveTypeID::f32)) {
ctx_->set_arg(arg_id, (float32)d);
set_arg(arg_id, (float32)d);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
ctx_->set_arg(arg_id, (float64)d);
set_arg(arg_id, (float64)d);
} else if (dt->is_primitive(PrimitiveTypeID::i32)) {
ctx_->set_arg(arg_id, (int32)d);
set_arg(arg_id, (int32)d);
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
ctx_->set_arg(arg_id, (int64)d);
set_arg(arg_id, (int64)d);
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
ctx_->set_arg(arg_id, (int8)d);
set_arg(arg_id, (int8)d);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
ctx_->set_arg(arg_id, (int16)d);
set_arg(arg_id, (int16)d);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
ctx_->set_arg(arg_id, (uint8)d);
set_arg(arg_id, (uint8)d);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
ctx_->set_arg(arg_id, (uint16)d);
set_arg(arg_id, (uint16)d);
} else if (dt->is_primitive(PrimitiveTypeID::u32)) {
ctx_->set_arg(arg_id, (uint32)d);
set_arg(arg_id, (uint32)d);
} else if (dt->is_primitive(PrimitiveTypeID::u64)) {
ctx_->set_arg(arg_id, (uint64)d);
set_arg(arg_id, (uint64)d);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
if (!arch_uses_llvm(kernel_->arch)) {
// TODO: remove this once we refactored the SPIR-V based backends
ctx_->set_arg(arg_id, (float32)d);
set_arg(arg_id, (float32)d);
return;
}
uint16 half = fp16_ieee_from_fp32_value((float32)d);
ctx_->set_arg(arg_id, half);
set_arg(arg_id, half);
} else {
TI_NOT_IMPLEMENTED
}
Expand All @@ -104,21 +104,21 @@ void LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {

auto dt = kernel_->parameter_list[arg_id].get_dtype();
if (dt->is_primitive(PrimitiveTypeID::i32)) {
ctx_->set_arg(arg_id, (int32)d);
set_arg(arg_id, (int32)d);
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
ctx_->set_arg(arg_id, (int64)d);
set_arg(arg_id, (int64)d);
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
ctx_->set_arg(arg_id, (int8)d);
set_arg(arg_id, (int8)d);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
ctx_->set_arg(arg_id, (int16)d);
set_arg(arg_id, (int16)d);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
ctx_->set_arg(arg_id, (uint8)d);
set_arg(arg_id, (uint8)d);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
ctx_->set_arg(arg_id, (uint16)d);
set_arg(arg_id, (uint16)d);
} else if (dt->is_primitive(PrimitiveTypeID::u32)) {
ctx_->set_arg(arg_id, (uint32)d);
set_arg(arg_id, (uint32)d);
} else if (dt->is_primitive(PrimitiveTypeID::u64)) {
ctx_->set_arg(arg_id, (uint64)d);
set_arg(arg_id, (uint64)d);
} else {
TI_INFO(dt->to_string());
TI_NOT_IMPLEMENTED
Expand All @@ -129,14 +129,15 @@ void LaunchContextBuilder::set_arg_uint(int arg_id, uint64 d) {
set_arg_int(arg_id, d);
}

void LaunchContextBuilder::set_arg(int arg_id, TypedConstant d) {
template <>
void LaunchContextBuilder::set_arg<TypedConstant>(int i, TypedConstant d) {
if (is_real(d.dt)) {
set_arg_float(arg_id, d.val_float());
set_arg_float(i, d.val_float());
} else {
if (is_signed(d.dt)) {
set_arg_int(arg_id, d.val_int());
set_arg_int(i, d.val_int());
} else {
set_arg_uint(arg_id, d.val_uint());
set_arg_uint(i, d.val_uint());
}
}
}
Expand All @@ -145,6 +146,77 @@ void LaunchContextBuilder::set_extra_arg_int(int i, int j, int32 d) {
ctx_->extra_args[i][j] = d;
}

template <typename T>
void LaunchContextBuilder::set_struct_arg(std::vector<int> index, T v) {
if (ctx_->arg_buffer_size == 0) {
// Currently arg_buffer_size is always zero on non-LLVM-based backends,
// and this function is no-op for these backends.
return;
}
int offset = ctx_->args_type->get_element_offset(index);
*(T *)(ctx_->arg_buffer + offset) = v;
}

template <typename T>
T LaunchContextBuilder::get_arg(int i) {
if (ctx_->arg_buffer_size > 0) {
// Currently arg_buffer_size is always zero on non-LLVM-based backends
return get_struct_arg<T>({i});
}
return taichi_union_cast_with_different_sizes<T>(ctx_->args[i]);
}

template <typename T>
T LaunchContextBuilder::get_struct_arg(std::vector<int> index) {
int offset = ctx_->args_type->get_element_offset(index);
return *(T *)(ctx_->arg_buffer + offset);
}

template <typename T>
T LaunchContextBuilder::get_grad_arg(int i) {
return taichi_union_cast_with_different_sizes<T>(ctx_->grad_args[i]);
}

template <typename T>
void LaunchContextBuilder::set_arg(int i, T v) {
set_struct_arg({i}, v);
ctx_->args[i] = taichi_union_cast_with_different_sizes<uint64>(v);
set_array_device_allocation_type(i, DevAllocType::kNone);
}

template <typename T>
void LaunchContextBuilder::set_grad_arg(int i, T v) {
ctx_->grad_args[i] = taichi_union_cast_with_different_sizes<uint64>(v);
}

template <typename T>
T LaunchContextBuilder::get_ret(int i) {
return taichi_union_cast_with_different_sizes<T>(ctx_->result_buffer[i]);
}

#define PER_C_TYPE(type, ctype) \
template void LaunchContextBuilder::set_struct_arg(std::vector<int> index, \
ctype v); \
template ctype LaunchContextBuilder::get_arg(int i); \
template ctype LaunchContextBuilder::get_struct_arg(std::vector<int> index); \
template ctype LaunchContextBuilder::get_grad_arg(int i); \
template void LaunchContextBuilder::set_arg(int i, ctype v); \
template void LaunchContextBuilder::set_grad_arg(int i, ctype v); \
template ctype LaunchContextBuilder::get_ret(int i);
#include "taichi/inc/data_type_with_c_type.inc.h"
PER_C_TYPE(gen, void *) // Register void* as a valid type
#undef PER_C_TYPE

void LaunchContextBuilder::set_array_runtime_size(int i, uint64 size) {
ctx_->array_runtime_sizes[i] = size;
}

void LaunchContextBuilder::set_array_device_allocation_type(
int i,
DevAllocType usage) {
ctx_->device_allocation_type[i] = (RuntimeContext::DevAllocType)usage;
}

void LaunchContextBuilder::set_arg_external_array_with_shape(
int arg_id,
uintptr_t ptr,
Expand All @@ -154,22 +226,21 @@ void LaunchContextBuilder::set_arg_external_array_with_shape(
kernel_->parameter_list[arg_id].is_array,
"Assigning external (numpy) array to scalar argument is not allowed.");

ActionRecorder::get_instance().record(
"set_kernel_arg_ext_ptr",
{ActionArg("kernel_name", kernel_->name), ActionArg("arg_id", arg_id),
ActionArg("address", fmt::format("0x{:x}", ptr)),
ActionArg("array_size_in_bytes", (int64)size)});

TI_ASSERT_INFO(shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
ctx_->set_arg_external_array(arg_id, ptr, size, shape);
set_arg(arg_id, ptr);
set_array_runtime_size(arg_id, size);
set_array_device_allocation_type(arg_id, DevAllocType::kNone);
for (uint64 i = 0; i < shape.size(); ++i) {
ctx_->extra_args[arg_id][i] = shape[i];
}
}

void LaunchContextBuilder::set_arg_ndarray(int arg_id, const Ndarray &arr) {
intptr_t ptr = arr.get_device_allocation_ptr_as_int();
TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
ctx_->set_arg_ndarray(arg_id, ptr, arr.shape);
set_arg_ndarray_impl(arg_id, ptr, arr.shape);
}

void LaunchContextBuilder::set_arg_ndarray_with_grad(int arg_id,
Expand All @@ -179,23 +250,75 @@ void LaunchContextBuilder::set_arg_ndarray_with_grad(int arg_id,
intptr_t ptr_grad = arr_grad.get_device_allocation_ptr_as_int();
TI_ASSERT_INFO(arr.shape.size() <= taichi_max_num_indices,
"External array cannot have > {max_num_indices} indices");
ctx_->set_arg_ndarray(arg_id, ptr, arr.shape, true, ptr_grad);
set_arg_ndarray_impl(arg_id, ptr, arr.shape, true, ptr_grad);
}

void LaunchContextBuilder::set_arg_texture(int arg_id, const Texture &tex) {
intptr_t ptr = tex.get_device_allocation_ptr_as_int();
ctx_->set_arg_texture(arg_id, ptr);
set_arg_texture_impl(arg_id, ptr);
}

void LaunchContextBuilder::set_arg_rw_texture(int arg_id, const Texture &tex) {
intptr_t ptr = tex.get_device_allocation_ptr_as_int();
ctx_->set_arg_rw_texture(arg_id, ptr, tex.get_size());
set_arg_rw_texture_impl(arg_id, ptr, tex.get_size());
}

RuntimeContext &LaunchContextBuilder::get_context() {
return *ctx_;
}

void LaunchContextBuilder::set_arg_texture_impl(int arg_id,
intptr_t alloc_ptr) {
set_struct_arg({arg_id}, alloc_ptr);
ctx_->args[arg_id] =
taichi_union_cast_with_different_sizes<uint64>(alloc_ptr);
set_array_device_allocation_type(arg_id, DevAllocType::kTexture);
}

void LaunchContextBuilder::set_arg_rw_texture_impl(
int arg_id,
intptr_t alloc_ptr,
const std::array<int, 3> &shape) {
set_struct_arg({arg_id}, alloc_ptr);
ctx_->args[arg_id] =
taichi_union_cast_with_different_sizes<uint64>(alloc_ptr);
set_array_device_allocation_type(arg_id, DevAllocType::kRWTexture);
TI_ASSERT(shape.size() <= taichi_max_num_indices);
for (int i = 0; i < shape.size(); i++) {
ctx_->extra_args[arg_id][i] = shape[i];
}
}

void LaunchContextBuilder::set_arg_ndarray_impl(int arg_id,
intptr_t devalloc_ptr,
const std::vector<int> &shape,
bool has_grad,
intptr_t devalloc_ptr_grad) {
// Set has_grad value
ctx_->has_grad[arg_id] = has_grad;

set_struct_arg({arg_id}, devalloc_ptr);
// Set args[arg_id] value
ctx_->args[arg_id] =
taichi_union_cast_with_different_sizes<uint64>(devalloc_ptr);

// Set grad_args[arg_id] value
if (has_grad) {
ctx_->grad_args[arg_id] =
taichi_union_cast_with_different_sizes<uint64>(devalloc_ptr_grad);
}

// Set device allocation type and runtime size
set_array_device_allocation_type(arg_id, DevAllocType::kNdarray);
TI_ASSERT(shape.size() <= taichi_max_num_indices);
size_t total_size = 1;
for (int i = 0; i < shape.size(); i++) {
ctx_->extra_args[arg_id][i] = shape[i];
total_size *= shape[i];
}
set_array_runtime_size(arg_id, total_size);
}

TypedConstant LaunchContextBuilder::fetch_ret(const std::vector<int> &index) {
const Type *dt = ret_type_->get_element_type(index);
int offset = ret_type_->get_element_offset(index);
Expand Down
40 changes: 39 additions & 1 deletion taichi/program/launch_context_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ namespace taichi::lang {

class LaunchContextBuilder {
public:
enum class DevAllocType : int8_t {
kNone = 0,
kNdarray = 1,
kTexture = 2,
kRWTexture = 3
};

LaunchContextBuilder(CallableBase *kernel, RuntimeContext *ctx);
explicit LaunchContextBuilder(CallableBase *kernel);

Expand All @@ -21,21 +28,52 @@ class LaunchContextBuilder {
void set_arg_int(int arg_id, int64 d);
void set_arg_uint(int arg_id, uint64 d);

void set_arg(int arg_id, TypedConstant d);
void set_array_runtime_size(int i, uint64 size);

void set_array_device_allocation_type(int i, DevAllocType usage);

template <typename T>
void set_arg(int i, T v);

template <typename T>
void set_grad_arg(int i, T v);

template <typename T>
void set_struct_arg(std::vector<int> index, T v);

template <typename T>
T get_arg(int i);

template <typename T>
T get_struct_arg(std::vector<int> index);

template <typename T>
T get_grad_arg(int i);

template <typename T>
T get_ret(int i);
void set_extra_arg_int(int i, int j, int32 d);

void set_arg_external_array_with_shape(int arg_id,
uintptr_t ptr,
uint64 size,
const std::vector<int64> &shape);

void set_arg_ndarray_impl(int arg_id,
intptr_t devalloc_ptr,
const std::vector<int> &shape,
bool has_grad = false,
intptr_t devalloc_ptr_grad = 0);
void set_arg_ndarray(int arg_id, const Ndarray &arr);
void set_arg_ndarray_with_grad(int arg_id,
const Ndarray &arr,
const Ndarray &arr_grad);

void set_arg_texture_impl(int arg_id, intptr_t alloc_ptr);
void set_arg_texture(int arg_id, const Texture &tex);
void set_arg_rw_texture_impl(int arg_id,
intptr_t alloc_ptr,
const std::array<int, 3> &shape);
void set_arg_rw_texture(int arg_id, const Texture &tex);

TypedConstant fetch_ret(const std::vector<int> &index);
Expand Down

0 comments on commit f82242b

Please sign in to comment.