Skip to content

Commit

Permalink
make fill_random_weights work for MTIA tbe module (pytorch#2286)
Browse files Browse the repository at this point in the history
Summary:

att, just a couple places left over previously when adding MTIA EmbeddingLocation.

Reviewed By: jspark1105

Differential Revision: D53062844
  • Loading branch information
842974287 authored and facebook-github-bot committed Jan 26, 2024
1 parent 9b2fa10 commit 1618a56
Showing 1 changed file with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def nbit_construct_split_state(
offsets.append(host_size)
host_size += state_size
elif location == EmbeddingLocation.DEVICE or location == EmbeddingLocation.MTIA:
placements.append(EmbeddingLocation.DEVICE)
placements.append(location)
offsets.append(dev_size)
dev_size += state_size
else:
Expand Down Expand Up @@ -1176,7 +1176,10 @@ def split_embedding_weights_with_scale_bias(
splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = []
for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs):
placement = self.weights_physical_placements[t]
if placement == EmbeddingLocation.DEVICE.value:
if (
placement == EmbeddingLocation.DEVICE.value
or placement == EmbeddingLocation.MTIA.value
):
weights = self.weights_dev
elif placement == EmbeddingLocation.HOST.value:
weights = self.weights_host
Expand Down

0 comments on commit 1618a56

Please sign in to comment.