Skip to content

Commit

Permalink
Add reverse qparam option for MTIA
Browse files Browse the repository at this point in the history
Summary: As title

Differential Revision:
D50945909

Privacy Context Container: L1188860
  • Loading branch information
Siyan Lin authored and facebook-github-bot committed Nov 2, 2023
1 parent cbb3130 commit dd5a350
Showing 1 changed file with 57 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__( # noqa C901
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
cacheline_alignment: bool = True,
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
) -> None: # noqa C901 # tuple of (rows, dims,)
super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -324,6 +325,7 @@ def max_ty_D(ty: SparseType) -> int:
self.initialize_physical_weights_placements_and_offsets(cacheline_alignment)
self.enforce_hbm: bool = enforce_hbm

self.reverse_qparam = reverse_qparam
# Assign weights after weights and weights_offsets are initialized.
if weight_lists:
self._apply_split(
Expand Down Expand Up @@ -1197,28 +1199,63 @@ def split_embedding_weights_with_scale_bias(
or weight_ty.value == SparseType.INT2.value
):
if split_scale_bias_mode == 1:
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[:, : self.scale_bias_size_in_bytes],
None,
if self.reverse_qparam:
splits.append(
(
weights_shifts[
:, 0 : (0 - self.scale_bias_size_in_bytes)
],
weights_shifts[
:, (0 - self.scale_bias_size_in_bytes) :
],
None,
)
)
)
else: # 2
# weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[
:, : self.scale_bias_size_in_bytes // 2
].view(torch.float16),
weights_shifts[
:,
self.scale_bias_size_in_bytes
// 2 : self.scale_bias_size_in_bytes,
].view(torch.float16),
else:
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[:, : self.scale_bias_size_in_bytes],
None,
)
)
)
elif split_scale_bias_mode == 2:
if self.reverse_qparam:
# weights_shifts: [0:-4] is real weights; [-4:-2] is scale; [-2:] is bias
splits.append(
(
weights_shifts[
:, 0 : (0 - self.scale_bias_size_in_bytes)
],
weights_shifts[
:,
(0 - self.scale_bias_size_in_bytes) : (
0 - self.scale_bias_size_in_bytes // 2
),
].view(torch.float16),
weights_shifts[
:, (0 - self.scale_bias_size_in_bytes // 2) :
].view(torch.float16),
)
)
else:
# weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[
:, : self.scale_bias_size_in_bytes // 2
].view(torch.float16),
weights_shifts[
:,
self.scale_bias_size_in_bytes
// 2 : self.scale_bias_size_in_bytes,
].view(torch.float16),
)
)
else:
raise ValueError("split_scale_bias_mode is not supported")

elif (
weight_ty.value == SparseType.FP8.value
or weight_ty.value == SparseType.FP16.value
Expand Down

0 comments on commit dd5a350

Please sign in to comment.