Skip to content

Commit

Permalink
[Bug Fix] Include python training apis when enable_training is enabled (
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored and rui-ren committed Feb 3, 2023
1 parent 6aef31a commit 40786b1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
69 changes: 38 additions & 31 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,41 +520,48 @@ def finalize_options(self):
if not enable_training:
classifiers.extend(["Operating System :: Microsoft :: Windows", "Operating System :: MacOS"])

if enable_training:
if enable_training or enable_training_apis:
packages.append("onnxruntime.training")
if enable_training:
packages.extend(
[
"onnxruntime.training.amp",
"onnxruntime.training.experimental",
"onnxruntime.training.experimental.gradient_graph",
"onnxruntime.training.optim",
"onnxruntime.training.torchdynamo",
"onnxruntime.training.ortmodule",
"onnxruntime.training.ortmodule.experimental",
"onnxruntime.training.ortmodule.experimental.json_config",
"onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule",
"onnxruntime.training.ortmodule.torch_cpp_extensions",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
"onnxruntime.training.utils.data",
]
)

package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [
"*.cpp",
"*.cu",
"*.cuh",
"*.h",
]

packages.extend(
[
"onnxruntime.training",
"onnxruntime.training.amp",
"onnxruntime.training.experimental",
"onnxruntime.training.experimental.gradient_graph",
"onnxruntime.training.optim",
"onnxruntime.training.torchdynamo",
"onnxruntime.training.ortmodule",
"onnxruntime.training.ortmodule.experimental",
"onnxruntime.training.ortmodule.experimental.json_config",
"onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule",
"onnxruntime.training.ortmodule.torch_cpp_extensions",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
"onnxruntime.training.utils.data",
"onnxruntime.training.api",
"onnxruntime.training.onnxblock",
"onnxruntime.training.onnxblock.loss",
"onnxruntime.training.onnxblock.optim",
]
)
if enable_training_apis:
packages.append("onnxruntime.training.api")
packages.append("onnxruntime.training.onnxblock")
packages.append("onnxruntime.training.onnxblock.loss")
packages.append("onnxruntime.training.onnxblock.optim")
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"]
package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [
"*.cpp",
"*.cu",
"*.cuh",
"*.h",
]

requirements_file = "requirements-training.txt"
# with training, we want to follow this naming convention:
# stable:
Expand Down
5 changes: 5 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,11 @@ def main():
if args.use_gdk:
args.test = False

# enable_training is a higher level flag that enables all training functionality.
if args.enable_training:
args.enable_training_apis = True
args.enable_training_ops = True

configs = set(args.config)

# setup paths and directories
Expand Down

0 comments on commit 40786b1

Please sign in to comment.