Skip to content

Commit

Permalink
Refactor LXU cache logic in TBE fwd training (#1295)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1295

The LXU cache logic is in the critical path of the forward TBE kernel.
Even when the LXU cache is not used, the kernel still checks whether a
row should be fetched from the cache or HBM at runtime.  The branching
logic should be harmless for the memory (subsystem) bound case.
However, it could add significant overhead if TBE is conditional
bound.  (We have observed that FP16 weight type is generally compute
or conditional bound, while FP32 weight type is memory bound.)

This diff adds a static conditional in the forward TBE kernel to
enable/disable the LXU cache code path at compile time.  At runtime,
the host selects the kernel with/without cache enabled based on
whether the LXU cache is present.

This diff also moves the conditional outside the D loop.  It should
add a small benefit for the large D cases when cache is used.

Reviewed By: jspark1105

Differential Revision: D39353035

fbshipit-source-id: bfd3d842971091e954e49c6c8fad034db1fcbc9b
  • Loading branch information
sryap committed Sep 10, 2022
1 parent 4e9d23f commit 9aff411
Showing 1 changed file with 150 additions and 123 deletions.
273 changes: 150 additions & 123 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ template <
typename cache_t,
{% if not dense %}
typename output_t,
bool use_lxu_cache,
{% endif %}
typename index_t
{% if not nobag %}
Expand Down Expand Up @@ -243,7 +244,7 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
{% if not dense %}
int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
int32_t cache_idx = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
{% endif %}
{% if weighted %}
at::acc_type<cache_t, true> idx_weight = l < L ? indice_weights[indices_start + l] : 0;
Expand All @@ -254,81 +255,81 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba
int64_t output_j = indices_start + l_start + j;
{% endif %}
{% if not dense %}
int32_t cache_idx_j = shfl_sync(cache_idx, j);
int32_t cache_idx_j = use_lxu_cache ? shfl_sync(cache_idx, j) : 0;
{% endif %}

{% if weighted %}
at::acc_type<cache_t, true> idx_weight_j = shfl_sync(idx_weight, j);
{% endif %}

{% if not dense %}
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams

{% endif %}
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
}

{% if not nobag %}
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
// use_lxu_cache is a compile time condition
if (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams

{% if not nobag %}
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
} else {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
}
{% else %}
for (int32_t i = 0; i < D; i+=4 * kWarpSize) {
int32_t d = i + threadIdx.x * 4;
if (d < D) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
}
}
{% endif %}
}
else { // else row is not in cache
{% endif %}
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
}
{% if not nobag %}
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
{% endif %}
}
{% else %}
for (int32_t i = 0; i < D; i+=4 * kWarpSize) {
int32_t d = i + threadIdx.x * 4;
if (d < D) {
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
} else {
}
{% else %}
for (int32_t i = 0; i < D; i+=4 * kWarpSize) {
int32_t d = i + threadIdx.x * 4;
if (d < D) {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
}
{% else %}
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
{% endif %}
}
}
{% endif %}
{% if not dense %}
} // else row is not in cache
{% endif %}
}
}
Expand Down Expand Up @@ -516,51 +517,67 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
output.scalar_type(),
{% endif %}
"batched_embedding{{ "_nobag" if nobag else "" }}_forward_kernel_2", [&] {
{% if not dense %}
// Check if LXU cache is used
bool use_lxu_cache = lxu_cache_weights.numel() > 0;
{% endif %}
{% if not nobag %}
{% for kMaxVecsPerThread in range(1, max_embedding_dim // items_per_warp + 1) %}
if (max_D <= {{ items_per_warp * kMaxVecsPerThread }}) {
{% if not dense %}
split_embedding_codegen_forward_{{ wdesc }}_kernel<emb_t, cache_t, output_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% else %}
dense_embedding_codegen_forward_{{ wdesc }}_kernel<scalar_t, scalar_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
FixedDivisor(B),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
pooling_mode,
{% if weighted %}
indice_weights.packed_accessor32<at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 1, at::RestrictPtrTraits>(),
{% endif %}
{% for use_cache in ["false", "true"] %}
// The dense case does not have cache so we have to generate code for
// only one case (value of use_cache does not matter)
{% if (not dense) or (use_cache == "true") %}
{% if not dense %}
if (use_lxu_cache == {{ use_cache }}) {
{% endif %}
{% for kMaxVecsPerThread in range(1, max_embedding_dim // items_per_warp + 1) %}
if (max_D <= {{ items_per_warp * kMaxVecsPerThread }}) {
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
split_embedding_codegen_forward_{{ wdesc }}_kernel<emb_t, cache_t, output_t, {{ use_cache }}, int64_t, {{ kMaxVecsPerThread }}><<<
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
dense_embedding_codegen_forward_{{ wdesc }}_kernel<scalar_t, scalar_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
FixedDivisor(B),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
pooling_mode,
{% if weighted %}
indice_weights.packed_accessor32<at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 1, at::RestrictPtrTraits>(),
{% endif %}
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
return;
}
{% endfor %}
return;
}
{% endfor %}
{% if not dense %}
} // if (use_lxu_cache == {{ use_cache }})
{% endif %}
{% endif %} // if (not dense) or (use_cache == "true")
{% endfor %} // for use_cache in ["false", "true"]
{% else %}
{% for kEmbeddingSize in [4, 8, 16, 32] %}
if (D <= {{ kEmbeddingSize }}) {
Expand Down Expand Up @@ -602,42 +619,52 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
return;
}
{% endfor %}
{% for use_cache in ["false", "true"] %}
// The dense case does not have cache so we have to generate code for
// only one case (value of use_cache does not matter)
{% if (not dense) or (use_cache == "true") %}
{% if not dense %}
split_embedding_nobag_codegen_forward_unweighted_kernel<emb_t, cache_t, output_t, int64_t><<<
if (use_lxu_cache == {{ use_cache }}) {
split_embedding_nobag_codegen_forward_unweighted_kernel<emb_t, cache_t, output_t, {{ use_cache }}, int64_t><<<
{% else %}
dense_embedding_nobag_codegen_forward_unweighted_kernel<scalar_t, scalar_t, int64_t><<<
dense_embedding_nobag_codegen_forward_unweighted_kernel<scalar_t, scalar_t, int64_t><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D,
FixedDivisor(B),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D,
FixedDivisor(B),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
return;
return;
{% if not dense %}
} // if (use_lxu_cache == {{ use_cache }})
{% endif %}
{% endif %} // if (not dense) or (use_cache == "true")
{% endfor %} // for use_cache in ["false", "true"]
{% endif %}
});
Expand Down

0 comments on commit 9aff411

Please sign in to comment.