Skip to content

Commit

Permalink
Add reverse qparam option for MTIA (#2109)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2109

As title

Reviewed By: sryap

Differential Revision:
D50945909

Privacy Context Container: L1188860

fbshipit-source-id: 7d6c3062c050d12fb614ef16af35818a2ecbe9ef
  • Loading branch information
Siyan Lin authored and facebook-github-bot committed Nov 3, 2023
1 parent df8d189 commit 5afde39
Showing 1 changed file with 57 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__( # noqa C901
scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES,
cacheline_alignment: bool = True,
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
) -> None: # noqa C901 # tuple of (rows, dims,)
super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -324,6 +325,7 @@ def max_ty_D(ty: SparseType) -> int:
self.initialize_physical_weights_placements_and_offsets(cacheline_alignment)
self.enforce_hbm: bool = enforce_hbm

self.reverse_qparam = reverse_qparam
# Assign weights after weights and weights_offsets are initialized.
if weight_lists:
self._apply_split(
Expand Down Expand Up @@ -1197,28 +1199,63 @@ def split_embedding_weights_with_scale_bias(
or weight_ty.value == SparseType.INT2.value
):
if split_scale_bias_mode == 1:
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[:, : self.scale_bias_size_in_bytes],
None,
if self.reverse_qparam:
splits.append(
(
weights_shifts[
:, 0 : (0 - self.scale_bias_size_in_bytes)
],
weights_shifts[
:, (0 - self.scale_bias_size_in_bytes) :
],
None,
)
)
)
else: # 2
# weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[
:, : self.scale_bias_size_in_bytes // 2
].view(torch.float16),
weights_shifts[
:,
self.scale_bias_size_in_bytes
// 2 : self.scale_bias_size_in_bytes,
].view(torch.float16),
else:
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[:, : self.scale_bias_size_in_bytes],
None,
)
)
)
elif split_scale_bias_mode == 2:
if self.reverse_qparam:
# weights_shifts: [0:-4] is real weights; [-4:-2] is scale; [-2:] is bias
splits.append(
(
weights_shifts[
:, 0 : (0 - self.scale_bias_size_in_bytes)
],
weights_shifts[
:,
(0 - self.scale_bias_size_in_bytes) : (
0 - self.scale_bias_size_in_bytes // 2
),
].view(torch.float16),
weights_shifts[
:, (0 - self.scale_bias_size_in_bytes // 2) :
].view(torch.float16),
)
)
else:
# weights_shifts: [0:2] is scale; [2:4] is bias; [4:] is real weights
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[
:, : self.scale_bias_size_in_bytes // 2
].view(torch.float16),
weights_shifts[
:,
self.scale_bias_size_in_bytes
// 2 : self.scale_bias_size_in_bytes,
].view(torch.float16),
)
)
else:
raise ValueError("split_scale_bias_mode is not supported")

elif (
weight_ty.value == SparseType.FP8.value
or weight_ty.value == SparseType.FP16.value
Expand Down

0 comments on commit 5afde39

Please sign in to comment.