Skip to content

Commit

Permalink
Enable TBE VBE with SGD (#1943)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1943

As titled

Differential Revision: D47575985

fbshipit-source-id: 58c63968edcba5e43edf3e262891aa73f6ae997f
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 16, 2023
1 parent a645ff8 commit 5ef984b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_common_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def sgd() -> Dict[str, Any]:
"split_weight_update_cpu": split_weight_update_cpu,
"has_cpu_support": True,
"has_gpu_support": True,
"has_vbe_support": False,
"has_vbe_support": True,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,7 @@ def forward( # noqa: C901
if batch_size_per_feature_per_rank is not None:
assert (
self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
or self.optimizer == OptimType.EXACT_SGD
), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD only"
assert (
self.pooling_mode != PoolingMode.NONE.value
Expand Down
8 changes: 6 additions & 2 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,7 @@ def execute_backward_sgd_( # noqa C901
weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]),
weighted=st.booleans(),
mixed=st.booleans(),
mixed_B=st.booleans(),
use_cache=st.booleans(),
cache_algorithm=st.sampled_from(CacheAlgorithm),
long_segments=st.booleans(),
Expand Down Expand Up @@ -2138,6 +2139,7 @@ def test_backward_sgd( # noqa C901
weights_precision: SparseType,
weighted: bool,
mixed: bool,
mixed_B: bool,
use_cache: bool,
cache_algorithm: CacheAlgorithm,
long_segments: bool,
Expand All @@ -2153,7 +2155,7 @@ def test_backward_sgd( # noqa C901
weights_precision,
weighted,
mixed,
False, # mixed_B
mixed_B if not use_cpu else False,
use_cache,
cache_algorithm,
long_segments,
Expand All @@ -2171,6 +2173,7 @@ def test_backward_sgd( # noqa C901
L=st.integers(min_value=1, max_value=4),
weighted=st.booleans(),
mixed=st.booleans(),
mixed_B=st.booleans(),
use_cache=st.booleans(),
cache_algorithm=st.sampled_from(CacheAlgorithm),
)
Expand All @@ -2188,6 +2191,7 @@ def test_backward_sgd_really_long_segments( # noqa C901
L: int,
weighted: bool,
mixed: bool,
mixed_B: bool,
use_cache: bool,
cache_algorithm: CacheAlgorithm,
) -> None:
Expand All @@ -2200,7 +2204,7 @@ def test_backward_sgd_really_long_segments( # noqa C901
SparseType.FP32, # weights_precision
weighted,
mixed,
False, # mixed_B
mixed_B,
use_cache,
cache_algorithm,
True, # long_segments
Expand Down

0 comments on commit 5ef984b

Please sign in to comment.