Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace setup.py with pyproject.toml #8744

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions fix_pyproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import os
from collections.abc import MutableMapping

import tomlkit
from tomlkit.items import Array

project_name = 'torch_xla'

_current_torch_version = '2.6.0'
# The following should be updated after each new torch release.
_latest_torch_version_on_pypi = '2.6.0'

_libtpu_version = '0.0.8'

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
'version', os.path.join(pkg_path, 'version.py'))
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module

_version_module = load_version_module(project_name)
__version__ = _version_module._get_version_for_build()
_jax_version = _version_module._version # JAX version, with no .dev suffix.
_cmdclass = _version_module._get_cmdclass(project_name)
_minimum_torch_version = _version_module._minimum_torch_version

with open('pyproject.toml', 'r') as f:
data = tomlkit.load(f)

project = data['project']
assert isinstance(project, MutableMapping)
dependencies = project['dependencies']
assert isinstance(dependencies, Array)
od = project['optional-dependencies']
assert isinstance(od, MutableMapping)

assert isinstance(dependencies[0], str)
assert dependencies[0].startswith('torch')
dependencies[0] = f'torch >={_minimum_torch_version}, <={_jax_version}'
od['minimum-torch'] = [f'torch=={_minimum_torch_version}']
od['ci'] = [f'torch=={_latest_torch_version_on_pypi}']
od['tpu'] = [
f'torch>={_current_torch_version},<={_jax_version}',
f'libtpu=={_libtpu_version}',
'requests', # necessary for jax.distributed.initialize
]
od['cuda'] = [
f"torch=={_current_torch_version}",
f"jax-cuda12-plugin[with_cuda]>={_current_torch_version},<={_jax_version}",
]

od['cuda12'] = [
f"torch=={_current_torch_version}",
f"jax-cuda12-plugin[with_cuda]>={_current_torch_version},<={_jax_version}",
]

od['cuda12_pip'] = [
f"torch=={_current_torch_version}",
f"jax-cuda12-plugin[with_cuda]>={_current_torch_version},<={_jax_version}",
]

od['cuda12_local'] = [
f"torch=={_current_torch_version}",
f"jax-cuda12-plugin=={_current_torch_version}",
]

# ROCm support for ROCm 6.0 and above.
od['rocm'] = [
f"torch=={_current_torch_version}",
f"jax-rocm60-plugin>={_current_torch_version},<={_jax_version}",
]

with open('pyproject.toml', 'w') as f:
tomlkit.dump(data, f)
86 changes: 86 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2018 The Google Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "torch_xla"
dynamic = ["version"]
description = "Differentiate, compile, and transform Numpy code."
readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">=3.10"
authors = [{ name = "torch_xla team", email = "[email protected]" }]
classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = [
"torch >=2.6.0, <=2.6.1",
"ml_dtypes>=0.4.0",
"numpy>=1.25",
"numpy>=1.26.0; python_version>='3.12'",
"opt_einsum",
"scipy>=1.11.1",
]

[project.optional-dependencies]
# Used only for CI builds that install torch_xla from github HEAD.
ci = ["torch==0.5.0"]
# A CPU-only torch_xla doesn't require any extras, but we keep this extra around for compatibility.
cpu = []
cuda = ["torch==2.6.0", "torch_xla-cuda12-plugin[with_cuda]>=0.5.0,<=0.5.1"]
cuda12 = ["torch==2.6.0", "torch_xla-cuda12-plugin[with_cuda]>=0.5.0,<=0.5.1"]
# Target that does not depend on the CUDA pip wheels, for those who want to use a preinstalled CUDA.
cuda12_local = ["torch==2.6.0", "torch_xla-cuda12-plugin==0.5.0"]
# Deprecated alias for cuda12, kept to avoid breaking users who wrote cuda12_pip in their CI.
cuda12_pip = ["torch==2.6.0", "torch_xla-cuda12-plugin[with_cuda]>=0.5.0,<=0.5.1"]
dev = [
"mypy>=1.15",
"pre-commit>=4.1",
"pytest>=8.3",
"ruff>=0.9.7",
"setuptools>=75.8",
"tomlkit>=0.13.2",
]
# For automatic bootstrapping distributed jobs in Kubernetes
k8s = ["kubernetes"]
# Minimum torch version; used in testing.
minimum-torch = ["torch==2.6.0"]
# ROCm support for ROCm 6.0 and above.
rocm = ["torch==2.6.0", "torch_xla-rocm60-plugin>=0.5.0,<=0.5.1"]
# Cloud TPU VM torch can be installed via:
# $ pip install "torch_xla[tpu]" -f https://storage.googleapis.com/torch_xla-releases/libtpu_releases.html
tpu = ["torch>=2.6.0,<=2.6.1", "libtpu==0.0.8", "requests"]

[project.urls]
homepage = "https://github.com/pytorch/xla"
repository = "https://github.com/pytorch/xla"

[tool.setuptools.dynamic]
version = {attr = "torch_xla.version.__version__"}

[tool.setuptools.packages.find]
exclude = ["examples", "torch_xla/src/internal_test_util"]

[tool.setuptools.package-data]
torch_xla = ['py.typed', "*.pyi", "**/*.pyi"]

[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
Loading
Loading