Skip to content

Commit

Permalink
Refactor generate_vbe_metadata (#3087)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#178

Pull Request resolved: #3087

Moves `generate_vbe_metadata` into the
`fbgemm_gpu.split_table_batched_embeddings_ops_training_common`. This
is a preparation for VBE enablement in SSD-TBE

Reviewed By: q10

Differential Revision: D62215222

fbshipit-source-id: 4db9d0c097f3f9b7aaf25c49ef777cf33a5c718d
  • Loading branch information
sryap authored and facebook-github-bot committed Sep 9, 2024
1 parent 826128b commit adb9d0f
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 136 deletions.
162 changes: 26 additions & 136 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
RecordCacheMetrics,
SplitState,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import (
generate_vbe_metadata,
is_torchdynamo_compiling,
)

try:
if torch.version.hip:
Expand All @@ -62,23 +66,6 @@
pass


try:
try:
from torch.compiler import is_compiling

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
# at least one test fails if we import is_compiling as a different name
return is_compiling()

except Exception:
# torch.compiler.is_compiling is not available in torch 1.10
from torch._dynamo import is_compiling as is_torchdynamo_compiling
except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


DEFAULT_ASSOC = 32 if torch.version.hip is None else 64
INT8_EMB_ROW_DIM_OFFSET = 8

Expand Down Expand Up @@ -334,125 +321,6 @@ def apply_split_helper(
)


def generate_vbe_metadata(
offsets: Tensor,
batch_size_per_feature_per_rank: Optional[List[List[int]]],
optimizer: OptimType,
pooling_mode: PoolingMode,
feature_dims_cpu: Tensor,
device: torch.device,
) -> invokers.lookup_args.VBEMetadata:
"""
Generate VBE metadata based on batch_size_per_feature_per_rank.
Metadata includes:
1) B_offsets - A tensor that contains batch size offsets for each
feature
2) output_offsets_feature_rank - A tensor that contains output
offsets for each feature
3) B_offsets_per_rank_per_feature - A tensor that contains batch
size offsets for each feature
and rank
4) max_B - The maximum batch size for all features
5) max_B_feature_rank - The maximum batch size for all ranks and
features
6) output_size - The output size (number of elements)
"""
if batch_size_per_feature_per_rank is not None:
assert optimizer in (
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.EXACT_SGD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.NONE,
), "Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD and ENSEMBLE_ROWWISE_ADAGRAD only"
assert (
pooling_mode != PoolingMode.NONE
), "Variable batch size TBE support is not enabled for PoolingMode.NONE"
# TODO: Add input check
zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32)

# Create B offsets
total_batch_size_per_feature = torch.tensor(
batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu"
).sum(dim=1)

max_B = total_batch_size_per_feature.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B)
torch._check(max_B < offsets.numel())

Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
B_offsets = Bs.cumsum(dim=0).to(torch.int)

# Create output offsets
B_feature_rank = torch.tensor(
batch_size_per_feature_per_rank,
device="cpu",
dtype=torch.int64,
)
max_B_feature_rank = B_feature_rank.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B_feature_rank)
torch._check(max_B_feature_rank <= offsets.size(0))
output_sizes_feature_rank = B_feature_rank.transpose(
0, 1
) * feature_dims_cpu.view(1, -1)
output_offsets_feature_rank = torch.concat(
[
zero_tensor.to(torch.int64),
output_sizes_feature_rank.flatten().cumsum(dim=0),
]
)
output_size = output_offsets_feature_rank[-1].item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(output_size)

# TODO: Support INT8 output
# B_offsets_rank_per_feature is for rank and (b, t) mapping
B_offsets_rank_per_feature = (
torch.tensor(
[
[0] + batch_size_per_feature
for batch_size_per_feature in batch_size_per_feature_per_rank
],
device="cpu",
dtype=torch.int32,
)
.cumsum(dim=1)
.to(torch.int)
)

B_offsets = B_offsets.to(device, non_blocking=True)
output_offsets_feature_rank = output_offsets_feature_rank.to(
device, non_blocking=True
)
B_offsets_rank_per_feature = B_offsets_rank_per_feature.to(
device, non_blocking=True
)

# TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=B_offsets,
output_offsets_feature_rank=output_offsets_feature_rank,
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
# pyre-ignore
max_B=max_B,
# pyre-ignore
max_B_feature_rank=max_B_feature_rank,
# pyre-ignore
output_size=output_size,
)
else:
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
B_offsets_rank_per_feature=None,
max_B=-1,
max_B_feature_rank=-1,
output_size=-1,
)
return vbe_metadata


# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
Expand Down Expand Up @@ -1379,6 +1247,17 @@ def _generate_vbe_metadata(
) -> invokers.lookup_args.VBEMetadata:
# Blocking D2H copy, but only runs at first call
self.feature_dims = self.feature_dims.cpu()
if batch_size_per_feature_per_rank is not None:
assert self.optimizer in (
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.EXACT_SGD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.NONE,
), (
"Variable batch size TBE support is enabled for "
"OptimType.EXACT_ROWWISE_ADAGRAD and "
"ENSEMBLE_ROWWISE_ADAGRAD only"
)
return generate_vbe_metadata(
offsets,
batch_size_per_feature_per_rank,
Expand Down Expand Up @@ -3043,6 +2922,17 @@ def _generate_vbe_metadata(
) -> invokers.lookup_args.VBEMetadata:
# Blocking D2H copy, but only runs at first call
self.feature_dims = self.feature_dims.cpu()
if batch_size_per_feature_per_rank is not None:
assert self.optimizer in (
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.EXACT_SGD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.NONE,
), (
"Variable batch size TBE support is enabled for "
"OptimType.EXACT_ROWWISE_ADAGRAD and "
"ENSEMBLE_ROWWISE_ADAGRAD only"
)
return generate_vbe_metadata(
offsets,
batch_size_per_feature_per_rank,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers

import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode
from torch import Tensor

try:
try:
from torch.compiler import is_compiling

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
# at least one test fails if we import is_compiling as a different name
return is_compiling()

except Exception:
# torch.compiler.is_compiling is not available in torch 1.10
from torch._dynamo import is_compiling as is_torchdynamo_compiling
except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


def generate_vbe_metadata(
offsets: Tensor,
batch_size_per_feature_per_rank: Optional[List[List[int]]],
optimizer: OptimType,
pooling_mode: PoolingMode,
feature_dims_cpu: Tensor,
device: torch.device,
) -> invokers.lookup_args.VBEMetadata:
"""
Generate VBE metadata based on batch_size_per_feature_per_rank.
Metadata includes:
1) B_offsets - A tensor that contains batch size offsets for each
feature
2) output_offsets_feature_rank - A tensor that contains output
offsets for each feature
3) B_offsets_per_rank_per_feature - A tensor that contains batch
size offsets for each feature
and rank
4) max_B - The maximum batch size for all features
5) max_B_feature_rank - The maximum batch size for all ranks and
features
6) output_size - The output size (number of elements)
"""
if batch_size_per_feature_per_rank is not None:
assert (
pooling_mode != PoolingMode.NONE
), "Variable batch size TBE support is not enabled for PoolingMode.NONE"
# TODO: Add input check
zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32)

# Create B offsets
total_batch_size_per_feature = torch.tensor(
batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu"
).sum(dim=1)

max_B = total_batch_size_per_feature.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B)
torch._check(max_B < offsets.numel())

Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
B_offsets = Bs.cumsum(dim=0).to(torch.int)

# Create output offsets
B_feature_rank = torch.tensor(
batch_size_per_feature_per_rank,
device="cpu",
dtype=torch.int64,
)
max_B_feature_rank = B_feature_rank.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B_feature_rank)
torch._check(max_B_feature_rank <= offsets.size(0))
output_sizes_feature_rank = B_feature_rank.transpose(
0, 1
) * feature_dims_cpu.view(1, -1)
output_offsets_feature_rank = torch.concat(
[
zero_tensor.to(torch.int64),
output_sizes_feature_rank.flatten().cumsum(dim=0),
]
)
output_size = output_offsets_feature_rank[-1].item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(output_size)

# TODO: Support INT8 output
# B_offsets_rank_per_feature is for rank and (b, t) mapping
B_offsets_rank_per_feature = (
torch.tensor(
[
[0] + batch_size_per_feature
for batch_size_per_feature in batch_size_per_feature_per_rank
],
device="cpu",
dtype=torch.int32,
)
.cumsum(dim=1)
.to(torch.int)
)

B_offsets = B_offsets.to(device, non_blocking=True)
output_offsets_feature_rank = output_offsets_feature_rank.to(
device, non_blocking=True
)
B_offsets_rank_per_feature = B_offsets_rank_per_feature.to(
device, non_blocking=True
)

# TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=B_offsets,
output_offsets_feature_rank=output_offsets_feature_rank,
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
# pyre-ignore
max_B=max_B,
# pyre-ignore
max_B_feature_rank=max_B_feature_rank,
# pyre-ignore
output_size=output_size,
)
else:
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
B_offsets_rank_per_feature=None,
max_B=-1,
max_B_feature_rank=-1,
output_size=-1,
)
return vbe_metadata

0 comments on commit adb9d0f

Please sign in to comment.