diff --git a/src/gpu/jit/binary_format.cpp b/src/gpu/jit/binary_format.cpp index 3181bd58902..4304bfe3849 100644 --- a/src/gpu/jit/binary_format.cpp +++ b/src/gpu/jit/binary_format.cpp @@ -144,40 +144,45 @@ class binary_format_kernel_t : public jit_generator { threadend(SWSB(sb2, 1), r127); } - static compute::kernel_t make_kernel(compute::compute_engine_t *engine) { + static compute::kernel_t make_kernel( + compute::compute_engine_t *engine, bool *skip_check) { compute::kernel_t kernel; + *skip_check = false; + if (hw != HW::Unknown) { binary_format_kernel_t binary_format_kernel; auto status = engine->create_kernel(&kernel, &binary_format_kernel, {}); + if (status != status::success) return nullptr; + *skip_check = binary_format_kernel.binaryIsZebin(); } else { switch (engine->device_info()->gpu_arch()) { case compute::gpu_arch_t::gen9: kernel = binary_format_kernel_t::make_kernel( - engine); + engine, skip_check); break; case compute::gpu_arch_t::gen11: kernel = binary_format_kernel_t::make_kernel( - engine); + engine, skip_check); break; case compute::gpu_arch_t::xe_lp: kernel = binary_format_kernel_t::make_kernel( - engine); + engine, skip_check); break; case compute::gpu_arch_t::xe_hp: kernel = binary_format_kernel_t::make_kernel( - engine); + engine, skip_check); break; case compute::gpu_arch_t::xe_hpg: kernel = binary_format_kernel_t::make_kernel( - engine); + engine, skip_check); break; case compute::gpu_arch_t::xe_hpc: kernel = binary_format_kernel_t::make_kernel( - engine); + engine, skip_check); break; case compute::gpu_arch_t::unknown: kernel = nullptr; break; } @@ -199,9 +204,16 @@ status_t gpu_supports_binary_format(bool *ok, engine_t *engine) { auto stream = utils::downcast(stream_generic); if (!stream) return status::invalid_arguments; - auto kernel = binary_format_kernel_t::make_kernel(gpu_engine); + bool skip_check = false; + auto kernel = binary_format_kernel_t::make_kernel( + gpu_engine, &skip_check); if (!kernel) return status::success; + if (skip_check) { + *ok = true; + return status::success; + } + // Binary kernel check. uint32_t magic0 = MAGIC0; uint64_t magic1 = MAGIC1; diff --git a/src/gpu/jit/ngen/ngen_opencl.hpp b/src/gpu/jit/ngen/ngen_opencl.hpp index 34dceea38de..319821ce9be 100644 --- a/src/gpu/jit/ngen/ngen_opencl.hpp +++ b/src/gpu/jit/ngen/ngen_opencl.hpp @@ -58,6 +58,8 @@ class OpenCLCodeGenerator : public ELFCodeGenerator inline std::vector getBinary(cl_context context, cl_device_id device, const std::string &options = "-cl-std=CL2.0"); inline cl_kernel getKernel(cl_context context, cl_device_id device, const std::string &options = "-cl-std=CL2.0"); + bool binaryIsZebin() { return isZebin; } + static inline HW detectHW(cl_context context, cl_device_id device); static inline void detectHWInfo(cl_context context, cl_device_id device, HW &outHW, Product &outProduct); @@ -65,6 +67,7 @@ class OpenCLCodeGenerator : public ELFCodeGenerator static inline void detectHWInfo(cl_context context, cl_device_id device, HW &outHW, int &outStepping); private: + bool isZebin = false; inline std::vector getPatchTokenBinary(cl_context context, cl_device_id device, const std::vector *code = nullptr, const std::string &options = "-cl-std=CL2.0"); }; @@ -162,6 +165,7 @@ std::vector OpenCLCodeGenerator::getBinary(cl_context context, cl_d for (bool defaultFormat : {true, false}) { bool legacy = defaultFormat ^ zebinFirst; + isZebin = !legacy; if (legacy) { try {