Skip to content

Commit

Permalink
feat: handle torch lower than 2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 authored and winglian committed Dec 1, 2024
1 parent 817b0fe commit 614c589
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
5 changes: 4 additions & 1 deletion scripts/cutcrossentropy_install.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import sys

try:
import torch
Expand All @@ -8,8 +9,10 @@

v = V(torch.__version__)

# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
raise RuntimeError(f"Torch = {v} too old!")
print("")
sys.exit(0)

print(
'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
Expand Down
19 changes: 15 additions & 4 deletions tests/e2e/integrations/test_cut_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault

Expand Down Expand Up @@ -60,8 +61,13 @@ def test_llama_w_cce(self, min_cfg, temp_dir):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

@pytest.mark.parametrize(
"attention_type",
Expand All @@ -79,5 +85,10 @@ def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

0 comments on commit 614c589

Please sign in to comment.