Skip to content

Commit

Permalink
xl v2 eager mode embedding loading (#2288)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2288

Add support for mtia model eager mode embedding loading support.

Reviewed By: wpc

Differential Revision: D53094023

fbshipit-source-id: 98c4a8269558160d29b3556dd5fd38a5381db338
  • Loading branch information
842974287 authored and facebook-github-bot committed Jan 30, 2024
1 parent 5921557 commit 0c7fa1a
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions fbgemm_gpu/include/fbgemm_gpu/embedding_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ div_round_up(uint32_t a, uint32_t b) {
return ((a + b - 1) / b);
}

C10_HOST_DEVICE C10_ALWAYS_INLINE int32_t
unpadded_row_size_in_bytes(int32_t dim, fbgemm_gpu::SparseType weight_ty) {
C10_HOST_DEVICE C10_ALWAYS_INLINE int32_t unpadded_row_size_in_bytes(
int32_t dim,
fbgemm_gpu::SparseType weight_ty,
const int32_t scale_bias_bytes = 4) {
if (weight_ty == fbgemm_gpu::SparseType::FP32) {
return dim * 4;
}
Expand All @@ -107,23 +109,24 @@ unpadded_row_size_in_bytes(int32_t dim, fbgemm_gpu::SparseType weight_ty) {
return dim;
}
if (weight_ty == fbgemm_gpu::SparseType::INT8) {
return dim + 4;
return dim + scale_bias_bytes;
}
if (weight_ty == fbgemm_gpu::SparseType::INT4) {
return dim / 2 + 4;
return dim / 2 + scale_bias_bytes;
}
if (weight_ty == fbgemm_gpu::SparseType::INT2) {
return dim / 4 + 4;
return dim / 4 + scale_bias_bytes;
}
return 0;
}

C10_HOST_DEVICE C10_ALWAYS_INLINE int32_t padded_row_size_in_bytes(
int32_t dim,
fbgemm_gpu::SparseType weight_ty,
int32_t row_alignment) {
auto r = unpadded_row_size_in_bytes(dim, weight_ty);
return round_up(r, row_alignment);
const int32_t row_alignment,
const int32_t scale_bias_bytes = 4) {
auto r = unpadded_row_size_in_bytes(dim, weight_ty, scale_bias_bytes);
return static_cast<int32_t>(round_up(r, row_alignment));
}

} // namespace nbit

0 comments on commit 0c7fa1a

Please sign in to comment.