Skip to content

Commit

Permalink
Enable TBE VBE with SGD
Browse files Browse the repository at this point in the history
Summary: As titled

Differential Revision: D47575985

fbshipit-source-id: eb669978be976f20eea2b8a3fc484507cfc25adc
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 16, 2023
1 parent a645ff8 commit d57bba6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_common_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def sgd() -> Dict[str, Any]:
"split_weight_update_cpu": split_weight_update_cpu,
"has_cpu_support": True,
"has_gpu_support": True,
"has_vbe_support": False,
"has_vbe_support": True,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,9 +834,10 @@ def forward( # noqa: C901
total_unique_indices: Optional[int] = None,
) -> Tensor:
if batch_size_per_feature_per_rank is not None:
assert (
self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD only"
assert self.optimizer in [
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.EXACT_SGD,
], "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD only"
assert (
self.pooling_mode != PoolingMode.NONE.value
), "Variable batch size TBE support is not enabled for PoolingMode.NONE"
Expand Down
8 changes: 6 additions & 2 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,7 @@ def execute_backward_sgd_( # noqa C901
weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]),
weighted=st.booleans(),
mixed=st.booleans(),
mixed_B=st.booleans(),
use_cache=st.booleans(),
cache_algorithm=st.sampled_from(CacheAlgorithm),
long_segments=st.booleans(),
Expand Down Expand Up @@ -2138,6 +2139,7 @@ def test_backward_sgd( # noqa C901
weights_precision: SparseType,
weighted: bool,
mixed: bool,
mixed_B: bool,
use_cache: bool,
cache_algorithm: CacheAlgorithm,
long_segments: bool,
Expand All @@ -2153,7 +2155,7 @@ def test_backward_sgd( # noqa C901
weights_precision,
weighted,
mixed,
False, # mixed_B
mixed_B if not use_cpu else False,
use_cache,
cache_algorithm,
long_segments,
Expand All @@ -2171,6 +2173,7 @@ def test_backward_sgd( # noqa C901
L=st.integers(min_value=1, max_value=4),
weighted=st.booleans(),
mixed=st.booleans(),
mixed_B=st.booleans(),
use_cache=st.booleans(),
cache_algorithm=st.sampled_from(CacheAlgorithm),
)
Expand All @@ -2188,6 +2191,7 @@ def test_backward_sgd_really_long_segments( # noqa C901
L: int,
weighted: bool,
mixed: bool,
mixed_B: bool,
use_cache: bool,
cache_algorithm: CacheAlgorithm,
) -> None:
Expand All @@ -2200,7 +2204,7 @@ def test_backward_sgd_really_long_segments( # noqa C901
SparseType.FP32, # weights_precision
weighted,
mixed,
False, # mixed_B
mixed_B,
use_cache,
cache_algorithm,
True, # long_segments
Expand Down

0 comments on commit d57bba6

Please sign in to comment.