Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm changes #1102

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
59267a9
Hipification of fbgemm for AMD GPUs/CPUs (#4)
jithunnair-amd Jan 25, 2022
a223936
Use SHEFL_SYNC_MACRO to replace __shefl() and __shefl_sync()
liligwu Jan 26, 2022
4610075
Merge pull request #6 from ROCmSoftwarePlatform/rocm4.3/develop
liligwu Jan 26, 2022
a506c52
Change the hipify dependency to hipify_torch (#7)
liligwu Jan 31, 2022
f596bde
IFU, merge from upstream commit c6df576 to main. (#8)
liligwu Feb 14, 2022
0cfb792
Enable `split_table_batched_embeddings_test.py` (#10)
liligwu Mar 2, 2022
f13af44
*Enable use_cache. *Enable split_embedding_inference_converter_test.p…
liligwu Mar 7, 2022
25e5b71
Skip use_cpu.
liligwu Mar 7, 2022
dcbe19f
Enable test_nbit_cache_pipeline and test_cache_miss_counter.
liligwu Mar 7, 2022
fda048e
Enable quantize_ops_test.py
liligwu Mar 7, 2022
00abba1
Merge branch 'main' into use_cache_enabled
liligwu Mar 7, 2022
cf307b6
Remove @skipIfRocm for test_nbit_cache_pipeline and test_cache_miss_c…
liligwu Mar 7, 2022
2d66ea8
*Uncondition use_cache in split_table_batched_embeddings_test.py *Rem…
liligwu Mar 7, 2022
958679b
Merge pull request #11 from ROCmSoftwarePlatform/use_cache_enabled
amathews-amd Mar 8, 2022
e642a48
Fix backward tests and test_cache_pipeline in split_table_batched_emb…
liligwu Mar 8, 2022
d0d294a
A minor change of removing a commented line.
liligwu Mar 8, 2022
146f2df
Remove skipIfRocm import in split_table_batched_embeddings_test.py.
liligwu Mar 8, 2022
eb0cf36
Merge pull request #12 from ROCmSoftwarePlatform/fix_backward
amathews-amd Mar 8, 2022
0c86f2b
*Removed post_hipify logic in setup.py. *Removed two headerfiles that…
liligwu Mar 11, 2022
6e7f13e
Merge pull request #16 from ROCmSoftwarePlatform/remove_post_hipify
amathews-amd Mar 11, 2022
edd3306
Pointing hipify_torch to the newer commit.
liligwu Mar 14, 2022
9a45f4a
Merge pull request #17 from ROCmSoftwarePlatform/pointing_hipify_torc…
amathews-amd Mar 14, 2022
309a3a1
Fixing #include <ATen/CUDAGeneratorImpl.h> by defining NEW_GENERATOR_…
liligwu Mar 16, 2022
358eaf5
Disabling all use_cpu in the tests. (#20)
liligwu Mar 16, 2022
3a915a8
Change py3.8 syntax to py3.7 syntax (#18)
pruthvistony Mar 16, 2022
40928ba
Match upstream setup (#21)
liligwu Mar 31, 2022
69abf78
Enable merge_pooled_embeddings op. in ROCm (#15)
reza-amd Apr 1, 2022
5c0096e
Merge remote-tracking branch 'upstream/main' into IFU-main-2022-04-07
liligwu Apr 14, 2022
bfac874
Fixing test_lxu_cache_lookup in AMD devices where warp_siize=64
liligwu Apr 14, 2022
1cf7e84
* Enabling the specificationn of hip architecture by using PYTORCH_RO…
liligwu Apr 15, 2022
5b33287
*Fixing the unit tests in sparse_ops_test.py. *Fixing the path of Ato…
liligwu Apr 19, 2022
2c514c5
Merge pull request #23 from ROCmSoftwarePlatform/IFU-main-2022-04-07
pruthvistony Apr 19, 2022
0d5a012
Enable use_cpu in the tests.
liligwu Apr 20, 2022
ae14a47
Merge remote-tracking branch 'upstream/main' into IFU-main-2022-04-20
liligwu Apr 20, 2022
1718605
*Taking @skipIfRocm back in the test_utils.py. *Fixing cublasGemmStri…
liligwu Apr 20, 2022
bc902a3
Cleaning up the code.
liligwu Apr 20, 2022
0d95948
Merge pull request #24 from ROCmSoftwarePlatform/IFU-main-2022-04-20
pruthvistony Apr 21, 2022
9a5a33b
Enabling cuda (#25)
liligwu Apr 21, 2022
6490dbc
Enabling cuda (#25)
liligwu Apr 21, 2022
77627ae
Merge branch 'main' of https://github.com/ROCmSoftwarePlatform/FBGEMM…
liligwu Apr 22, 2022
18b48e9
Merge remote-tracking branch 'upstream/main' into IFU-main-2022-05-02
liligwu May 2, 2022
99a70e1
Merge pull request #2 from ROCmSoftwarePlatform/IFU-main-2022-05-02
liligwu May 4, 2022
fed56ff
Merge branch 'main' into rocm_changes
liligwu May 4, 2022
4b39a70
Merge branch 'upstream_main' into rocm_changes
liligwu May 5, 2022
785afb8
Removing building and testing bash scripts.
liligwu May 5, 2022
bbd0ad1
* Addressing the comments in PR review ROCm changes #1102. * Reoganiz…
liligwu May 9, 2022
9db83d8
Minor changes that minimize the difference to upstream.
liligwu May 9, 2022
eabd0a8
A minor change on a blank line.
liligwu May 9, 2022
2038008
Fixing indentation and commented code in CMakeList.txt
liligwu May 10, 2022
0202078
Removing build script.
liligwu May 10, 2022
9cf8856
Addressing the second batch of comments of https://github.com/pytorch…
liligwu May 11, 2022
b885322
* Removing the condition on c++ standard * An indentation correction
liligwu May 12, 2022
0e3dfdb
* Changing the logic of detecting GPU vender, making CUDA as default.…
liligwu May 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "third_party/googletest"]
path = third_party/googletest
url = https://github.com/google/googletest
[submodule "third_party/hipify_torch"]
liligwu marked this conversation as resolved.
Show resolved Hide resolved
path = third_party/hipify_torch
url = https://github.com/ROCmSoftwarePlatform/hipify_torch.git
96 changes: 83 additions & 13 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,27 @@ message("${message_line}")
if(SKBUILD)
message("The project is built using scikit-build")
endif()

if(EXISTS "/usr/bin/nvidia-smi")
message("NVIDIA GPU detected.")
option(USE_CUDA "Use CUDA" ON)
option(USE_ROCM "Use ROCm" OFF)
elseif(EXISTS "/opt/rocm/bin/rocm-smi")
message("AMD GPU detected.")
option(USE_CUDA "Use CUDA" OFF)
option(USE_ROCM "Use ROCm" ON)
else()
message("Unable to detect GPU vendor")
jianyuh marked this conversation as resolved.
Show resolved Hide resolved
message(FATAL_ERROR "")
endif()

if(FBGEMM_CPU_ONLY)
message("Building for CPU-only")
endif()

message("${message_line}")

if(FBGEMM_CPU_ONLY)
if(FBGEMM_CPU_ONLY OR USE_ROCM)
project(
fbgemm_gpu
VERSION 0.0.1
Expand Down Expand Up @@ -55,6 +69,16 @@ set(TORCH_CUDA_OPTIONS
-D__CUDA_NO_BFLOAT16_CONVERSIONS__
-D__CUDA_NO_HALF2_OPERATORS__)


if(USE_ROCM)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" "${THIRDPARTY}/hipify_torch/cmake")
include(Hip)
include(Hipify)

message("${message_line}")
message(STATUS "hip found ${ROCM_FOUND}")
endif()

#
# GENERATED CUDA, CPP and Python code
#
Expand Down Expand Up @@ -147,14 +171,30 @@ set(codegen_dependencies
${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/sparse_ops_utils.h
)

add_custom_command(
OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files}
${gen_gpu_host_source_files} ${gen_python_files}
COMMAND
if(USE_ROCM)
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}"
"${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py"
"--opensource"
DEPENDS "${codegen_dependencies}")
DEPENDS "${codegen_dependencies}")

set(header_include_dir
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/src
${CMAKE_CURRENT_SOURCE_DIR})

hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR} HEADER_INCLUDE_DIR ${header_include_dir})
else()
add_custom_command(
OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files}
${gen_gpu_host_source_files} ${gen_python_files}
COMMAND
"${PYTHON_EXECUTABLE}"
"${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py"
"--opensource"
DEPENDS "${codegen_dependencies}")
endif()

set_source_files_properties(
${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS
Expand Down Expand Up @@ -216,8 +256,12 @@ set_source_files_properties(
PROPERTIES COMPILE_OPTIONS
"-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl")

set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2}
${cpp_fbgemm_files_avx512})
if(USE_ROCM)
set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2})
else()
set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2}
${cpp_fbgemm_files_avx512})
endif()

set(cpp_fbgemm_files_include_directories
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include
Expand Down Expand Up @@ -312,15 +356,39 @@ else()
set(fbgemm_gpu_sources ${fbgemm_gpu_sources_cpu})
endif()

if(USE_ROCM)
set(abspath_gen_source_files)
foreach(filename_gen_source_file ${gen_source_files})
list(APPEND abspath_gen_source_files "${CMAKE_BINARY_DIR}/${filename_gen_source_file}")
endforeach()
endif()

#
# MODULE
#

add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files}
${cpp_asmjit_files} ${cpp_fbgemm_files})
if(USE_ROCM)
get_hipified_list("${fbgemm_gpu_sources}" fbgemm_gpu_sources)
get_hipified_list("${abspath_gen_source_files}" abspath_gen_source_files)
get_hipified_list("${cpp_fbgemm_files}" cpp_fbgemm_files)

if(NOT FBGEMM_CPU_ONLY)
target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE)
set(FBGEMM_ALL_HIP_FILES ${fbgemm_gpu_sources} ${abspath_gen_source_files} ${cpp_fbgemm_files})
set_source_files_properties(${FBGEMM_ALL_HIP_FILES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
hip_include_directories("${cpp_fbgemm_files_include_directories}")

hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES}
HIPCC_OPTIONS ${HIP_HCC_FLAGS})
target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE})
list(GET TORCH_INCLUDE_DIRS 0 TORCH_PATH)
else()
add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files}
${cpp_asmjit_files} ${cpp_fbgemm_files})
set_property(TARGET fbgemm_gpu_py PROPERTY CUDA_ARCHITECTURES
"${cuda_architectures}")

if(NOT FBGEMM_CPU_ONLY)
target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE)
endif()
endif()

set_target_properties(fbgemm_gpu_py PROPERTIES PREFIX "")
Expand All @@ -330,7 +398,9 @@ if(NVML_LIB_PATH)
target_link_libraries(fbgemm_gpu_py ${NVML_LIB_PATH})
endif()
target_include_directories(fbgemm_gpu_py PRIVATE ${TORCH_INCLUDE_DIRS})
set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17)
if(USE_CUDA)
set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17)
endif()
liligwu marked this conversation as resolved.
Show resolved Hide resolved

install(TARGETS fbgemm_gpu_py DESTINATION fbgemm_gpu)

Expand Down
12 changes: 8 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,10 @@ def uvm(
offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
per_sample_weights = None
if weighted:
assert (this_rs_uvm_weights := rs_uvm[2]) is not None
assert (this_rs_gpu_weights := rs_gpu[2]) is not None
this_rs_uvm_weights = rs_uvm[2]
assert this_rs_uvm_weights is not None
this_rs_gpu_weights = rs_gpu[2]
assert this_rs_gpu_weights is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this change is needed for ROCm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this change is needed for ROCm?

ROCm is on Python 3.7.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I just don't understand Python, but

  • Does ROCm depend on a specific Python version? In my understanding, it is irrelevant.
  • Does the original grammar depend on a specific Python version (e.g., only Python2, or Python 3.8+, or Python <3.5).
assert (this_rs_uvm_weights := rs_uvm[2]) is not None

Copy link
Contributor Author

@liligwu liligwu May 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

":=" is a Python 3.8 feature, see https://stackoverflow.com/a/26000366
The PyTorch upstream CI jobs for ROCm are executed with python 3.7, so all our release dockers are on 3.7. However, ROCm does not dependent on a specific python version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes maybe we should change this if FBGEMM_GPU should work with Python < 3.8,, but I am not sure about the minimum Python version for FBGEMM_GPU, but @jianyuh do you know the version requirement?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! If the current change works for 3.7+, it looks good to me. Ideally we should use the simple syntax and avoid using the specific Python 3.8+ features.

per_sample_weights = torch.cat(
[this_rs_uvm_weights, this_rs_gpu_weights]
)
Expand Down Expand Up @@ -1634,8 +1636,10 @@ def nbit_uvm(
offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
per_sample_weights = None
if weighted:
assert (this_rs_uvm_weights := rs_uvm[2]) is not None
assert (this_rs_gpu_weights := rs_gpu[2]) is not None
this_rs_uvm_weights = rs_uvm[2]
assert this_rs_uvm_weights is not None
this_rs_gpu_weights = rs_gpu[2]
assert this_rs_gpu_weights is not None
per_sample_weights = torch.cat(
[this_rs_uvm_weights, this_rs_gpu_weights]
)
Expand Down
176 changes: 176 additions & 0 deletions fbgemm_gpu/cmake/Hip.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
set(FBGEMM_HAVE_HIP FALSE)

IF(NOT DEFINED ENV{ROCM_PATH})
SET(ROCM_PATH /opt/rocm)
ELSE()
SET(ROCM_PATH $ENV{ROCM_PATH})
ENDIF()

# HIP_PATH
IF(NOT DEFINED ENV{HIP_PATH})
SET(HIP_PATH ${ROCM_PATH}/hip)
ELSE()
SET(HIP_PATH $ENV{HIP_PATH})
ENDIF()

IF(NOT EXISTS ${HIP_PATH})
return()
ENDIF()

# HCC_PATH
IF(NOT DEFINED ENV{HCC_PATH})
SET(HCC_PATH ${ROCM_PATH}/hcc)
ELSE()
SET(HCC_PATH $ENV{HCC_PATH})
ENDIF()

# HSA_PATH
IF(NOT DEFINED ENV{HSA_PATH})
SET(HSA_PATH ${ROCM_PATH}/hsa)
ELSE()
SET(HSA_PATH $ENV{HSA_PATH})
ENDIF()

# ROCBLAS_PATH
IF(NOT DEFINED ENV{ROCBLAS_PATH})
SET(ROCBLAS_PATH ${ROCM_PATH}/rocblas)
ELSE()
SET(ROCBLAS_PATH $ENV{ROCBLAS_PATH})
ENDIF()

# ROCSPARSE_PATH
IF(NOT DEFINED ENV{ROCSPARSE_PATH})
SET(ROCSPARSE_PATH ${ROCM_PATH}/rocsparse)
ELSE()
SET(ROCSPARSE_PATH $ENV{ROCSPARSE_PATH})
ENDIF()

# ROCFFT_PATH
IF(NOT DEFINED ENV{ROCFFT_PATH})
SET(ROCFFT_PATH ${ROCM_PATH}/rocfft)
ELSE()
SET(ROCFFT_PATH $ENV{ROCFFT_PATH})
ENDIF()

# HIPSPARSE_PATH
IF(NOT DEFINED ENV{HIPSPARSE_PATH})
SET(HIPSPARSE_PATH ${ROCM_PATH}/hipsparse)
ELSE()
SET(HIPSPARSE_PATH $ENV{HIPSPARSE_PATH})
ENDIF()

# THRUST_PATH
IF(NOT DEFINED ENV{THRUST_PATH})
SET(THRUST_PATH ${ROCM_PATH}/include)
ELSE()
SET(THRUST_PATH $ENV{THRUST_PATH})
ENDIF()
liligwu marked this conversation as resolved.
Show resolved Hide resolved

# HIPRAND_PATH
IF(NOT DEFINED ENV{HIPRAND_PATH})
SET(HIPRAND_PATH ${ROCM_PATH}/hiprand)
ELSE()
SET(HIPRAND_PATH $ENV{HIPRAND_PATH})
ENDIF()

# ROCRAND_PATH
IF(NOT DEFINED ENV{ROCRAND_PATH})
SET(ROCRAND_PATH ${ROCM_PATH}/rocrand)
ELSE()
SET(ROCRAND_PATH $ENV{ROCRAND_PATH})
ENDIF()

# MIOPEN_PATH
IF(NOT DEFINED ENV{MIOPEN_PATH})
SET(MIOPEN_PATH ${ROCM_PATH}/miopen)
ELSE()
SET(MIOPEN_PATH $ENV{MIOPEN_PATH})
ENDIF()

# Add HIP to the CMAKE Module Path
set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})

# Disable Asserts In Code (Can't use asserts on HIP stack.)
ADD_DEFINITIONS(-DNDEBUG)
ADD_DEFINITIONS(-DUSE_ROCM)

IF(NOT DEFINED ENV{PYTORCH_ROCM_ARCH})
SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a)
ELSE()
SET(FBGEMM_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH})
ENDIF()

# Find the HIP Package
find_package(HIP)

IF(HIP_FOUND)
set(FBGEMM_HAVE_HIP TRUE)

if(HIP_COMPILER STREQUAL clang)
set(hip_library_name amdhip64)
else()
set(hip_library_name hip_hcc)
endif()
message("HIP library name: ${hip_library_name}")

find_package(hip REQUIRED)
find_package(rocBLAS REQUIRED)
find_package(hipFFT REQUIRED)
find_package(hipRAND REQUIRED)
find_package(rocRAND REQUIRED)
find_package(hipSPARSE REQUIRED)
find_package(OpenMP REQUIRED)
find_package(rocPRIM REQUIRED)

set(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
set(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
FIND_LIBRARY(FBGEMM_HIP_HCC_LIBRARIES ${hip_library_name} HINTS ${HIP_PATH}/lib)

list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
# list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_BFLOAT16_CONVERSIONS__=1)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF2_OPERATORS__=1)
list(APPEND HIP_CXX_FLAGS -mavx2)
list(APPEND HIP_CXX_FLAGS -mf16c)
list(APPEND HIP_CXX_FLAGS -mfma)
list(APPEND HIP_CXX_FLAGS -std=c++17)

set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS})
# Ask hcc to generate device code during compilation so we can use
# host linker to link.
list(APPEND HIP_HCC_FLAGS -fno-gpu-rdc)
list(APPEND HIP_HCC_FLAGS -Wno-defaulted-function-deleted)
foreach(fbgemm_rocm_arch ${FBGEMM_ROCM_ARCH})
list(APPEND HIP_HCC_FLAGS --amdgpu-target=${fbgemm_rocm_arch})
endforeach()

set(hip_DIR ${HIP_PATH}/lib/cmake/hip)
set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs)
set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr)
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)
set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand)
set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas)
set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen)
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft)
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)
set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl)
set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim)
set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub)
set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust)
set(ROCclr_DIR ${ROCM_PATH}/rocclr/lib/cmake/rocclr)
set(ROCRAND_INCLUDE ${ROCRAND_PATH}/include)
set(ROCM_SMI_INCLUDE ${ROCM_PATH}/rocm_smi/include)

set(FBGEMM_HIP_INCLUDE ${ROCM_PATH}/include ${FBGEMM_HIP_INCLUDE})
set(FBGEMM_HIP_INCLUDE ${hip_INCLUDE_DIRS} $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}> $<INSTALL_INTERFACE:include> ${FBGEMM_HIP_INCLUDE})

hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE})

list (APPEND CMAKE_PREFIX_PATH ${HIP_PATH} ${ROCM_PATH})
set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})

ELSE()
message(FATAL_ERROR "Not able to find HIP installation.")
ENDIF()
5 changes: 5 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,14 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_
C10_CUDA_KERNEL_LAUNCH_CHECK();
int shared_kb = max_shared_bytes >> 10;
// V100: 64 KB; A100: 96 KB.
#ifndef __HIP_PLATFORM_HCC__
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
#else
// MI100 has independent shared mem and L1
int used_shared_kb = shared_kb;
#endif
int used_shared_bytes = used_shared_kb << 10;

Tensor linear_indices, linear_indices_sorted;
Expand Down
Loading