Skip to content

Commit

Permalink
Patch xla to find conda cuda libraries and add dependency on required…
Browse files Browse the repository at this point in the history
… cuda packages
  • Loading branch information
traversaro committed Dec 8, 2024
1 parent 266fbcd commit 8c12eaf
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
7 changes: 7 additions & 0 deletions recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ requirements:
- ml_dtypes >=0.2.0
- __cuda # [cuda_compiler_version != "None"]
- cuda-nvcc-tools # [(cuda_compiler_version or "").startswith("12")]
# Workaround for https://github.com/conda-forge/jaxlib-feedstock/pull/288#issuecomment-2511925904
- libcublas-dev # [(cuda_compiler_version or "").startswith("12")]
- libcusolver-dev # [(cuda_compiler_version or "").startswith("12")]
- libcurand-dev # [(cuda_compiler_version or "").startswith("12")]
- cuda-cupti-dev # [(cuda_compiler_version or "").startswith("12")]
- libcufft-dev # [(cuda_compiler_version or "").startswith("12")]
- libcusparse-dev # [(cuda_compiler_version or "").startswith("12")]
run_constrained:
- jax >={{ version }}

Expand Down
31 changes: 30 additions & 1 deletion recipe/patches/0002-Consolidated-build-fixes-for-XLA.patch
Original file line number Diff line number Diff line change
Expand Up @@ -494,11 +494,39 @@ index 0000000..a7f7f0b
++}
+
+ } // namespace xla
diff --git a/third_party/xla/0006-Add-conda-cuda-path.patch b/third_party/xla/0006-Add-conda-cuda-path.patch
new file mode 100644
index 0000000..a7f7f0b
--- /dev/null
+++ b/third_party/xla/0006-Add-conda-cuda-path.patch
@@ -0,0 +1,22 @@
+diff --git a/third_party/tsl/tsl/platform/default/cuda_root_path.cc b/third_party/tsl/tsl/platform/default/cuda_root_path.cc
+index ca6da0e553..1d8a9450c0 100644
+--- a/third_party/tsl/tsl/platform/default/cuda_root_path.cc
++++ b/third_party/tsl/tsl/platform/default/cuda_root_path.cc
+@@ -75,6 +75,17 @@ std::vector<std::string> CandidateCudaRoots() {
+ // Also add the path to the copy of libdevice.10.bc that we include within
+ // the Python wheel.
+ roots.emplace_back(io::JoinPath(dir, "cuda"));
++
++ // In case cuda was installed with nvidia's official conda packages, we also
++ // include the root prefix of the environment in the candidate roots dir,
++ // we assume that the lib binaries are either in the python package's root dir
++ // or in a 'python' subdirectory, as done by the previous for.
++ // python packages on non-Windows platforms are installed in
++ // $CONDA_PREFIX/lib/python3.12/site-packages/pkg_name, so if we want
++ // to add $CONDA_PREFIX to the candidate roots dirs we need to add
++ // ../../../..
++ for (auto path : {"../../../..", "../../../../.."})
++ roots.emplace_back(io::JoinPath(dir, path));
+ }
+ #endif // defined(PLATFORM_POSIX) && !defined(__APPLE__)
+
diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl
index 37861d9..45fbcde 100644
--- a/third_party/xla/workspace.bzl
+++ b/third_party/xla/workspace.bzl
@@ -30,6 +30,13 @@ def repo():
@@ -30,6 +30,14 @@ def repo():
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
Expand All @@ -508,6 +536,7 @@ index 37861d9..45fbcde 100644
+ "//third_party/xla:0003-Omit-usage-of-StrFormat.patch",
+ "//third_party/xla:0004-Add-missing-bits-absl-systemlib.patch",
+ "//third_party/xla:0005-Check-whether-absl-log-is-already-initialized.patch",
+ "//third_party/xla:0006-Add-conda-cuda-path.patch",
+ ],
)

Expand Down

0 comments on commit 8c12eaf

Please sign in to comment.