From c29d13315c606689ed335ff97ffa330042358092 Mon Sep 17 00:00:00 2001
From: Silvio Traversaro <silvio@traversaro.it>
Date: Sun, 19 Jan 2025 23:23:56 -0800
Subject: [PATCH] PR #20288: cuda_root_path: Find cuda libraries when installed
 with conda packages

Imported from GitHub PR https://github.com/openxla/xla/pull/20288

This fix emerged when looking in solving https://github.com/jax-ml/jax/issues/24604 . In a nutshell, the official cuda package for conda (both in the `conda-forge` and `nvidia` conda channels) install the CUDA libraries in a different location with respect to PyPI packages, so the logic to find them needs to be augmented to be able to find the CUDA libraries when installed from conda packages.

I did not tested this with a tensorflow build, but probably this will also help in solving https://github.com/tensorflow/tensorflow/issues/56927 .

xref: https://github.com/conda-forge/tensorflow-feedstock/pull/408
xref: https://github.com/conda-forge/jaxlib-feedstock/pull/288
Copybara import of the project:

--
a2ce85cf9df1ede3f3c1843ede55d4c76673910e by Silvio Traversaro <silvio@traversaro.it>:

cuda_root_path: Find cuda libraries when installed with conda packages

Merging this change closes #20288

FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/20288 from traversaro:fixloadcudaconda a2ce85cf9df1ede3f3c1843ede55d4c76673910e
PiperOrigin-RevId: 717411600
---
 xla/tsl/platform/default/cuda_root_path.cc | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/xla/tsl/platform/default/cuda_root_path.cc b/xla/tsl/platform/default/cuda_root_path.cc
index 578d8b05c70e68..60c7dabf3ea5b1 100644
--- a/xla/tsl/platform/default/cuda_root_path.cc
+++ b/xla/tsl/platform/default/cuda_root_path.cc
@@ -76,6 +76,17 @@ std::vector<std::string> CandidateCudaRoots() {
     // Also add the path to the copy of libdevice.10.bc that we include within
     // the Python wheel.
     roots.emplace_back(io::JoinPath(dir, "cuda"));
+
+    // In case cuda was installed with nvidia's official conda packages, we also
+    // include the root prefix of the environment in the candidate roots dir,
+    // we assume that the lib binaries are either in the python package's root
+    // dir or in a 'python' subdirectory, as done by the previous for. python
+    // packages on non-Windows platforms are installed in
+    // $CONDA_PREFIX/lib/python3.12/site-packages/pkg_name, so if we want
+    // to add $CONDA_PREFIX to the candidate roots dirs we need to add
+    // ../../../..
+    for (auto path : {"../../../..", "../../../../.."})
+      roots.emplace_back(io::JoinPath(dir, path));
   }
 #endif  // defined(PLATFORM_POSIX) && !defined(__APPLE__)