Skip to content

Commit

Permalink
Support different placement for momentum and weights (#1787)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1787

As title

Reviewed By: jiaqizhai, jianyuh

Differential Revision: D46155038

fbshipit-source-id: c344d7302ecbc62522153c29f830c4fba9b908d8
  • Loading branch information
xing-liu authored and facebook-github-bot committed May 27, 2023
1 parent 314acde commit bf1700b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def construct_split_state(
cacheable: bool,
precision: SparseType = SparseType.FP32,
int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
placement: Optional[EmbeddingLocation] = None,
) -> SplitState:
placements: List[EmbeddingLocation] = []
offsets: List[int] = []
Expand All @@ -116,6 +117,7 @@ def construct_split_state(
if precision == SparseType.INT8:
embedding_dim += int8_emb_row_dim_offset
state_size = num_embeddings * embedding_dim if not rowwise else num_embeddings
location = placement if placement is not None else location
if location == EmbeddingLocation.HOST:
placements.append(EmbeddingLocation.HOST)
offsets.append(host_size)
Expand Down Expand Up @@ -206,6 +208,7 @@ def __init__( # noqa C901
pooling_mode: PoolingMode = PoolingMode.SUM,
device: Optional[Union[str, int, torch.device]] = None,
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
uvm_non_rowwise_momentum: bool = False, # place non-rowwise momentum on UVM
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -441,16 +444,19 @@ def __init__( # noqa C901
# NOTE: make TorchScript work!
self._register_nonpersistent_buffers("momentum1")
else:
rowwise = optimizer in [
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.ROWWISE_ADAGRAD,
OptimType.EXACT_ROWWISE_WEIGHTED_ADAGRAD,
]
self._apply_split(
construct_split_state(
embedding_specs,
rowwise=optimizer
in [
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.ROWWISE_ADAGRAD,
OptimType.EXACT_ROWWISE_WEIGHTED_ADAGRAD,
],
rowwise=rowwise,
cacheable=False,
placement=EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None,
),
prefix="momentum1",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
Expand All @@ -464,12 +470,18 @@ def __init__( # noqa C901
OptimType.LAMB,
OptimType.PARTIAL_ROWWISE_LAMB,
):
rowwise = optimizer in (
OptimType.PARTIAL_ROWWISE_ADAM,
OptimType.PARTIAL_ROWWISE_LAMB,
)
self._apply_split(
construct_split_state(
embedding_specs,
rowwise=optimizer
in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB),
rowwise=rowwise,
cacheable=False,
placement=EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None,
),
prefix="momentum2",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,7 @@ def execute_backward_optimizers_( # noqa C901
pooling_mode: PoolingMode,
use_cpu: bool,
weight_decay_mode: WeightDecayMode = WeightDecayMode.L2,
uvm_non_rowwise_momentum: bool = False,
) -> None:
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
assume(not use_cpu or T * B * L * D <= 2048)
Expand Down Expand Up @@ -2572,6 +2573,7 @@ def execute_backward_optimizers_( # noqa C901
],
optimizer=optimizer,
pooling_mode=pooling_mode,
uvm_non_rowwise_momentum=uvm_non_rowwise_momentum,
# pyre-fixme[6]: Expected `CacheAlgorithm` for 5th param but got `float`.
**optimizer_kwargs,
)
Expand Down Expand Up @@ -3003,6 +3005,7 @@ def get_wts_from_counter_adagrad(
else st.just(False)
if (gpu_available and TEST_WITH_ROCM)
else st.just(True),
uvm_non_rowwise_momentum=st.booleans(),
)
@settings(
verbosity=Verbosity.verbose,
Expand All @@ -3024,6 +3027,7 @@ def test_backward_optimizers_adam( # noqa C901
long_segments: bool,
pooling_mode: PoolingMode,
use_cpu: bool,
uvm_non_rowwise_momentum: bool,
) -> None:
self.execute_backward_optimizers_(
T,
Expand All @@ -3037,6 +3041,7 @@ def test_backward_optimizers_adam( # noqa C901
long_segments,
pooling_mode,
use_cpu,
uvm_non_rowwise_momentum=uvm_non_rowwise_momentum,
)

@given(
Expand Down

0 comments on commit bf1700b

Please sign in to comment.