diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index 15ce4fe114..365aebbfce 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -33,6 +33,7 @@ class EmbOptimType(enum.Enum): SHAMPOO_V2 = "shampoo_v2" # not currently supported for sparse embedding tables MADGRAD = "madgrad" EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated + ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad" NONE = "none" def __str__(self) -> str: diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index b2b0fc2591..ddb8485681 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -25,7 +25,7 @@ import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers -# from fbgemm_gpu.config import FeatureGateName +from fbgemm_gpu.config import FeatureGate, FeatureGateName from fbgemm_gpu.runtime_monitor import ( AsyncSeriesTimer, TBEStatsReporter, @@ -1549,6 +1549,17 @@ def forward( # noqa: C901 offsets=self.row_counter_offsets, placements=self.row_counter_placements, ) + + if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: + if FeatureGate.is_enabled(FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD): + raise AssertionError( + "ENSEMBLE_ROWWISE_ADAGRAD feature has not landed yet (see D60189486 stack)" + ) + else: + logging.warning( + "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!" + ) + if self._used_rowwise_adagrad_with_counter: if ( self._max_counter_update_freq > 0