Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
q10 committed Feb 12, 2024
1 parent c366391 commit 2513ba4
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 56 deletions.
42 changes: 0 additions & 42 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -143,35 +143,6 @@ run_fbgemm_gpu_tests () {
__configure_fbgemm_gpu_test_cuda
fi

# # Enable ROCM testing if specified
# if [ "$fbgemm_variant" == "rocm" ]; then
# echo "[TEST] Set environment variables for ROCm testing ..."
# # shellcheck disable=SC2086
# print_exec conda env config vars set ${env_prefix} FBGEMM_TEST_WITH_ROCM=1
# # shellcheck disable=SC2086
# print_exec conda env config vars set ${env_prefix} HIP_LAUNCH_BLOCKING=1
# fi

# # These are either non-tests or currently-broken tests in both FBGEMM_GPU and FBGEMM_GPU-CPU
# local files_to_skip=(
# ./ssd_split_table_batched_embeddings_test.py
# )

# if [ "$fbgemm_variant" == "cpu" ]; then
# # These tests have non-CPU operators referenced in @given
# local ignored_tests=(
# ./uvm/copy_test.py
# ./uvm/uvm_test.py
# )
# elif [ "$fbgemm_variant" == "rocm" ]; then
# local ignored_tests=(
# # https://github.com/pytorch/FBGEMM/issues/1559
# ./batched_unary_embeddings_test.py
# )
# else
# local ignored_tests=()
# fi

echo "[TEST] Installing pytest ..."
# shellcheck disable=SC2086
(exec_with_retries 3 conda install ${env_prefix} -y pytest expecttest) || return 1
Expand Down Expand Up @@ -204,19 +175,6 @@ run_fbgemm_gpu_tests () {
done
}

# for test_file in $all_test_files; do
# if echo "${files_to_skip[@]}" | grep "${test_file}"; then
# echo "[TEST] Skipping test file known to be broken: ${test_file}"
# elif echo "${ignored_tests[@]}" | grep "${test_file}"; then
# echo "[TEST] Skipping test file: ${test_file}"
# elif run_python_test "${env_name}" "${test_file}"; then
# echo ""
# else
# return 1
# fi
# done
# }


################################################################################
# FBGEMM_GPU Test Bulk-Combination Functions
Expand Down
6 changes: 0 additions & 6 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
gpu_unavailable,
gradcheck,
optests,
skipIfRocm,
symint_vector_unsupported,
use_cpu_strategy,
)
Expand All @@ -46,7 +45,6 @@
gpu_unavailable,
gradcheck,
optests,
skipIfRocm,
symint_vector_unsupported,
use_cpu_strategy,
)
Expand Down Expand Up @@ -1630,7 +1628,6 @@ def test_jagged_dense_dense_elementwise_add_jagged_output_dynamic_shape(

assert output.size() == output_ref.size()

@skipIfRocm()
@settings(
verbosity=Verbosity.verbose,
max_examples=20,
Expand Down Expand Up @@ -2370,7 +2367,6 @@ def test_jagged_softmax(

torch.testing.assert_close(values.grad, values_ref.grad)

@skipIfRocm()
@given(
B=st.integers(10, 512),
M=st.integers(1, 32),
Expand Down Expand Up @@ -2669,7 +2665,6 @@ def test_jagged_slice_errors(
)

@unittest.skipIf(*gpu_unavailable)
@skipIfRocm()
@given(
B=st.integers(min_value=100, max_value=200),
F=st.integers(min_value=50, max_value=100),
Expand Down Expand Up @@ -2774,7 +2769,6 @@ def test_jagged_unique_indices(
self.assertTrue((output_start <= pos) and (pos < output_end))

@unittest.skipIf(*gpu_unavailable)
@skipIfRocm()
@given(
B=st.integers(min_value=100, max_value=200),
F=st.integers(min_value=50, max_value=100),
Expand Down
2 changes: 0 additions & 2 deletions fbgemm_gpu/test/permute_pooled_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
gpu_unavailable,
on_arm_platform,
optests,
skipIfRocm,
)
else:
from fbgemm_gpu.test.test_utils import (
cpu_and_maybe_gpu,
gpu_unavailable,
on_arm_platform,
optests,
skipIfRocm,
)

typed_gpu_unavailable: Tuple[bool, str] = gpu_unavailable
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/test/sparse/index_select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_available, skipIfRocm
from test_utils import gpu_available
else:
import fbgemm_gpu.sparse_ops # noqa: F401, E402
from fbgemm_gpu.test.test_utils import gpu_available, skipIfRocm
from fbgemm_gpu.test.test_utils import gpu_available


class IndexSelectTest(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/test/sparse/pack_segments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_available, skipIfRocm
from test_utils import gpu_available
else:
from fbgemm_gpu.test.test_utils import gpu_available, skipIfRocm
from fbgemm_gpu.test.test_utils import gpu_available


def get_n_rand_num_summing_to_k(n: int, k: int) -> np.ndarray:
Expand Down
44 changes: 42 additions & 2 deletions fbgemm_gpu/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def use_cpu_strategy() -> st.SearchStrategy[bool]:
def skipIfRocm(reason: str = "Test currently doesn't work on the ROCm stack") -> Any:
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def skipIfRocmDecorator(fn: Callable) -> Any:
def decorator(fn: Callable) -> Any:
@wraps(fn)
# pyre-fixme[3]: Return annotation cannot be `Any`.
def wrapper(*args: Any, **kwargs: Any) -> Any:
Expand All @@ -196,7 +196,47 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:

return wrapper

return skipIfRocmDecorator
return decorator


# pyre-fixme[3]: Return annotation cannot be `Any`.
def skipIfRocmLessThan(min_version: int) -> Any:
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def decorator(testfn: Callable) -> Any:
@wraps(testfn)
# pyre-fixme[3]: Return annotation cannot be `Any`.
def wrapper(*args: Any, **kwargs: Any) -> Any:
ROCM_VERSION_FILEPATH = "/opt/rocm/.info/version-dev"
if TEST_WITH_ROCM:
# Fail if ROCm version file is missing.
if not os.path.isfile(ROCM_VERSION_FILEPATH):
raise AssertionError(
f"ROCm version file {ROCM_VERSION_FILEPATH} is missing!"
)

# Parse the version number from the file.
with open(ROCM_VERSION_FILEPATH, "r") as file:
version = file.read().strip()
version = version.replace("-", "").split(".")
version = (
int(version[0]) * 10000 + int(version[1]) * 100 + int(version[2])
)

# Fail if ROCm version is less than the minimum version.
if version < min_version:
raise unittest.SkipTest(
f"Skip the test since the ROCm version is less than {min_version}"
)
else:
testfn(*args, **kwargs)

else:
testfn(*args, **kwargs)

return wrapper

return decorator


def symint_vector_unsupported() -> Tuple[bool, str]:
Expand Down

0 comments on commit 2513ba4

Please sign in to comment.