Skip to content

Commit

Permalink
deprecate the old split_embedding_weights impl (#1817)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1817

Clean up

Reviewed By: sryap

Differential Revision: D46585801

fbshipit-source-id: b40dcf9f10ed35e3e8955d1eff8a4628942003b5
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jun 12, 2023
1 parent b9bcddd commit 46c764b
Showing 1 changed file with 22 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1137,78 +1137,6 @@ def reset_cache_states(self) -> None:
self.lxu_state.fill_(0)
self.timestep_counter.reset()

@torch.jit.export
def split_embedding_weights(
self,
split_scale_shifts: bool = True
# When true, return list of two tensors, the first with weights and
# the second with scale_bias.
# This should've been named as split_scale_bias.
# Keep as is for backward compatibility.
) -> List[Tuple[Tensor, Optional[Tensor]]]:
"""
Returns a list of weights, split by table
"""
assert self.weight_initialized
splits: List[Tuple[Tensor, Optional[Tensor]]] = []
for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
placement = self.weights_physical_placements[t]
if placement == EmbeddingLocation.DEVICE.value:
weights = self.weights_dev
elif placement == EmbeddingLocation.HOST.value:
weights = self.weights_host
else:
weights = self.weights_uvm
offset = self.weights_physical_offsets[t]
weights_shifts = weights.detach()[
offset : offset
+ rows
* rounded_row_size_in_bytes(
dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes
)
].view(
rows,
rounded_row_size_in_bytes(
dim, weight_ty, self.row_alignment, self.scale_bias_size_in_bytes
),
)

if split_scale_shifts:
# remove the padding at the end of each row.
weights_shifts = weights_shifts[
:,
: unpadded_row_size_in_bytes(
dim, weight_ty, self.scale_bias_size_in_bytes
),
]
if (
weight_ty == SparseType.INT8
or weight_ty == SparseType.INT4
or weight_ty == SparseType.INT2
):
splits.append(
(
weights_shifts[:, self.scale_bias_size_in_bytes :],
weights_shifts[:, : self.scale_bias_size_in_bytes],
)
)
else:
assert (
weight_ty == SparseType.FP8
or weight_ty == SparseType.FP16
or weight_ty == SparseType.FP32
)
splits.append(
(
weights_shifts,
None,
)
)
else:
splits.append((weights_shifts, None))

return splits

@torch.jit.export
def split_embedding_weights_with_scale_bias(
self, split_scale_bias_mode: int = 1
Expand Down Expand Up @@ -1300,6 +1228,28 @@ def split_embedding_weights_with_scale_bias(

return splits

@torch.jit.export
def split_embedding_weights(
self,
split_scale_shifts: bool = True
# When true, return list of two tensors, the first with weights and
# the second with scale_bias.
# This should've been named as split_scale_bias.
# Keep as is for backward compatibility.
) -> List[Tuple[Tensor, Optional[Tensor]]]:
"""
Returns a list of weights, split by table
"""
splits: List[
Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
] = self.split_embedding_weights_with_scale_bias(
split_scale_bias_mode=(1 if split_scale_shifts else 0)
)
return [
(split_weight_scale_bias[0], split_weight_scale_bias[1])
for split_weight_scale_bias in splits
]

@torch.jit.export
def initialize_weights(self) -> None:
if not self.weight_initialized:
Expand Down

0 comments on commit 46c764b

Please sign in to comment.