Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nobag kernel for embedding_dim <= 32 #1197

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,126 @@ constexpr size_t kForwardMaxThreads = 512;
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

{% if not weighted %}
template <
typename emb_t,
typename cache_t,
{% if not dense %}
typename output_t,
{% endif %}
typename index_t,
size_t kThreadGroupSize
>
__launch_bounds__(kForwardMaxThreads)
__global__ void {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel(
const at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
{% if not dense %}
const at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
const at::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits>
lxu_cache_weights,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
{% endif %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
int64_t D,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
{% if not dense %}
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations,
at::PackedTensorAccessor32<output_t, 2, at::RestrictPtrTraits>
output // [B][total_D],
{% else %}
at::PackedTensorAccessor32<at::acc_type<cache_t,true>, 2, at::RestrictPtrTraits>
output // [B][total_D],
{% endif %}
) {
int32_t T = weights_offsets.size(0);
int32_t B = (offsets.size(0) - 1) / T;
int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
int32_t t = b_t / B;
int32_t b = b_t % B;

if (b_t >= B * T) {
return;
}
int64_t weights_offset = weights_offsets[t];
index_t indices_start = offsets[t * B + b];
index_t indices_end = offsets[t * B + b + 1];
int32_t L = indices_end - indices_start;
const emb_t* __restrict__ weights;
{% if not dense %}
const auto placement = static_cast<PlacementType>(weights_placements[t]);
if (placement == PlacementType::DEVICE) {
weights = &dev_weights[weights_offset];
} else {
weights = &uvm_weights[weights_offset];
}
{% else %}
weights = &dev_weights[weights_offset];
{% endif %}

int32_t D_emb = D;
if (std::is_same<emb_t, uint8_t>::value) {
D_emb += kINT8QparamsBytes;
}

constexpr int32_t kNumThreadGroup = kWarpSize / kThreadGroupSize;
const int32_t group_start = threadIdx.x / kThreadGroupSize * kThreadGroupSize;
const int32_t group_end = group_start + kThreadGroupSize;
const int32_t d = threadIdx.x % kThreadGroupSize * 4;

for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
{% if not dense %}
int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
{% endif %}
for (auto j = group_start; j < group_end && l_start + j < L; ++j) {
int64_t idx_j = shfl_sync(idx, j);
int64_t output_j = indices_start + l_start + j;
{% if not dense %}
int32_t cache_idx_j = shfl_sync(cache_idx, j);
{% endif %}

{% if not dense %}
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams

{% endif %}
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
}

if (d < D) {
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
} else {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
}
{% else %}
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
{% endif %}
}
}
}
}
{% endif %}

{% for nobag in [True, False] %}
{% if not nobag or not weighted %}
template <
Expand Down Expand Up @@ -439,6 +559,45 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
}
{% endfor %}
{% else %}
{% for kEmbeddingSize in [4, 8, 16, 32] %}
if (D <= {{ kEmbeddingSize }}) {
{% if not dense %}
split_embedding_nobag_codegen_forward_unweighted_small_kernel<emb_t, cache_t, output_t, int64_t, {{ kEmbeddingSize // 4 }}><<<
{% else %}
dense_embedding_nobag_codegen_forward_unweighted_small_kernel<scalar_t, scalar_t, int64_t, {{ kEmbeddingSize // 4 }}><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D,
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}

return;
}
{% endfor %}
{% if not dense %}
split_embedding_nobag_codegen_forward_unweighted_kernel<emb_t, cache_t, output_t, int64_t><<<
{% else %}
Expand Down