Skip to content

Commit

Permalink
gpu: lnorm: align BWD fused kernel selection with changed tag matchin…
Browse files Browse the repository at this point in the history
…g semantics
  • Loading branch information
skazakov1 authored and karturov committed Oct 25, 2023
1 parent a96e9b1 commit a2ec0a0
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/gpu/ocl/vectorized_lnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ bool is_fused_kernel_applicable(lnorm_conf_t &conf,

auto gpu_arch = compute_engine->device_info()->gpu_arch();
memory_desc_wrapper src_mdw(pd->src_md());
memory_desc_wrapper dst_mdw(pd->src_md());
memory_desc_wrapper stat_mdw(pd->stat_md());
auto eu_count = compute_engine->device_info()->eu_count();
auto max_eus_per_wg = device_info_t::max_eus_per_wg(gpu_arch);
Expand All @@ -53,9 +52,11 @@ bool is_fused_kernel_applicable(lnorm_conf_t &conf,
const size_t max_slm_size = device_info_t::max_slm_size(gpu_arch);

// Plain layout only
if (!(src_mdw.matches_one_of_tag(ab, abc, abcd, abcde)
&& stat_mdw.matches_one_of_tag(a, ab)))
return false;
const bool is_plain = src_mdw.matches_one_of_tag(ab, abc)
&& stat_mdw.matches_one_of_tag(a, ab)
// kernel does not support M x 1 x N layout
&& IMPLICATION(src_mdw.ndims() == 3, src_mdw.dims()[1] != 1);
if (!is_plain) return false;

const int desired_sg_size = 16; // based on PVC performance data
conf.sub_group_size = mayiuse_sg(desired_sg_size, engine)
Expand Down

0 comments on commit a2ec0a0

Please sign in to comment.