diff --git a/CMakeLists.txt b/CMakeLists.txt index ba09e84..bca6a3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,25 +1,32 @@ -cmake_minimum_required(VERSION 3.12...3.18) -project(kepler_jax LANGUAGES CXX) +cmake_minimum_required(VERSION 3.15...3.26) +project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX) +message(STATUS "Using CMake version: " ${CMAKE_VERSION}) -message(STATUS "Using CMake version " ${CMAKE_VERSION}) - -find_package(Python COMPONENTS Interpreter Development REQUIRED) +# Find pybind11 +set(PYBIND11_NEWPYTHON ON) find_package(pybind11 CONFIG REQUIRED) +# find_package(Python COMPONENTS Interpreter Development REQUIRED) +# find_package(pybind11 CONFIG REQUIRED) + include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) # CPU op library pybind11_add_module(cpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/cpu_ops.cc) -install(TARGETS cpu_ops DESTINATION kepler_jax) +install(TARGETS cpu_ops LIBRARY DESTINATION .) + +# Include the CUDA extensions if possible +include(CheckLanguage) +check_language(CUDA) -if (KEPLER_JAX_CUDA) +if(CMAKE_CUDA_COMPILER) enable_language(CUDA) include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) pybind11_add_module( gpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu ${CMAKE_CURRENT_LIST_DIR}/lib/gpu_ops.cc) - install(TARGETS gpu_ops DESTINATION kepler_jax) + install(TARGETS gpu_ops LIBRARY DESTINATION .) else() message(STATUS "Building without CUDA") endif() diff --git a/pyproject.toml b/pyproject.toml index a82478d..bc15e72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,27 @@ +[project] +name = "kepler_jax" +description = "A simple demonstration of how you can extend JAX with custom C++ and CUDA ops" +readme = "README.md" +authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] +requires-python = ">=3.9" +license = { file = "LICENSE" } +urls = { Homepage = "https://github.com/dfm/extending-jax" } +dependencies = ["jax>=0.4.16"] +dynamic = ["version"] + +[project.optional-dependencies] +test = ["pytest"] + [build-system] -requires = ["setuptools>=42", "wheel", "setuptools_scm[toml]>=3.4", "pybind11>=2.6", "cmake"] -build-backend = "setuptools.build_meta" +requires = ["pybind11>=2.6", "scikit-build-core>=0.5"] +build-backend = "scikit_build_core.build" + +[tool.scikit-build] +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +sdist.include = ["src/kepler_jax/kepler_jax_version.py"] +wheel.install-dir = "kepler_jax" +minimum-version = "0.5" +build-dir = "build/{wheel_tag}" [tool.setuptools_scm] write_to = "src/kepler_jax/kepler_jax_version.py" diff --git a/setup.py b/setup.py deleted file mode 100644 index 7e9a341..0000000 --- a/setup.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python - -import codecs -import os -import subprocess - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - -HERE = os.path.dirname(os.path.realpath(__file__)) - - -def read(*parts): - with codecs.open(os.path.join(HERE, *parts), "rb", "utf-8") as f: - return f.read() - - -# This custom class for building the extensions uses CMake to compile. You -# don't have to use CMake for this task, but I found it to be the easiest when -# compiling ops with GPU support since setuptools doesn't have great CUDA -# support. -class CMakeBuildExt(build_ext): - def build_extensions(self): - # First: configure CMake build - import platform - import sys - import distutils.sysconfig - - import pybind11 - - # Work out the relevant Python paths to pass to CMake, adapted from the - # PyTorch build system - if platform.system() == "Windows": - cmake_python_library = "{}/libs/python{}.lib".format( - distutils.sysconfig.get_config_var("prefix"), - distutils.sysconfig.get_config_var("VERSION"), - ) - if not os.path.exists(cmake_python_library): - cmake_python_library = "{}/libs/python{}.lib".format( - sys.base_prefix, - distutils.sysconfig.get_config_var("VERSION"), - ) - else: - cmake_python_library = "{}/{}".format( - distutils.sysconfig.get_config_var("LIBDIR"), - distutils.sysconfig.get_config_var("INSTSONAME"), - ) - cmake_python_include_dir = distutils.sysconfig.get_python_inc() - - install_dir = os.path.abspath( - os.path.dirname(self.get_ext_fullpath("dummy")) - ) - os.makedirs(install_dir, exist_ok=True) - cmake_args = [ - "-DCMAKE_INSTALL_PREFIX={}".format(install_dir), - "-DPython_EXECUTABLE={}".format(sys.executable), - "-DPython_LIBRARIES={}".format(cmake_python_library), - "-DPython_INCLUDE_DIRS={}".format(cmake_python_include_dir), - "-DCMAKE_BUILD_TYPE={}".format( - "Debug" if self.debug else "Release" - ), - "-DCMAKE_PREFIX_PATH={}".format(pybind11.get_cmake_dir()), - ] - if os.environ.get("KEPLER_JAX_CUDA", "no").lower() == "yes": - cmake_args.append("-DKEPLER_JAX_CUDA=yes") - - os.makedirs(self.build_temp, exist_ok=True) - subprocess.check_call( - ["cmake", HERE] + cmake_args, cwd=self.build_temp - ) - - # Build all the extensions - super().build_extensions() - - # Finally run install - subprocess.check_call( - ["cmake", "--build", ".", "--target", "install"], - cwd=self.build_temp, - ) - - def build_extension(self, ext): - target_name = ext.name.split(".")[-1] - subprocess.check_call( - ["cmake", "--build", ".", "--target", target_name], - cwd=self.build_temp, - ) - - -extensions = [ - Extension( - "kepler_jax.cpu_ops", - ["src/kepler_jax/src/cpu_ops.cc"], - ), -] - -if os.environ.get("KEPLER_JAX_CUDA", "no").lower() == "yes": - extensions.append( - Extension( - "kepler_jax.gpu_ops", - [ - "src/kepler_jax/src/gpu_ops.cc", - "src/kepler_jax/src/cuda_kernels.cc.cu", - ], - ) - ) - - -setup( - name="kepler_jax", - author="Dan Foreman-Mackey", - author_email="foreman.mackey@gmail.com", - url="https://github.com/dfm/extending-jax", - license="MIT", - description=( - "A simple demonstration of how you can extend JAX with custom C++ and " - "CUDA ops" - ), - long_description=read("README.md"), - long_description_content_type="text/markdown", - packages=find_packages("src"), - package_dir={"": "src"}, - include_package_data=True, - install_requires=[ - "jax>=0.4.16", - "jaxlib>=0.4.16" - ], - extras_require={"test": "pytest"}, - ext_modules=extensions, - cmdclass={"build_ext": CMakeBuildExt}, -)