Skip to content

Commit

Permalink
Make iter persistent and add kernel check (#2897)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2897

[fbgemm_gpu][GWD] Make `iter` persistent and add kernel check to fail when `prev_iter` >= `iter`.

Reviewed By: csmiler

Differential Revision: D58566070

fbshipit-source-id: 2ebe618b1e970bb155907b3a94c120459c04c785
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jul 26, 2024
1 parent 3dad495 commit 42d1320
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/genscript/jinja_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str:
if is_global_weight_decay_kernel:
return """
const auto prev_iter = prev_iter_dev[linear_index];
CUDA_KERNEL_ASSERT(prev_iter < iter);
const auto global_weight_decay = prev_iter == 0 ? 1 : max(gwd_lower_bound, powf(weight_decay_base, iter - prev_iter - 1));
if (threadIdx.x == 0) {
prev_iter_dev[linear_index] = iter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ using namespace fbgemm_gpu;
{%- if is_gwd_kernel %}
// if l > L or prev_iter == 0, global_weight_decay = 1
const auto prev_it = prev_iter[idx];
CUDA_KERNEL_ASSERT(prev_it < iter);
const auto global_weight_decay = (l > L || prev_it == 0) ? 1 : max(gwd_lower_bound, powf(weight_decay_base, iter - prev_it - 1));
{%- endif %}

Expand Down

0 comments on commit 42d1320

Please sign in to comment.