diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 4015c93534..4e86ebacd1 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -101,7 +101,7 @@ def nbit_construct_split_state( offsets.append(host_size) host_size += state_size elif location == EmbeddingLocation.DEVICE or location == EmbeddingLocation.MTIA: - placements.append(EmbeddingLocation.DEVICE) + placements.append(location) offsets.append(dev_size) dev_size += state_size else: @@ -1176,7 +1176,10 @@ def split_embedding_weights_with_scale_bias( splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = [] for t, (_, rows, dim, weight_ty, _) in enumerate(self.embedding_specs): placement = self.weights_physical_placements[t] - if placement == EmbeddingLocation.DEVICE.value: + if ( + placement == EmbeddingLocation.DEVICE.value + or placement == EmbeddingLocation.MTIA.value + ): weights = self.weights_dev elif placement == EmbeddingLocation.HOST.value: weights = self.weights_host