Skip to content

Commit

Permalink
Support MTIA in DenseTableBatchedEmbeddingBagsCodegen
Browse files Browse the repository at this point in the history
Summary: As titled

Differential Revision: D58137460
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jun 4, 2024
1 parent 168ae4c commit 984a435
Showing 1 changed file with 12 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2441,6 +2441,7 @@ def __init__(
pooling_mode: PoolingMode = PoolingMode.SUM,
use_cpu: bool = False,
output_dtype: SparseType = SparseType.FP32,
use_mtia: bool = False,
) -> None: # noqa C901 # tuple of (rows, dims,)
super(DenseTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand All @@ -2449,7 +2450,10 @@ def __init__(
self.output_dtype: int = output_dtype.as_int()
table_embedding_dtype = weights_precision.as_dtype()

self.use_cpu = use_cpu
self.use_cpu: bool = use_cpu
self.use_mtia: bool = use_mtia

assert not (use_cpu and use_mtia), "Cannot use CPU and MTIA at the same time"

if self.use_cpu or self.pooling_mode == PoolingMode.NONE:
assert output_dtype in [
Expand All @@ -2460,7 +2464,13 @@ def __init__(

# pyre-fixme[8]: Attribute has type `device`; used as `Union[int, device]`.
self.current_device: torch.device = (
torch.device("cpu") if self.use_cpu else torch.cuda.current_device()
torch.device("cpu")
if self.use_cpu
else (
torch.mtia.current_device()
if self.use_mtia
else torch.cuda.current_device()
)
)

self.embedding_specs = embedding_specs
Expand Down

0 comments on commit 984a435

Please sign in to comment.