Skip to content

Commit

Permalink
Support different placement for momentum and weights (pytorch#1787)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1787

As title

Differential Revision: D46155038

fbshipit-source-id: 58a29f28a405196af72e232e187ed5e8ee75dbff
  • Loading branch information
xing-liu authored and facebook-github-bot committed May 26, 2023
1 parent f9883cd commit 1077429
Showing 1 changed file with 20 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def construct_split_state(
cacheable: bool,
precision: SparseType = SparseType.FP32,
int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET,
enforce_placement: Optional[EmbeddingLocation] = None,
) -> SplitState:
placements: List[EmbeddingLocation] = []
offsets: List[int] = []
Expand All @@ -116,6 +117,7 @@ def construct_split_state(
if precision == SparseType.INT8:
embedding_dim += int8_emb_row_dim_offset
state_size = num_embeddings * embedding_dim if not rowwise else num_embeddings
location = enforce_placement if enforce_placement is not None else location
if location == EmbeddingLocation.HOST:
placements.append(EmbeddingLocation.HOST)
offsets.append(host_size)
Expand Down Expand Up @@ -206,6 +208,7 @@ def __init__( # noqa C901
pooling_mode: PoolingMode = PoolingMode.SUM,
device: Optional[Union[str, int, torch.device]] = None,
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
uvm_non_rowwise_momentum: bool = False, # place non-rowwise momentum on UVM
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -441,16 +444,19 @@ def __init__( # noqa C901
# NOTE: make TorchScript work!
self._register_nonpersistent_buffers("momentum1")
else:
rowwise = optimizer in [
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.ROWWISE_ADAGRAD,
OptimType.EXACT_ROWWISE_WEIGHTED_ADAGRAD,
]
self._apply_split(
construct_split_state(
embedding_specs,
rowwise=optimizer
in [
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.ROWWISE_ADAGRAD,
OptimType.EXACT_ROWWISE_WEIGHTED_ADAGRAD,
],
rowwise=rowwise,
cacheable=False,
enforce_placement=EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None,
),
prefix="momentum1",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
Expand All @@ -464,12 +470,18 @@ def __init__( # noqa C901
OptimType.LAMB,
OptimType.PARTIAL_ROWWISE_LAMB,
):
rowwise = optimizer in (
OptimType.PARTIAL_ROWWISE_ADAM,
OptimType.PARTIAL_ROWWISE_LAMB,
)
self._apply_split(
construct_split_state(
embedding_specs,
rowwise=optimizer
in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB),
rowwise=rowwise,
cacheable=False,
enforce_placement=EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None,
),
prefix="momentum2",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
Expand Down

0 comments on commit 1077429

Please sign in to comment.