diff --git a/cuda/private/templates/BUILD.cufft b/cuda/private/templates/BUILD.cufft index 01a420b..856d77f 100644 --- a/cuda/private/templates/BUILD.cufft +++ b/cuda/private/templates/BUILD.cufft @@ -34,3 +34,50 @@ cc_library( ":cufftw_lib", ]), ) + +cc_import( + name = "cufftw_static_a", + static_library = "%{component_name}/%{libpath}/libcufftw_static.a", + target_compatible_with = ["@platforms//os:linux"], +) + +cc_import( + name = "cufft_static_a", + static_library = "%{component_name}/%{libpath}/libcufft_static.a", + target_compatible_with = ["@platforms//os:linux"], +) + +cc_import( + name = "cufft_static_nocallback_a", + static_library = "%{component_name}/%{libpath}/libcufft_static_nocallback.a", + target_compatible_with = ["@platforms//os:linux"], +) + +cc_library( + name = "cufftw_static", + deps = [ + ":%{component_name}_headers", + ] + if_linux([ + ":cufftw_static_a", + ]), +) + +cc_library( + name = "cufft_static", + deps = [ + ":%{component_name}_headers", + ] + if_linux([ + ":cufftw_static_a", + ":cufft_static_a", + ]), +) + +cc_library( + name = "cufft_static_nocallback", + deps = [ + ":%{component_name}_headers", + ] + if_linux([ + ":cufftw_static_a", + ":cufft_static_nocallback_a", + ]), +) diff --git a/cuda/private/templates/registry.bzl b/cuda/private/templates/registry.bzl index 978382b..9fc111e 100644 --- a/cuda/private/templates/registry.bzl +++ b/cuda/private/templates/registry.bzl @@ -4,7 +4,7 @@ REGISTRY = { "nvcc": ["compiler_deps", "nvptxcompiler"], "cccl": ["cub", "thrust"], "cublas": ["cublas"], - "cufft": ["cufft"], + "cufft": ["cufft", "cufft_static"], "cufile": [], "cupti": ["cupti", "nvperf_host", "nvperf_target"], "curand": ["curand"],