Skip to content

Commit

Permalink
Support MTIA in DenseTableBatchedEmbeddingBagsCodegen (pytorch#2070)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/FBGEMM#2680


As titled

Differential Revision: D58137460
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jun 5, 2024
1 parent f699979 commit 0c106eb
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,16 +675,20 @@ def __init__(
weights_precision = data_type_to_sparse_type(config.data_type)
fused_params = config.fused_params or {}
output_dtype = fused_params.get("output_dtype", SparseType.FP32)
use_cpu: bool = (
device is None
or device.type == "cpu"
or (not (torch.cuda.is_available() or torch.mtia.is_available()))
)
self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
DenseTableBatchedEmbeddingBagsCodegen(
list(zip(self._local_rows, self._local_cols)),
feature_table_map=self._feature_table_map,
pooling_mode=PoolingMode.NONE,
use_cpu=device is None
or device.type == "cpu"
or not torch.cuda.is_available(),
use_cpu=use_cpu,
weights_precision=weights_precision,
output_dtype=output_dtype,
use_mtia=device is not None and device.type == "mtia",
)
)
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
Expand Down Expand Up @@ -975,16 +979,20 @@ def __init__(
weights_precision = data_type_to_sparse_type(config.data_type)
fused_params = config.fused_params or {}
output_dtype = fused_params.get("output_dtype", SparseType.FP32)
use_cpu: bool = (
device is None
or device.type == "cpu"
or (not (torch.cuda.is_available() or torch.mtia.is_available()))
)
self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
DenseTableBatchedEmbeddingBagsCodegen(
list(zip(self._local_rows, self._local_cols)),
feature_table_map=self._feature_table_map,
pooling_mode=self._pooling,
use_cpu=device is None
or device.type == "cpu"
or not torch.cuda.is_available(),
use_cpu=use_cpu,
weights_precision=weights_precision,
output_dtype=output_dtype,
use_mtia=device is not None and device.type == "mtia",
)
)
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
Expand Down

0 comments on commit 0c106eb

Please sign in to comment.