Skip to content

Commit

Permalink
[aot] Improve C++ wrapper implementation (#6146)
Browse files Browse the repository at this point in the history
RT

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PENGUINLIONG and pre-commit-ci[bot] authored Sep 23, 2022
1 parent 3677d93 commit b272c96
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 48 deletions.
143 changes: 96 additions & 47 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ THandle move_handle(THandle &handle) {
class Memory {
TiRuntime runtime_{TI_NULL_HANDLE};
TiMemory memory_{TI_NULL_HANDLE};
size_t size_{0};
bool should_destroy_{false};

public:
Expand All @@ -88,10 +89,14 @@ class Memory {
Memory(Memory &&b)
: runtime_(detail::move_handle(b.runtime_)),
memory_(detail::move_handle(b.memory_)),
size_(std::exchange(b.size_, 0)),
should_destroy_(std::exchange(b.should_destroy_, false)) {
}
Memory(TiRuntime runtime, TiMemory memory, bool should_destroy)
: runtime_(runtime), memory_(memory), should_destroy_(should_destroy) {
Memory(TiRuntime runtime, TiMemory memory, size_t size, bool should_destroy)
: runtime_(runtime),
memory_(memory),
size_(size),
should_destroy_(should_destroy) {
}
~Memory() {
destroy();
Expand All @@ -102,10 +107,36 @@ class Memory {
destroy();
runtime_ = detail::move_handle(b.runtime_);
memory_ = detail::move_handle(b.memory_);
size_ = std::exchange(b.size_, 0);
should_destroy_ = std::exchange(b.should_destroy_, false);
return *this;
}

void *map() const {
return ti_map_memory(runtime_, memory_);
}
void unmap() const {
ti_unmap_memory(runtime_, memory_);
}

inline void read(void *dst, size_t size) const {
void *src = map();
if (src != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
}
inline void write(const void *src, size_t size) const {
void *dst = map();
if (dst != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
}

constexpr size_t size() const {
return size_;
}
constexpr TiMemory memory() const {
return memory_;
}
Expand All @@ -116,32 +147,28 @@ class Memory {

template <typename T>
class NdArray {
TiRuntime runtime_{TI_NULL_HANDLE};
Memory memory_{};
TiNdArray ndarray_{};
bool should_destroy_{false};

public:
constexpr bool is_valid() const {
return ndarray_.memory != nullptr;
return memory_.is_valid();
}
inline void destroy() {
if (should_destroy_) {
ti_free_memory(runtime_, ndarray_.memory);
ndarray_.memory = TI_NULL_HANDLE;
should_destroy_ = false;
}
memory_.destroy();
}

NdArray() {
}
NdArray(const NdArray<T> &) = delete;
NdArray(NdArray<T> &&b)
: runtime_(detail::move_handle(b.runtime_)),
ndarray_(std::exchange(b.ndarray_, {})),
should_destroy_(std::exchange(b.should_destroy_, false)) {
: memory_(std::move(b.memory_)), ndarray_(std::exchange(b.ndarray_, {})) {
}
NdArray(TiRuntime runtime, const TiNdArray &ndarray, bool should_destroy)
: runtime_(runtime), ndarray_(ndarray), should_destroy_(should_destroy) {
NdArray(Memory &&memory, const TiNdArray &ndarray)
: memory_(std::move(memory)), ndarray_(ndarray) {
if (ndarray.memory != memory_) {
ti_set_last_error(TI_ERROR_INVALID_ARGUMENT, "ndarray.memory != memory");
}
}
~NdArray() {
destroy();
Expand All @@ -150,44 +177,56 @@ class NdArray {
NdArray<T> &operator=(const NdArray<T> &) = delete;
NdArray<T> &operator=(NdArray<T> &&b) {
destroy();
runtime_ = detail::move_handle(b.runtime_);
memory_ = std::move(b.memory_);
ndarray_ = std::exchange(b.ndarray_, {});
should_destroy_ = std::exchange(b.should_destroy_, false);
return *this;
}

inline void *map() {
return ti_map_memory(runtime_, ndarray_.memory);
inline void *map() const {
return memory_.map();
}
inline void unmap() {
return ti_unmap_memory(runtime_, ndarray_.memory);
inline void unmap() const {
return memory_.unmap();
}

inline void read(T *dst, size_t size) {
T *src = (T *)map();
if (src != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
inline void read(T *dst, size_t size) const {
memory_.read(dst, size);
}
inline void read(std::vector<T> &dst) {
inline void read(std::vector<T> &dst) const {
read(dst.data(), dst.size() * sizeof(T));
}
inline void write(const T *src, size_t size) {
T *dst = (T *)map();
if (dst != nullptr) {
std::memcpy(dst, src, size);
}
unmap();
template <typename U>
inline void read(std::vector<U> &dst) const {
static_assert(sizeof(U) % sizeof(T) == 0,
"sizeof(U) must be a multiple of sizeof(T)");
read((T *)dst.data(), dst.size() * sizeof(U));
}
inline void write(const T *src, size_t size) const {
memory_.write(src, size);
}
inline void write(const std::vector<T> &src) {
inline void write(const std::vector<T> &src) const {
write(src.data(), src.size() * sizeof(T));
}
template <typename U>
inline void write(const std::vector<U> &src) const {
static_assert(sizeof(U) % sizeof(T) == 0,
"sizeof(U) must be a multiple of sizeof(T)");
write((const T *)src.data(), src.size() * sizeof(U));
}

constexpr TiMemory memory() const {
return ndarray_.memory;
constexpr TiDataType elem_type() const {
return ndarray_.elem_type;
}
constexpr const TiNdShape &shape() const {
return ndarray_.shape;
}
constexpr const TiNdShape &elem_shape() const {
return ndarray_.elem_shape;
}
constexpr const Memory &memory() const {
return memory_;
}
constexpr TiNdArray ndarray() const {
constexpr const TiNdArray &ndarray() const {
return ndarray_;
}
constexpr operator TiNdArray() const {
Expand Down Expand Up @@ -291,6 +330,9 @@ class Texture {
return *this;
}

constexpr const Image &image() const {
return image_;
}
constexpr TiTexture texture() const {
return texture_;
}
Expand Down Expand Up @@ -598,6 +640,7 @@ class Event {
};

class Runtime {
TiArch arch_{TI_ARCH_MAX_ENUM};
TiRuntime runtime_{TI_NULL_HANDLE};
bool should_destroy_{false};

Expand All @@ -617,14 +660,15 @@ class Runtime {
}
Runtime(const Runtime &) = delete;
Runtime(Runtime &&b)
: runtime_(detail::move_handle(b.runtime_)),
: arch_(std::exchange(b.arch_, TI_ARCH_MAX_ENUM)),
runtime_(detail::move_handle(b.runtime_)),
should_destroy_(std::exchange(b.should_destroy_, false)) {
}
Runtime(TiArch arch)
: runtime_(ti_create_runtime(arch)), should_destroy_(true) {
: arch_(arch), runtime_(ti_create_runtime(arch)), should_destroy_(true) {
}
Runtime(TiRuntime runtime, bool should_destroy)
: runtime_(runtime), should_destroy_(should_destroy) {
Runtime(TiArch arch, TiRuntime runtime, bool should_destroy)
: arch_(arch), runtime_(runtime), should_destroy_(should_destroy) {
}
~Runtime() {
destroy();
Expand All @@ -639,7 +683,7 @@ class Runtime {

Memory allocate_memory(const TiMemoryAllocateInfo &allocate_info) {
TiMemory memory = ti_allocate_memory(runtime_, &allocate_info);
return Memory(runtime_, memory, true);
return Memory(runtime_, memory, allocate_info.size, true);
}
Memory allocate_memory(size_t size) {
TiMemoryAllocateInfo allocate_info{};
Expand All @@ -648,8 +692,8 @@ class Runtime {
return allocate_memory(allocate_info);
}
template <typename T>
NdArray<T> allocate_ndarray(std::vector<uint32_t> shape,
std::vector<uint32_t> elem_shape,
NdArray<T> allocate_ndarray(const std::vector<uint32_t> &shape = {},
const std::vector<uint32_t> &elem_shape = {},
bool host_access = false) {
size_t size = sizeof(T);
TiNdArray ndarray{};
Expand All @@ -666,13 +710,15 @@ class Runtime {
}
ndarray.elem_shape.dim_count = elem_shape.size();
ndarray.elem_type = detail::templ2dtype<T>::value;

TiMemoryAllocateInfo allocate_info{};
allocate_info.size = size;
allocate_info.host_read = host_access;
allocate_info.host_write = host_access;
allocate_info.usage = TI_MEMORY_USAGE_STORAGE_BIT;
ndarray.memory = ti_allocate_memory(runtime_, &allocate_info);
return NdArray<T>(runtime_, std::move(ndarray), true);
Memory memory = allocate_memory(allocate_info);
ndarray.memory = memory;
return NdArray<T>(std::move(memory), ndarray);
}

Image allocate_image(const TiImageAllocateInfo &allocate_info) {
Expand Down Expand Up @@ -734,6 +780,9 @@ class Runtime {
ti_wait(runtime_);
}

constexpr TiArch arch() const {
return arch_;
}
constexpr TiRuntime runtime() const {
return runtime_;
}
Expand Down
4 changes: 4 additions & 0 deletions taichi/rhi/vulkan/vulkan_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,11 @@ VulkanCommandList::VulkanCommandList(VulkanDevice *ti_device,
: ti_device_(ti_device),
stream_(stream),
device_(ti_device->vk_device()),
#if !defined(__APPLE__)
query_pool_(vkapi::create_query_pool(ti_device->vk_device())),
#else
query_pool_(),
#endif
buffer_(buffer) {
VkCommandBufferBeginInfo info{};
info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
Expand Down
2 changes: 1 addition & 1 deletion taichi/rhi/vulkan/vulkan_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ bool VulkanLoader::init(PFN_vkGetInstanceProcAddr get_proc_addr) {
// (penguinliong) So that MoltenVK instances can be imported.
if (get_proc_addr != nullptr) {
volkInitializeCustom(get_proc_addr);
initialized = check_vulkan_device();
initialized = true;
return;
}
#if defined(__APPLE__)
Expand Down

0 comments on commit b272c96

Please sign in to comment.