Skip to content

Commit

Permalink
Add OptimType.NONE in SplitTBE (defuse bwd and optim) (#1819)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1819

This diff is the **backend** part

This diff introduces `OptimType.NONE`.  Unlike other `OptimType`s,
`OptimType.NONE` does not perform the optimizer step during SplitTBE's
backward pass.  With `OptimType.NONE`, SplitTBE deduplicates output
gradients in the backward pass and generates a sparse gradient tensor
(PyTorch's `sparse_coo_tensor`) for the device's weight (FQN:
`weights_dev`).

Currently, `OptimType.NONE` only supports the case where the embedding
dimensions of all embedding tables are identical.

Differential Revision: D44392172

fbshipit-source-id: b1264e5a5032ebad051d5c5b739dd9ffec1d8a92
  • Loading branch information
sryap authored and facebook-github-bot committed Jun 12, 2023
1 parent a60ccaa commit 91ecd67
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 182 deletions.
3 changes: 2 additions & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ set(GPU_ONLY_OPTIMIZERS
partial_rowwise_lamb
lars_sgd
rowwise_adagrad_with_weight_decay
approx_rowwise_adagrad_with_weight_decay)
approx_rowwise_adagrad_with_weight_decay
none)

set(DEPRECATED_OPTIMIZERS
approx_sgd
Expand Down
19 changes: 19 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,23 @@ def backward_dense() -> None:
)


def none_optimizer() -> None:
generate(
optimizer="none",
dense=False,
args=make_args(
[
(INT, "total_hash_size"),
(INT, "total_unique_indices"),
]
),
# Generate only GPU code
has_cpu_support=False,
has_gpu_support=True,
has_vbe_support=False,
)


def gen__init__py() -> None:
template = env.get_template("__init__.template")
src_py = template.render()
Expand Down Expand Up @@ -1670,6 +1687,8 @@ def emb_codegen(
rowwise_adagrad_with_counter()
rowwise_weighted_adagrad()
sgd()
none_optimizer()

gen__init__py()


Expand Down
Loading

0 comments on commit 91ecd67

Please sign in to comment.