Skip to content

Commit

Permalink
Re-organize Input Combine tests (#2399)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2399

- Re-organize Input Combine tests

Reviewed By: sryap

Differential Revision: D54645104

fbshipit-source-id: 528c66e298020748416e2982dc1137526b6eb4be
  • Loading branch information
q10 authored and facebook-github-bot committed Mar 8, 2024
1 parent 9751a61 commit 8cb1649
Show file tree
Hide file tree
Showing 8 changed files with 314 additions and 142 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/input_combine_ops/input_combine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cuda(
constexpr uint32_t IS_LONG_NUM_BITS = 32;
at::cuda::OptionalCUDAGuard device_guard(device);

// combined_indices and combined_legnths are int tensors
// combined_indices and combined_lengths are int tensors
const auto int_options = at::TensorOptions().dtype(at::kInt).device(
at::kCUDA, at::cuda::current_device());
Tensor combined_indices =
Expand Down
22 changes: 10 additions & 12 deletions fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
Expand All @@ -21,6 +17,10 @@
#include <c10/util/Exception.h>
#include <torch/script.h>

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;

namespace fbgemm_gpu {
Expand Down Expand Up @@ -384,14 +384,12 @@ padding_fused_tbe_input_combine_with_length_cpu(
auto combined_lengths = _cat_int_tensors_with_padding(
lengths_list, total_lengths, pin_memory, batch_size);

if (need_weights) {
return {
std::move(combined_indices),
std::move(combined_lengths),
_cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)};
}
return {combined_indices, combined_lengths, at::empty({0})};
auto combined_per_sample_weights = need_weights
? _cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)
: at::empty({0});

return {combined_indices, combined_lengths, combined_per_sample_weights};
}

} // namespace fbgemm_gpu
Expand Down
6 changes: 3 additions & 3 deletions fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>

#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;

namespace fbgemm_gpu {
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/test/combine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
125 changes: 125 additions & 0 deletions fbgemm_gpu/test/combine/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import List, Optional, Tuple

import fbgemm_gpu
import torch

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)


if not open_source:
if torch.version.hip:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip")
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu")


class TBEInputPrepareReference(torch.nn.Module):
def __init__(self, include_last_offsets: List[bool]) -> None:
super().__init__()
self.include_last_offsets = include_last_offsets

def forward( # noqa C901
self,
indices_list: List[torch.Tensor],
offsets_list: List[torch.Tensor],
per_sample_weights_list: List[torch.Tensor],
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
size = 0
assert len(indices_list) > 0
assert len(indices_list) == len(offsets_list)
assert len(indices_list) == len(per_sample_weights_list)
assert len(indices_list) == len(self.include_last_offsets)
for i in range(len(self.include_last_offsets)):
size += indices_list[i].size(0)
assert indices_list[i].dim() == 1
assert offsets_list[i].dim() == 1
if per_sample_weights_list[i].numel() > 0:
assert per_sample_weights_list[i].dim() == 1
assert indices_list[i].numel() == per_sample_weights_list[i].numel()
combined_indices = torch.empty(
size,
dtype=torch.int32,
device=indices_list[0].device,
)
torch.cat(indices_list, out=combined_indices)
offsets_starts = torch.zeros(
[len(offsets_list) + 1],
dtype=offsets_list[0].dtype,
device=offsets_list[0].device,
)
offsets_accs = torch.zeros(
[len(offsets_list) + 1],
dtype=offsets_list[0].dtype,
device=offsets_list[0].device,
)

for i, include_last_offset in enumerate(self.include_last_offsets):
if include_last_offset:
offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0) - 1
else:
offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0)
offsets_accs[i + 1] = offsets_accs[i] + indices_list[i].size(0)

assert offsets_accs[-1] == combined_indices.size(0)
combined_offsets_size: List[int] = (
[int(offsets_starts[-1].item()) + 1]
if batch_size is None
else [batch_size * len(offsets_list) + 1]
)
combined_offsets = torch.zeros(
combined_offsets_size,
dtype=torch.int32,
device=offsets_list[0].device,
)
if batch_size is None:
for i in range(len(self.include_last_offsets)):
combined_offsets[offsets_starts[i] : offsets_starts[i + 1]] = (
offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]]
+ offsets_accs[i]
)
else:
for i in range(len(self.include_last_offsets)):
cur_start = batch_size * i
combined_offsets[
cur_start : cur_start + offsets_starts[i + 1] - offsets_starts[i]
] = (
offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]]
+ offsets_accs[i]
)
cur_start = cur_start + offsets_starts[i + 1] - offsets_starts[i]
for j in range(batch_size - offsets_starts[i + 1] + offsets_starts[i]):
combined_offsets[cur_start + j] = (
indices_list[i].numel() + offsets_accs[i]
)
combined_offsets[-1] = offsets_accs[-1]
per_sample_weights: Optional[torch.Tensor] = None
for i in range(len(self.include_last_offsets)):
if per_sample_weights_list[i].size(0) > 0:
per_sample_weights = torch.ones(
combined_indices.size(0),
dtype=per_sample_weights_list[i].dtype,
device=per_sample_weights_list[i].device,
)
break
if per_sample_weights is not None:
for i in range(len(self.include_last_offsets)):
if per_sample_weights_list[i].size(0) > 0:
# fmt: off
per_sample_weights[offsets_accs[i] : offsets_accs[i + 1]] = (
per_sample_weights_list[i][:]
)
# fmt: on

# indices and offsets are required to be int32 for TBE
return combined_indices, combined_offsets, per_sample_weights
81 changes: 81 additions & 0 deletions fbgemm_gpu/test/combine/empty_weights_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from hypothesis import given, settings

from .common import open_source

if open_source:
# pyre-ignore[21]
from test_utils import cpu_and_maybe_gpu, optests
else:
from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, optests


@optests.generate_opcheck_tests()
class EmptyWeightsTest(unittest.TestCase):
@unittest.skip("Fix is not implemented yet")
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@given(device=cpu_and_maybe_gpu())
@settings(deadline=None)
def test_tbe_input_combine_with_length_empty_weights(
self, device: torch.device
) -> None:
arg0_list = [
[88, 55],
[80, 29],
[2, 85],
[39, 51],
[84, 35],
[12, 6],
[94, 43],
[98, 59],
[19, 68],
[97, 89],
]
arg0 = [torch.tensor(t, dtype=torch.int32, device=device) for t in arg0_list]

arg1_list = [
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
]
arg1 = [torch.tensor(t, dtype=torch.int32, device=device) for t in arg1_list]

arg2_list = [
[],
[],
[],
[],
[3.0, 3.0],
[],
[],
[3.0, 3.0],
[3.0, 3.0],
[],
]
arg2 = [torch.tensor(t, dtype=torch.float, device=device) for t in arg2_list]

torch.ops.fbgemm.tbe_input_combine_with_length(
arg0,
arg1,
arg2,
)


if __name__ == "__main__":
unittest.main()
85 changes: 85 additions & 0 deletions fbgemm_gpu/test/combine/failures_dict.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
{
"_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit",
"_version": 1,
"data": {
"fbgemm::padding_fused_tbe_input_combine": {
"InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::padding_fused_tbe_input_combine_with_length": {
"InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int32_with_length": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combine_int64_with_length": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_aot_dispatch_dynamic__test_padding_fused_input_combined_mix_with_length": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_faketensor__test_padding_fused_input_combine_int32_with_length": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_faketensor__test_padding_fused_input_combine_int64_with_length": {
"comment": "",
"status": "xfail"
},
"InputCombineTest.test_faketensor__test_padding_fused_input_combined_mix_with_length": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::tbe_input_combine": {},
"fbgemm::tbe_input_combine_with_length": {
"InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int32_with_length": {
"comment": "",
"status": "xsuccess"
},
"InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_int64_with_length": {
"comment": "",
"status": "xsuccess"
},
"InputCombineTest.test_aot_dispatch_dynamic__test_input_combine_mix_with_length": {
"comment": "",
"status": "xsuccess"
},
"InputCombineTest.test_faketensor__test_input_combine_int32_with_length": {
"comment": "",
"status": "xsuccess"
},
"InputCombineTest.test_faketensor__test_input_combine_int64_with_length": {
"comment": "",
"status": "xsuccess"
},
"InputCombineTest.test_faketensor__test_input_combine_mix_with_length": {
"comment": "",
"status": "xsuccess"
}
}
}
}
Loading

0 comments on commit 8cb1649

Please sign in to comment.