Skip to content

Commit

Permalink
Refactor LXU cache logic in TBE fwd training
Browse files Browse the repository at this point in the history
Summary:
The LXU cache logic is in the critical path of the forward TBE kernel.
Even when the LXU cache is not used, the kernel still checks whether a
row should be fetched from the cache or HBM at runtime.  The branching
logic should be harmless for the memory (subsystem) bound case.
However, it could add significant overhead if TBE is conditional
bound.  (We have observed that FP16 weight type is generally compute
or conditional bound, while FP32 weight type is memory bound.)

This diff adds a static conditional in the forward TBE kernel to
enable/disable the LXU cache code path at compile time.  At runtime,
the host selects the kernel with/without cache enabled based on
whether the LXU cache is present.

This diff also moves the conditional outside the D loop.  It should
add a small benefit for the large D cases when cache is used.

Differential Revision: D39353035

fbshipit-source-id: 3b6b8d84e2da95be2b95aca0012b6114aadaf1ce
  • Loading branch information
sryap committed Sep 10, 2022
1 parent 4d6209f commit 67f0aad
Showing 1 changed file with 147 additions and 120 deletions.
267 changes: 147 additions & 120 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ template <
typename cache_t,
{% if not dense %}
typename output_t,
bool use_lxu_cache,
{% endif %}
typename index_t
{% if not nobag %}
Expand Down Expand Up @@ -241,7 +242,7 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
{% if not dense %}
int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
int32_t cache_idx = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
{% endif %}
{% if weighted %}
at::acc_type<cache_t, true> idx_weight = l < L ? indice_weights[indices_start + l] : 0;
Expand All @@ -260,73 +261,73 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba
{% endif %}

{% if not dense %}
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams

{% endif %}
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
}

{% if not nobag %}
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
// use_lxu_cache is a compile time condition
if (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams

{% if not nobag %}
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
} else {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
}
{% else %}
for (int32_t i = 0; i < D; i+=4 * kWarpSize) {
int32_t d = i + threadIdx.x * 4;
if (d < D) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
}
}
{% endif %}
}
else { // else row is not in cache
{% endif %}
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
}
{% if not nobag %}
#pragma unroll kMaxVecsPerThread
for (int32_t i = 0;
i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D;
++i) {
int32_t d = 4 * kWarpSize * i + threadIdx.x * 4;
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
{% if weighted %}
accumulators[i].fma_(weight, idx_weight_j);
{% else %}
accumulators[i].add_(weight);
{% endif %}
{% endif %}
}
{% else %}
for (int32_t i = 0; i < D; i+=4 * kWarpSize) {
int32_t d = i + threadIdx.x * 4;
if (d < D) {
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
} else {
}
{% else %}
for (int32_t i = 0; i < D; i+=4 * kWarpSize) {
int32_t d = i + threadIdx.x * 4;
if (d < D) {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
}
{% else %}
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
{% endif %}
}
}
{% endif %}
{% if not dense %}
} // else row is not in cache
{% endif %}
}
}
Expand Down Expand Up @@ -514,50 +515,66 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
output.scalar_type(),
{% endif %}
"batched_embedding{{ "_nobag" if nobag else "" }}_forward_kernel_2", [&] {
{% if not dense %}
// Check if LXU cache is used
bool use_lxu_cache = lxu_cache_weights.numel() > 0;
{% endif %}
{% if not nobag %}
{% for kMaxVecsPerThread in range(1, max_embedding_dim // items_per_warp + 1) %}
if (max_D <= {{ items_per_warp * kMaxVecsPerThread }}) {
{% if not dense %}
split_embedding_codegen_forward_{{ wdesc }}_kernel<emb_t, cache_t, output_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% else %}
dense_embedding_codegen_forward_{{ wdesc }}_kernel<scalar_t, scalar_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
pooling_mode,
{% if weighted %}
indice_weights.packed_accessor32<at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 1, at::RestrictPtrTraits>(),
{% endif %}
{% for use_cache in ["false", "true"] %}
// The dense case does not have cache so we have to generate code for
// only one case (value of use_cache does not matter)
{% if (not dense) or (use_cache == "true") %}
{% if not dense %}
if (use_lxu_cache == {{ use_cache }}) {
{% endif %}
{% for kMaxVecsPerThread in range(1, max_embedding_dim // items_per_warp + 1) %}
if (max_D <= {{ items_per_warp * kMaxVecsPerThread }}) {
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
split_embedding_codegen_forward_{{ wdesc }}_kernel<emb_t, cache_t, output_t, {{ use_cache }}, int64_t, {{ kMaxVecsPerThread }}><<<
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
dense_embedding_codegen_forward_{{ wdesc }}_kernel<scalar_t, scalar_t, int64_t, {{ kMaxVecsPerThread }}><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D_offsets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
pooling_mode,
{% if weighted %}
indice_weights.packed_accessor32<at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 1, at::RestrictPtrTraits>(),
{% endif %}
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
return;
}
{% endfor %}
return;
}
{% endfor %}
{% if not dense %}
} // if (use_lxu_cache == {{ use_cache }})
{% endif %}
{% endif %} // if (not dense) or (use_cache == "true")
{% endfor %} // for use_cache in ["false", "true"]
{% else %}
{% for kEmbeddingSize in [4, 8, 16, 32] %}
if (D <= {{ kEmbeddingSize }}) {
Expand Down Expand Up @@ -598,41 +615,51 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
return;
}
{% endfor %}
{% for use_cache in ["false", "true"] %}
// The dense case does not have cache so we have to generate code for
// only one case (value of use_cache does not matter)
{% if (not dense) or (use_cache == "true") %}
{% if not dense %}
split_embedding_nobag_codegen_forward_unweighted_kernel<emb_t, cache_t, output_t, int64_t><<<
if (use_lxu_cache == {{ use_cache }}) {
split_embedding_nobag_codegen_forward_unweighted_kernel<emb_t, cache_t, output_t, {{ use_cache }}, int64_t><<<
{% else %}
dense_embedding_nobag_codegen_forward_unweighted_kernel<scalar_t, scalar_t, int64_t><<<
dense_embedding_nobag_codegen_forward_unweighted_kernel<scalar_t, scalar_t, int64_t><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D,
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D,
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
return;
return;
{% if not dense %}
} // if (use_lxu_cache == {{ use_cache }})
{% endif %}
{% endif %} // if (not dense) or (use_cache == "true")
{% endfor %} // for use_cache in ["false", "true"]
{% endif %}
});
Expand Down

0 comments on commit 67f0aad

Please sign in to comment.