diff --git a/taichi/codegen/codegen_metal.cpp b/taichi/codegen/codegen_metal.cpp index f751132e9ad09..4a31408b3b99b 100644 --- a/taichi/codegen/codegen_metal.cpp +++ b/taichi/codegen/codegen_metal.cpp @@ -17,8 +17,7 @@ constexpr char kArgsContextName[] = "args_ctx_"; class MetalKernelCodegen : public IRVisitor { public: MetalKernelCodegen(const std::string &mtl_kernel_prefix, - const std::string &root_snode_type_name, - Kernel *kernel, + const std::string &root_snode_type_name, Kernel *kernel, const StructCompiledResult *compiled_snode_structs) : mtl_kernel_prefix_(mtl_kernel_prefix), root_snode_type_name_(root_snode_type_name), @@ -33,9 +32,7 @@ class MetalKernelCodegen : public IRVisitor { return args_attribs_; } - const std::string &kernel_source_code() const { - return kernel_src_code_; - } + const std::string &kernel_source_code() const { return kernel_src_code_; } const std::vector &kernels_attribs() const { return mtl_kernels_attribs_; @@ -394,7 +391,7 @@ class MetalKernelCodegen : public IRVisitor { void generate_common_functions() { #define TI_INSIDE_METAL_CODEGEN -#include +#include "taichi/platform/metal/shaders/helpers.metal.h" kernel_src_code_ += kMetalHelpersSourceCode; #undef TI_INSIDE_METAL_CODEGEN emit("\n"); @@ -541,9 +538,7 @@ class MetalKernelCodegen : public IRVisitor { } } - void push_indent() { - indent_ += " "; - } + void push_indent() { indent_ += " "; } void pop_indent() { indent_.pop_back(); @@ -578,11 +573,9 @@ MetalCodeGen::MetalCodeGen(const std::string &kernel_name, const StructCompiledResult *struct_compiled) : id_(Program::get_kernel_id()), taichi_kernel_name_(fmt::format("mtl_k{:04d}_{}", id_, kernel_name)), - struct_compiled_(struct_compiled) { -} + struct_compiled_(struct_compiled) {} -FunctionType MetalCodeGen::compile(Program &, - Kernel &kernel, +FunctionType MetalCodeGen::compile(Program &, Kernel &kernel, MetalRuntime *runtime) { this->prog_ = &kernel.program; this->kernel_ = &kernel; diff --git a/taichi/platform/metal/shaders/atomic_stubs.h b/taichi/platform/metal/shaders/atomic_stubs.h new file mode 100644 index 0000000000000..5f580a7ab356a --- /dev/null +++ b/taichi/platform/metal/shaders/atomic_stubs.h @@ -0,0 +1,54 @@ +#pragma once + +using atomic_int = int; +using atomic_uint = unsigned int; + +namespace metal { + +using memory_order = bool; +memory_order memory_order_relaxed = false; + +} // namespace metal + +template +bool atomic_compare_exchange_weak_explicit(T *object, T *expected, T desired, + metal::memory_order) { + const T val = *object; + if (val == *expected) { + *object = desired; + return true; + } + *expected = val; + return false; +} + +template +bool atomic_fetch_or_explicit(T *object, T operand, metal::memory_order) { + const T result = *object; + *object = (result | operand); + return result; +} + +template +bool atomic_fetch_and_explicit(T *object, T operand, metal::memory_order) { + const T result = *object; + *object = (result & operand); + return result; +} + +template +T atomic_fetch_add_explicit(T *object, T operand, metal::memory_order) { + const T result = *object; + *object += operand; + return result; +} + +template +T atomic_load_explicit(T *object, metal::memory_order) { + return *object; +} + +template +void atomic_store_explicit(T *object, T desired, metal::memory_order) { + *object = desired; +} diff --git a/taichi/platform/metal/shaders/epilog.h b/taichi/platform/metal/shaders/epilog.h new file mode 100644 index 0000000000000..130389f8ed7e5 --- /dev/null +++ b/taichi/platform/metal/shaders/epilog.h @@ -0,0 +1,9 @@ +#undef device +#undef constant +#undef thread +#undef kernel + +#undef byte + +#undef STR2 +#undef STR diff --git a/taichi/platform/metal/helpers.metal.h b/taichi/platform/metal/shaders/helpers.metal.h similarity index 75% rename from taichi/platform/metal/helpers.metal.h rename to taichi/platform/metal/shaders/helpers.metal.h index 51912e113697f..ab869ba31b49e 100644 --- a/taichi/platform/metal/helpers.metal.h +++ b/taichi/platform/metal/shaders/helpers.metal.h @@ -1,38 +1,27 @@ +#include "taichi/platform/metal/shaders/prolog.h" + #ifdef TI_INSIDE_METAL_CODEGEN +#ifndef TI_METAL_NESTED_INCLUDE #define METAL_BEGIN_HELPERS_DEF constexpr auto kMetalHelpersSourceCode = #define METAL_END_HELPERS_DEF ; - -#define STR2(...) #__VA_ARGS__ -#define STR(...) STR2(__VA_ARGS__) - #else - #define METAL_BEGIN_HELPERS_DEF #define METAL_END_HELPERS_DEF -#define STR(...) __VA_ARGS__ - -#define device -#define constant -#define thread +#endif // TI_METAL_NESTED_INCLUDE -using atomic_int = int; +#else -template -bool atomic_compare_exchange_weak_explicit(Args...) { - static_assert(false, "Do not include"); -} +static_assert(false, "Do not include"); -namespace metal { -bool memory_order_relaxed = false; -} // namespace metal +#define METAL_BEGIN_HELPERS_DEF +#define METAL_END_HELPERS_DEF #endif // TI_INSIDE_METAL_CODEGEN METAL_BEGIN_HELPERS_DEF STR( - template - T union_cast(G g) { + template T union_cast(G g) { // For some reason, if I emit taichi/common.h's union_cast(), Metal failed // to compile. More strangely, if I copy the generated code to XCode as a // Metal kernel, it compiled successfully... @@ -64,5 +53,5 @@ METAL_END_HELPERS_DEF #undef METAL_BEGIN_HELPERS_DEF #undef METAL_END_HELPERS_DEF -#undef STR2 -#undef STR + +#include "taichi/platform/metal/shaders/epilog.h" diff --git a/taichi/platform/metal/shaders/prolog.h b/taichi/platform/metal/shaders/prolog.h new file mode 100644 index 0000000000000..f99799f43f73f --- /dev/null +++ b/taichi/platform/metal/shaders/prolog.h @@ -0,0 +1,29 @@ +#ifdef TI_INSIDE_METAL_CODEGEN + +#ifndef TI_METAL_NESTED_INCLUDE +#define STR2(...) #__VA_ARGS__ +#define STR(...) STR2(__VA_ARGS__) +#else +// If we are emitting to Metal source code, and the shader file is included by +// some other shader file, then skip emitting the code for the nested shader, +// otherwise there could be a symbol redefinition error. That is, we only emit +// the source code for the shader being directly included by the host side. +#define STR(...) +#endif // TI_METAL_NESTED_INCLUDE + +#else + +#include + +#define STR(...) __VA_ARGS__ + +#define device +#define constant +#define thread +#define kernel + +#define byte char + +#include "taichi/platform/metal/shaders/atomic_stubs.h" + +#endif // TI_INSIDE_METAL_CODEGEN