Skip to content

Commit

Permalink
feat: add cut_cross_entropy (#2091)
Browse files Browse the repository at this point in the history
* feat: add cut_cross_entropy

* fix: add to input

* fix: remove from setup.py

* feat: refactor into an integration

* chore: ignore lint

* feat: add test for cce

* fix: set max_steps for liger test

* chore: Update base model following suggestion

Co-authored-by: Wing Lian <[email protected]>

* chore: update special_tokens following suggestion

Co-authored-by: Wing Lian <[email protected]>

* chore: remove with_temp_dir following comments

* fix: plugins aren't loaded

* chore: update quotes in error message

* chore: lint

* chore: lint

* feat: enable FA on test

* chore: refactor get_pytorch_version

* fix: lock cce commit version

* fix: remove subclassing UT

* fix: downcast even if not using FA and config check

* feat: add test to check different attentions

* feat: add install to CI

* chore: refactor to use parametrize for attention

* fix: pytest not detecting test

* feat: handle torch lower than 2.4

* fix args/kwargs to match docs

* use release version cut-cross-entropy==24.11.4

* fix quotes

* fix: use named params for clarity for modal builder

* fix: handle install from pip

* fix: test check only top level module install

* fix: re-add import check

* uninstall existing version if no transformers submodule in cce

* more dataset fixtures into the cache

---------

Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
3 people authored and bursteratom committed Dec 4, 2024
1 parent f073af6 commit 4078f37
Show file tree
Hide file tree
Showing 19 changed files with 705 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ jobs:
pip3 show torch
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
Expand Down
1 change: 1 addition & 0 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi

RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh

# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
Expand Down
1 change: 1 addition & 0 deletions cicd/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
)
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi

RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh

# So we can test the Docker image
RUN pip install pytest
Expand Down
28 changes: 28 additions & 0 deletions scripts/cutcrossentropy_install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys

try:
import torch
except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V

v = V(torch.__version__)

# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
print("")
sys.exit(0)

cce_spec = importlib.util.find_spec("cut_cross_entropy")
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")

UNINSTALL_PREFIX = ""
if cce_spec and not cce_spec_transformers:
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "

print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)
9 changes: 3 additions & 6 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from transformers.utils.import_utils import _is_package_available

from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import (
Expand All @@ -38,6 +37,7 @@
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
prepare_plugins,
validate_config,
)
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
Expand Down Expand Up @@ -426,11 +426,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

cfg.axolotl_config_path = config

if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)

try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
Expand All @@ -449,6 +444,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
},
)

prepare_plugins(cfg)

prepare_optim_env(cfg)

prepare_opinionated_env(cfg)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/trainers/trl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def train(
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs
**generation_kwargs,
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
Expand Down
Loading

0 comments on commit 4078f37

Please sign in to comment.