Skip to content

Commit

Permalink
Add nobag kernel for embedding_dim <= 32
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed Jul 8, 2022
1 parent 64a5c4a commit dfac00a
Showing 1 changed file with 159 additions and 0 deletions.
159 changes: 159 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,126 @@ constexpr size_t kForwardMaxThreads = 512;
using Tensor = at::Tensor;
using namespace fbgemm_gpu;

{% if not weighted %}
template <
typename emb_t,
typename cache_t,
{% if not dense %}
typename output_t,
{% endif %}
typename index_t,
size_t kThreadGroupSize
>
__launch_bounds__(kForwardMaxThreads)
__global__ void {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel(
const at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> dev_weights,
{% if not dense %}
const at::PackedTensorAccessor64<emb_t, 1, at::RestrictPtrTraits> uvm_weights,
const at::PackedTensorAccessor64<cache_t, 2, at::RestrictPtrTraits>
lxu_cache_weights,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
weights_placements,
{% endif %}
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
int64_t D,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
{% if not dense %}
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations,
at::PackedTensorAccessor32<output_t, 2, at::RestrictPtrTraits>
output // [B][total_D],
{% else %}
at::PackedTensorAccessor32<at::acc_type<cache_t,true>, 2, at::RestrictPtrTraits>
output // [B][total_D],
{% endif %}
) {
int32_t T = weights_offsets.size(0);
int32_t B = (offsets.size(0) - 1) / T;
int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
int32_t t = b_t / B;
int32_t b = b_t % B;

if (b_t >= B * T) {
return;
}
int64_t weights_offset = weights_offsets[t];
index_t indices_start = offsets[t * B + b];
index_t indices_end = offsets[t * B + b + 1];
int32_t L = indices_end - indices_start;
const emb_t* __restrict__ weights;
{% if not dense %}
const auto placement = static_cast<PlacementType>(weights_placements[t]);
if (placement == PlacementType::DEVICE) {
weights = &dev_weights[weights_offset];
} else {
weights = &uvm_weights[weights_offset];
}
{% else %}
weights = &dev_weights[weights_offset];
{% endif %}

int32_t D_emb = D;
if (std::is_same<emb_t, uint8_t>::value) {
D_emb += kINT8QparamsBytes;
}

constexpr int32_t kNumThreadGroup = kWarpSize / kThreadGroupSize;
const int32_t group_start = threadIdx.x / kThreadGroupSize * kThreadGroupSize;
const int32_t group_end = group_start + kThreadGroupSize;
const int32_t d = threadIdx.x % kThreadGroupSize * 4;

for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
{% if not dense %}
int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0;
{% endif %}
for (auto j = group_start; j < group_end && l_start + j < L; ++j) {
int64_t idx_j = shfl_sync(idx, j);
int64_t output_j = indices_start + l_start + j;
{% if not dense %}
int32_t cache_idx_j = shfl_sync(cache_idx, j);
{% endif %}

{% if not dense %}
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
D,
nullptr);
float2 qparams_cache; // assume cache is fp16/fp32 which doesn't require qparams

{% endif %}
auto weight_row_emb = WeightRow<emb_t, cache_t, cache_t>(
const_cast<emb_t*>(&weights[idx_j * D_emb]),
nullptr,
D,
nullptr);
float2 qparams_emb;
if (std::is_same<emb_t, uint8_t>::value) {
qparams_emb = weight_row_emb.load_qparams();
}

if (d < D) {
{% if not dense %}
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
weight.store(&output[output_j][d]);
} else {
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
}
{% else %}
Vec4T<cache_t> weight = weight_row_emb.load(d, qparams_emb);
weight.store(&output[output_j][d]);
{% endif %}
}
}
}
}
{% endif %}

{% for nobag in [True, False] %}
{% if not nobag or not weighted %}
template <
Expand Down Expand Up @@ -439,6 +559,45 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
}
{% endfor %}
{% else %}
{% for kEmbeddingSize in [4, 8, 16, 32] %}
if (D <= {{ kEmbeddingSize }}) {
{% if not dense %}
split_embedding_nobag_codegen_forward_unweighted_small_kernel<emb_t, cache_t, output_t, int64_t, {{ kEmbeddingSize // 4 }}><<<
{% else %}
dense_embedding_nobag_codegen_forward_unweighted_small_kernel<scalar_t, scalar_t, int64_t, {{ kEmbeddingSize // 4 }}><<<
{% endif %}
div_round_up((B * T), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
dev_weights.packed_accessor64<{{ "scalar_t" if dense else "emb_t" }}, 1, at::RestrictPtrTraits>(),
{% if not dense %}
uvm_weights.packed_accessor64<emb_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
weights_placements.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
{% endif %}
weights_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
D,
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
{% if not dense %}
lxu_cache_locations.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<
output_t,
2,
at::RestrictPtrTraits>()
);
{% else %}
output.packed_accessor32<
at::acc_type<scalar_t, true>,
2,
at::RestrictPtrTraits>()
);
{% endif %}
return;
}
{% endfor %}
{% if not dense %}
split_embedding_nobag_codegen_forward_unweighted_kernel<emb_t, cache_t, output_t, int64_t><<<
{% else %}
Expand Down

0 comments on commit dfac00a

Please sign in to comment.