-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Metal] Support fast_math, preps for saturating_grid_dim #1443
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,15 +84,17 @@ using InputBuffersMap = std::unordered_map<BufferEnum, MTLBuffer *>; | |
class CompiledMtlKernelBase { | ||
public: | ||
struct Params { | ||
const KernelAttributes *kernel_attribs; | ||
bool is_jit_evaluator; | ||
const CompileConfig *config; | ||
const KernelAttributes *kernel_attribs; | ||
MTLDevice *device; | ||
MTLFunction *mtl_func; | ||
}; | ||
|
||
explicit CompiledMtlKernelBase(Params ¶ms) | ||
: is_jit_evalutor_(params.is_jit_evaluator), | ||
kernel_attribs_(*params.kernel_attribs), | ||
: kernel_attribs_(*params.kernel_attribs), | ||
config_(params.config), | ||
is_jit_evalutor_(params.is_jit_evaluator), | ||
pipeline_state_( | ||
new_compute_pipeline_state_with_function(params.device, | ||
params.mtl_func)) { | ||
|
@@ -113,7 +115,7 @@ class CompiledMtlKernelBase { | |
|
||
void launch_if_not_empty(BindBuffers buffers, | ||
MTLCommandBuffer *command_buffer) { | ||
const int num_threads = kernel_attribs_.num_threads; | ||
const int num_threads = get_total_num_threads(); | ||
if (num_threads == 0) { | ||
return; | ||
} | ||
|
@@ -130,40 +132,49 @@ class CompiledMtlKernelBase { | |
set_mtl_buffer(encoder.get(), b.first, /*offset=*/0, bi); | ||
} | ||
|
||
const int native_block_dim = | ||
get_max_total_threads_per_threadgroup(pipeline_state_.get()); | ||
|
||
int num_threads_per_group = 0; | ||
// Sometimes it is helpful to limit the maximum GPU block dim for the | ||
// kernels. E.g., when you are generating iPhone shaders on a Mac. | ||
const int prescribed_block_dim = | ||
(std::size_t)get_current_program().config.max_block_dim; | ||
if (prescribed_block_dim != 0) { | ||
num_threads_per_group = std::min(native_block_dim, prescribed_block_dim); | ||
} else { | ||
num_threads_per_group = native_block_dim; | ||
} | ||
Comment on lines
-136
to
-145
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OFT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really :) These are the necessary changes for supporting |
||
|
||
const int num_threads_per_group = get_num_threads_per_group(num_threads); | ||
const int num_groups = | ||
((num_threads + num_threads_per_group - 1) / num_threads_per_group); | ||
|
||
const int dispatch_num_threads = | ||
std::min(num_threads, num_threads_per_group); | ||
|
||
if (!is_jit_evalutor_) { | ||
ActionRecorder::get_instance().record( | ||
"launch_kernel", | ||
{ActionArg("kernel_name", kernel_attribs_.name), | ||
ActionArg("num_groups", num_groups), | ||
ActionArg("num_threads_per_group", dispatch_num_threads)}); | ||
ActionArg("num_threads_per_group", num_threads_per_group)}); | ||
} | ||
|
||
dispatch_threadgroups(encoder.get(), num_groups, dispatch_num_threads); | ||
dispatch_threadgroups(encoder.get(), num_groups, num_threads_per_group); | ||
end_encoding(encoder.get()); | ||
} | ||
|
||
const bool is_jit_evalutor_; | ||
int get_total_num_threads() const { | ||
int num_threads = kernel_attribs_.num_threads; | ||
// TODO(k-ye): Surface |saturating_grid_dim| to ti.init() once #1396 is in. | ||
// const int prescribed_grid_dim = config_->saturating_grid_dim; | ||
// if (prescribed_grid_dim > 0) { | ||
// num_threads = std::min(num_threads, prescribed_grid_dim); | ||
// } | ||
return num_threads; | ||
} | ||
|
||
int get_num_threads_per_group(int total_num_threads) const { | ||
int num_threads_per_group = | ||
get_max_total_threads_per_threadgroup(pipeline_state_.get()); | ||
// Sometimes it is helpful to limit the maximum GPU block dim for the | ||
// kernels. E.g., when you are generating iPhone shaders on a Mac. | ||
const int prescribed_block_dim = config_->max_block_dim; | ||
if (prescribed_block_dim > 0) { | ||
num_threads_per_group = | ||
std::min(num_threads_per_group, prescribed_block_dim); | ||
} | ||
// Cap by |total_num_threads| in case this is a very small kernel. | ||
return std::min(num_threads_per_group, total_num_threads); | ||
} | ||
|
||
KernelAttributes kernel_attribs_; | ||
const CompileConfig *const config_; | ||
const bool is_jit_evalutor_; | ||
nsobj_unique_ptr<MTLComputePipelineState> pipeline_state_; | ||
}; | ||
|
||
|
@@ -253,15 +264,16 @@ class CompiledTaichiKernel { | |
MTLDevice *device; | ||
MemoryPool *mem_pool; | ||
KernelProfilerBase *profiler; | ||
const CompileConfig *compile_config; | ||
}; | ||
|
||
CompiledTaichiKernel(Params params) | ||
: ti_kernel_attribs(*params.ti_kernel_attribs), | ||
ctx_attribs(*params.ctx_attribs) { | ||
auto *const device = params.device; | ||
auto kernel_lib = new_library_with_source( | ||
device, params.mtl_source_code, | ||
infer_msl_version(ti_kernel_attribs.used_features)); | ||
device, params.mtl_source_code, params.compile_config->fast_math, | ||
infer_msl_version(params.ti_kernel_attribs->used_features)); | ||
if (kernel_lib == nullptr) { | ||
TI_ERROR("Failed to compile Metal kernel! Generated code:\n\n{}", | ||
params.mtl_source_code); | ||
|
@@ -286,6 +298,7 @@ class CompiledTaichiKernel { | |
RuntimeListOpsMtlKernel::Params kparams; | ||
kparams.kernel_attribs = &ka; | ||
kparams.is_jit_evaluator = ti_kernel_attribs.is_jit_evaluator; | ||
kparams.config = params.compile_config; | ||
kparams.device = device; | ||
kparams.mtl_func = mtl_func.get(); | ||
kparams.mem_pool = params.mem_pool; | ||
|
@@ -295,6 +308,7 @@ class CompiledTaichiKernel { | |
UserMtlKernel::Params kparams; | ||
kparams.kernel_attribs = &ka; | ||
kparams.is_jit_evaluator = ti_kernel_attribs.is_jit_evaluator; | ||
kparams.config = params.compile_config; | ||
kparams.device = device; | ||
kparams.mtl_func = mtl_func.get(); | ||
kernel = std::make_unique<UserMtlKernel>(kparams); | ||
|
@@ -574,6 +588,7 @@ class KernelManager::Impl { | |
params.device = device_.get(); | ||
params.mem_pool = mem_pool_; | ||
params.profiler = profiler_; | ||
params.compile_config = config_; | ||
compiled_taichi_kernels_[taichi_kernel_name] = | ||
std::make_unique<CompiledTaichiKernel>(params); | ||
TI_DEBUG("Registered Taichi kernel <{}>", taichi_kernel_name); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ def _c_mod(a, b): | |
|
||
@pytest.mark.parametrize('lhs_is_mat,rhs_is_mat', [(True, True), (True, False), | ||
(False, True)]) | ||
@ti.all_archs | ||
@ti.all_archs_with(fast_math=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, it failed at |
||
def test_binary_f(lhs_is_mat, rhs_is_mat): | ||
x = ti.Matrix(3, 2, ti.f32, 16) | ||
if lhs_is_mat: | ||
|
@@ -145,7 +145,7 @@ def func(): | |
|
||
|
||
@pytest.mark.parametrize('rhs_is_mat', [True, False]) | ||
@ti.all_archs | ||
@ti.all_archs_with(fast_math=False) | ||
def test_writeback_binary_f(rhs_is_mat): | ||
x = ti.Matrix(3, 2, ti.f32, 9) | ||
y = ti.Matrix(3, 2, ti.f32, ()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Not sure about how
fast_math
works. IIUC is this the same as specifyingprecision mediump float;
in OpenGL?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are both for optimizing performance, but take different approaches. From what i can tell,
precision mediump float
reduces the bits to represent float. Metal also has a similar concept ofhalf
(16bits). Fast math, on the other hand, reduces the instructions in the computation to produce an approximated result.