diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index c161363f08..bedf38f5c8 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -54,6 +54,7 @@ class DoesNotHavePrefix(Exception): class ComputeDevice(enum.IntEnum): CPU = 0 CUDA = 1 + MTIA = 2 class WeightDecayMode(enum.IntEnum): @@ -366,7 +367,13 @@ def __init__( # noqa C901 assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" - self.use_cpu: bool = all(cd == ComputeDevice.CPU for cd in compute_devices) + # Split TBE has different function schemas for CUDA and CPU. + # For MTIA device type, it uses the CPU one. + self.use_cpu: bool = ( + compute_devices[0] == ComputeDevice.CPU + or compute_devices[0] == ComputeDevice.MTIA + ) + assert not self.use_cpu or all( loc == EmbeddingLocation.HOST for loc in locations ), "ComputeDevice.CPU is only for EmbeddingLocation.HOST!" @@ -998,7 +1005,7 @@ def forward( # noqa: C901 placements=self.momentum2_placements, ) # Ensure iter is always on CPU so the increment doesn't synchronize. - if self.iter.is_cuda: + if not self.iter.is_cpu: self.iter = self.iter.cpu() self.iter[0] += 1