Skip to content

Commit

Permalink
Work around offsets and indices type mismatch int TBE training (#3037)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#135

Pull Request resolved: #3037

D61524189 added `bounds_check_indices` in the `prefetch` function.  It
expects `indices` and `offsets` to have the same type.  However, that
is not always the case.  This diff forces type casting on `indices`
and `offsets` if they do not have the same type.  A long term solution
is to update `bounds_check_indices` to support `indices` and `offsets`
that have different types.

Reviewed By: PaulZhang12

Differential Revision: D61801695
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 26, 2024
1 parent 27db8cf commit 324888f
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2638,6 +2638,13 @@ def prepare_inputs(
offsets, batch_size_per_feature_per_rank
)

# TODO: remove this and add an assert after updating
# bounds_check_indices to support different indices type and offset
# type
force_cast_input_types = (
indices.dtype != offsets.dtype or force_cast_input_types
)

if force_cast_input_types:
# Force casting indices and offsets to long
(indices, offsets) = indices.long(), offsets.long()
Expand All @@ -2646,10 +2653,6 @@ def prepare_inputs(
if per_sample_weights is not None:
per_sample_weights = per_sample_weights.float()

assert (
indices.dtype == offsets.dtype
), "Indices and offsets must have the same type"

if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
torch.ops.fbgemm.bounds_check_indices(
self.rows_per_table,
Expand Down

0 comments on commit 324888f

Please sign in to comment.