Skip to content

Commit

Permalink
SymIntify torchrec variable batch size path (#2394)
Browse files Browse the repository at this point in the history
Summary:

Variable Batch parameters are SymInt in dynamo tracing.

SymInt does not support bit shifts => Skipping adjust_info_B_num_bits logic for dynamo case (when SymInt arrive into kernel) defaulting the values.

fbcode/deeplearning/fbgemm/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py - changing List comprehension with python sum to torch.sum. 

Removing int() convertions for SymInt.

Adding torch._check() for VB parameters

Reviewed By: ezyang

Differential Revision: D54554735
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Mar 6, 2024
1 parent 50ab9bd commit 0cf36ec
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
8 changes: 6 additions & 2 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,12 @@ class {{ autograd_func }} :
const auto uvm_cache_stats_ = uvm_cache_stats
.value_or(at::empty({0}, uvm_weights.options().dtype(at::kInt)));

// TODO: don't guard here
auto [info_B_num_bits, info_B_mask] = adjust_info_B_num_bits(max_B_.guard_int(__FILE__, __LINE__), T.guard_int(__FILE__, __LINE__));
int32_t info_B_num_bits = 26;
uint32_t info_B_mask = (1u << info_B_num_bits) - 1;
if (!max_B_.is_symbolic()) {
// TODO: don't guard here
auto [info_B_num_bits, info_B_mask] = adjust_info_B_num_bits(max_B_.guard_int(__FILE__, __LINE__), T.guard_int(__FILE__, __LINE__));
}

{%- if vbe %}
static auto generate_vbe_metadata_op =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Tensor
{%- if vbe %}
const Tensor& vbe_row_output_offsets,
const Tensor& vbe_b_t_map,
const int64_t vbe_output_size,
const c10::SymInt vbe_output_size,
const int64_t info_B_num_bits, // int32_t
const int64_t info_B_mask_int64, // uint32_t
{%- endif %}
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{%- if vbe %}
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" int vbe_output_size, "
" SymInt vbe_output_size, "
" int info_B_num_bits, "
" int info_B_mask_int64, "
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
except Exception:
pass


try:
from torch._dynamo import is_compiling as is_torchdynamo_compiling
except Exception:

def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


DEFAULT_ASSOC = 32 if torch.version.hip is None else 64
INT8_EMB_ROW_DIM_OFFSET = 8

Expand Down Expand Up @@ -944,11 +953,14 @@ def forward( # noqa: C901

# Create B offsets
total_batch_size_per_feature = torch.tensor(
[sum(batch_sizes) for batch_sizes in batch_size_per_feature_per_rank],
device="cpu",
dtype=torch.int32,
)
max_B = int(total_batch_size_per_feature.max().item())
batch_size_per_feature_per_rank, dtype=torch.int32, device="cpu"
).sum(dim=1)

max_B = total_batch_size_per_feature.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B)
torch._check(max_B <= offsets.size(0))

Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
B_offsets = Bs.cumsum(dim=0).to(torch.int)

Expand All @@ -958,7 +970,10 @@ def forward( # noqa: C901
device="cpu",
dtype=torch.int64,
)
max_B_feature_rank = int(B_feature_rank.max().item())
max_B_feature_rank = B_feature_rank.max().item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(max_B_feature_rank)
torch._check(max_B_feature_rank <= offsets.size(0))
# D->H only once
self.feature_dims = self.feature_dims.cpu()
output_sizes_feature_rank = B_feature_rank.transpose(
Expand All @@ -970,7 +985,9 @@ def forward( # noqa: C901
output_sizes_feature_rank.flatten().cumsum(dim=0),
]
)
output_size = int(output_offsets_feature_rank[-1].item())
output_size = output_offsets_feature_rank[-1].item()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
torch._check_is_size(output_size)

# TODO: Support INT8 output
# B_offsets_rank_per_feature is for rank and (b, t) mapping
Expand Down Expand Up @@ -1000,8 +1017,11 @@ def forward( # noqa: C901
B_offsets=B_offsets,
output_offsets_feature_rank=output_offsets_feature_rank,
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
# pyre-ignore
max_B=max_B,
# pyre-ignore
max_B_feature_rank=max_B_feature_rank,
# pyre-ignore
output_size=output_size,
)
else:
Expand Down

0 comments on commit 0cf36ec

Please sign in to comment.