diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index bb99278880..004df55186 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -55,6 +55,7 @@ jobs: pip3 install --upgrade pip pip3 install --upgrade packaging pip3 install -U -e . + python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Run tests diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 249dfd9e4c..dd4c95bbe4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,6 +72,7 @@ jobs: pip3 show torch pip3 install -U -e . python scripts/unsloth_install.py | sh + python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt - name: Run tests diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 65553d60b5..da28b391fc 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -38,6 +38,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ fi RUN python scripts/unsloth_install.py | sh +RUN python scripts/cutcrossentropy_install.py | sh # So we can test the Docker image RUN pip install -r requirements-dev.txt -r requirements-tests.txt diff --git a/cicd/tests.py b/cicd/tests.py index 812ef7b426..f3dbaef105 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -40,6 +40,7 @@ cicd_image = ( Image.from_dockerfile( pathlib.Path(temp_dir) / "Dockerfile", + context_mount=None, force_build=True, gpu="A10G", ) diff --git a/docker/Dockerfile b/docker/Dockerfile index 173a508792..88b871ea94 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -27,6 +27,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ fi RUN python scripts/unsloth_install.py | sh +RUN python scripts/cutcrossentropy_install.py | sh # So we can test the Docker image RUN pip install pytest diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py new file mode 100644 index 0000000000..3816e58143 --- /dev/null +++ b/scripts/cutcrossentropy_install.py @@ -0,0 +1,28 @@ +"""Script to output the correct installation command for cut-cross-entropy.""" +import importlib.util +import sys + +try: + import torch +except ImportError as exc: + raise ImportError("Install torch via `pip install torch`") from exc +from packaging.version import Version as V + +v = V(torch.__version__) + +# no cut-cross-entropy support for torch < 2.4.0 +if v < V("2.4.0"): + print("") + sys.exit(0) + +cce_spec = importlib.util.find_spec("cut_cross_entropy") +cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers") + +UNINSTALL_PREFIX = "" +if cce_spec and not cce_spec_transformers: + UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " + +print( + UNINSTALL_PREFIX + + 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"' +) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 2ef78e07d8..4572cfa5c8 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -27,7 +27,6 @@ from transformers.utils.import_utils import _is_package_available from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.chat_templates import ( @@ -38,6 +37,7 @@ from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, + prepare_plugins, validate_config, ) from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset @@ -426,11 +426,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): cfg.axolotl_config_path = config - if cfg.get("plugins"): - plugin_manager = PluginManager.get_instance() - for plugin_name in cfg["plugins"]: - plugin_manager.register(plugin_name) - try: device_props = torch.cuda.get_device_properties("cuda") gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) @@ -449,6 +444,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): }, ) + prepare_plugins(cfg) + prepare_optim_env(cfg) prepare_opinionated_env(cfg) diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index 24c0b04123..57f014bd6a 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -40,7 +40,7 @@ def train( query_tensors, return_prompt=False, generate_ref_response=True, - **generation_kwargs + **generation_kwargs, ) batch["response"] = self.tokenizer.batch_decode(response_tensors) batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) diff --git a/src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md b/src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md new file mode 100644 index 0000000000..03d1cbfb00 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md @@ -0,0 +1,325 @@ +Acknowledgements + +Portions of this Cut Cross Entropy Software may utilize the following copyrighted +material, the use of which is hereby acknowledged. + + +------ + + +PyTorch + + From PyTorch: + + Copyright (c) 2016- Facebook, Inc (Adam Paszke) + Copyright (c) 2014- Facebook, Inc (Soumith Chintala) + Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) + Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) + Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) + Copyright (c) 2011-2013 NYU (Clement Farabet) + Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) + Copyright (c) 2006 Idiap Research Institute (Samy Bengio) + Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + + From Caffe2: + + Copyright (c) 2016-present, Facebook Inc. All rights reserved. + + All contributions by Facebook: + Copyright (c) 2016 Facebook Inc. + + All contributions by Google: + Copyright (c) 2015 Google Inc. + All rights reserved. + + All contributions by Yangqing Jia: + Copyright (c) 2015 Yangqing Jia + All rights reserved. + + All contributions by Kakao Brain: + Copyright 2019-2020 Kakao Brain + + All contributions by Cruise LLC: + Copyright (c) 2022 Cruise LLC. + All rights reserved. + + All contributions by Arm: + Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + + All contributions from Caffe: + Copyright(c) 2013, 2014, 2015, the respective contributors + All rights reserved. + + All other contributions: + Copyright(c) 2015, 2016 the respective contributors + All rights reserved. + + Caffe2 uses a copyright model similar to Caffe: each contributor holds + copyright over their contributions to Caffe2. The project versioning records + all such contribution and copyright details. If a contributor wants to further + mark their specific copyright on a particular contribution, they should + indicate their copyright solely in the commit message of the change when it is + committed. + + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + +Triton + + /* + * Copyright 2018-2020 Philippe Tillet + * Copyright 2020-2022 OpenAI + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + + +Transformers + + Copyright 2018- The Hugging Face team. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/axolotl/integrations/cut_cross_entropy/LICENSE b/src/axolotl/integrations/cut_cross_entropy/LICENSE new file mode 100644 index 0000000000..7ab90b7d9c --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/LICENSE @@ -0,0 +1,47 @@ +Copyright (C) 2024 Apple Inc. All Rights Reserved. + +IMPORTANT: This Apple software is supplied to you by Apple +Inc. ("Apple") in consideration of your agreement to the following +terms, and your use, installation, modification or redistribution of +this Apple software constitutes acceptance of these terms. If you do +not agree with these terms, please do not use, install, modify or +redistribute this Apple software. + +In consideration of your agreement to abide by the following terms, and +subject to these terms, Apple grants you a personal, non-exclusive +license, under Apple's copyrights in this original Apple software (the +"Apple Software"), to use, reproduce, modify and redistribute the Apple +Software, with or without modifications, in source and/or binary forms; +provided that if you redistribute the Apple Software in its entirety and +without modifications, you must retain this notice and the following +text and disclaimers in all such redistributions of the Apple Software. +Neither the name, trademarks, service marks or logos of Apple Inc. may +be used to endorse or promote products derived from the Apple Software +without specific prior written permission from Apple. Except as +expressly stated in this notice, no other rights or licenses, express or +implied, are granted by Apple herein, including but not limited to any +patent rights that may be infringed by your derivative works or by other +works in which the Apple Software may be incorporated. + +The Apple Software is provided by Apple on an "AS IS" basis. APPLE +MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION +THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS +FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND +OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. + +IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, +MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED +AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), +STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + + +------------------------------------------------------------------------------- +SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY: + +The Cut Cross Entropy software includes a number of subcomponents with separate +copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md. +------------------------------------------------------------------------------- diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md new file mode 100644 index 0000000000..c67d7440b9 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -0,0 +1,10 @@ +# Cut Cross Entropy + +### Usage + +```yaml +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +cut_cross_entropy: true +``` diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py new file mode 100644 index 0000000000..97517bccdb --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -0,0 +1,83 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for the Plugin for Cut Cross Entropy integration with Axolotl. + +Cut Cross Entropy is an optimized implementation of cross entropy loss +from Apple's ML team. +""" +import importlib +import logging + +import torch + +from axolotl.integrations.base import BasePlugin +from axolotl.utils import get_pytorch_version + +from ...utils.distributed import zero_only +from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 + +LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy") + +_CCE_INSTALL_MESSAGE = ( + "Please install cut_cross_entropy with transformers support using " + '`pip install "cut-cross-entropy[transformers]==24.11.4"`' +) + + +class CutCrossEntropyPlugin(BasePlugin): + """ + Plugin for Cut Cross Entropy integration with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs" + + def _check_requirements(self): + """Check if all requirements are met.""" + # Check PyTorch version + + major, minor, _ = get_pytorch_version() + if (major, minor) < (2, 4): + raise ImportError( + "Cut Cross Entropy requires PyTorch >= 2.4.0. " + f"Current version: {torch.__version__}" + ) + + # Check if cut_cross_entropy is installed + cce_spec = importlib.util.find_spec("cut_cross_entropy") + if cce_spec is None: + raise ImportError(_CCE_INSTALL_MESSAGE) + + cce_spec_transformers = importlib.util.find_spec( + "cut_cross_entropy.transformers" + ) + if cce_spec_transformers is None: + raise ImportError(_CCE_INSTALL_MESSAGE) + + def pre_model_load(self, cfg): + """Apply cut cross entropy before model loading if enabled.""" + if cfg.cut_cross_entropy: + self._check_requirements() + + from cut_cross_entropy.transformers import cce_patch + + with zero_only(): + LOG.info( + f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" + ) + + # The patch checks model_type internally + cce_patch(cfg.model_config_type) diff --git a/src/axolotl/integrations/cut_cross_entropy/args.py b/src/axolotl/integrations/cut_cross_entropy/args.py new file mode 100644 index 0000000000..9a364e2d3e --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/args.py @@ -0,0 +1,42 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for handling Cut Cross Entropy input arguments. +""" +import logging +from typing import Optional + +from pydantic import BaseModel, model_validator + +LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args") + + +class CutCrossEntropyArgs(BaseModel): + """ + Input args for Cut Cross Entropy. + """ + + cut_cross_entropy: Optional[bool] = None + + @model_validator(mode="before") + @classmethod + def check_dtype_is_half(cls, data): + if not (data.get("bf16") or data.get("fp16")): + raise ValueError( + "Cut Cross Entropy requires fp16/bf16 training for backward pass. " + "Please set `bf16` or `fp16` to `True`." + ) + + return data diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 91545009ad..4602054471 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -1,7 +1,11 @@ """ Basic utils for Axolotl """ + import importlib.util +import re + +import torch def is_mlflow_available(): @@ -10,3 +14,23 @@ def is_mlflow_available(): def is_comet_available(): return importlib.util.find_spec("comet_ml") is not None + + +# pylint: disable=duplicate-code +def get_pytorch_version() -> tuple[int, int, int]: + """ + Get Pytorch version as a tuple of (major, minor, patch). + """ + torch_version = torch.__version__ + version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) + + if not version_match: + raise ValueError("Invalid version format") + + major, minor, patch = version_match.groups() + major, minor = int(major), int(minor) + patch = int(patch) if patch is not None else 0 # Default patch to 0 if not present + return major, minor, patch + + +# pylint: enable=duplicate-code diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 422ed78efb..468bd6e7f8 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -7,6 +7,7 @@ from transformers.utils import is_torch_bf16_gpu_available from transformers.utils.import_utils import is_torch_npu_available +from axolotl.integrations.base import PluginManager from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.config.models.input.v0_4_1 import ( @@ -264,3 +265,14 @@ def validate_config( return DictDefault( dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True)) ) + + +def prepare_plugins(cfg): + """ + Prepare the plugins for the configuration + """ + + if cfg.get("plugins"): + plugin_manager = PluginManager.get_instance() + for plugin_name in cfg["plugins"]: + plugin_manager.register(plugin_name) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index dc366f7870..3da5cc0ddf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1086,14 +1086,17 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: self.prepare_model(qlora_fsdp) - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp: - LOG.info( - "converting modules to %s for flash attention", self.cfg.torch_dtype - ) + should_convert = ( + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp) + or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass + ) + + if should_convert: + LOG.info("Converting modules to %s", self.cfg.torch_dtype) self.convert_embedding_modules_dtype( - embedding_modules, + embedding_modules=embedding_modules, dist_dtype=self.cfg.torch_dtype, before_kbit_train_or_finetune=False, ) diff --git a/tests/conftest.py b/tests/conftest.py index 4479e676f4..2fc985d3ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,6 +51,22 @@ def download_mlabonne_finetome_100k_dataset(): snapshot_download("mlabonne/FineTome-100k", repo_type="dataset") +@pytest.fixture +def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset(): + # download the dataset + snapshot_download( + "argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset" + ) + + +@pytest.fixture +def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset(): + # download the dataset + snapshot_download( + "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset" + ) + + @pytest.fixture def temp_dir(): # Create a temporary directory diff --git a/tests/e2e/integrations/liger.py b/tests/e2e/integrations/liger.py index bb4574dff3..455c3d2818 100644 --- a/tests/e2e/integrations/liger.py +++ b/tests/e2e/integrations/liger.py @@ -7,7 +7,7 @@ from axolotl.cli import load_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.train import train -from axolotl.utils.config import normalize_config +from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault from ..utils import with_temp_dir @@ -54,8 +54,10 @@ def test_llama_wo_flce(self, temp_dir): "lr_scheduler": "cosine", "save_safetensors": True, "bf16": "auto", + "max_steps": 10, } ) + prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) @@ -99,8 +101,10 @@ def test_llama_w_flce(self, temp_dir): "lr_scheduler": "cosine", "save_safetensors": True, "bf16": "auto", + "max_steps": 10, } ) + prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py new file mode 100644 index 0000000000..82801eedce --- /dev/null +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -0,0 +1,94 @@ +""" +Simple end-to-end test for Cut Cross Entropy integration +""" + +from pathlib import Path + +import pytest + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils import get_pytorch_version +from axolotl.utils.config import normalize_config, prepare_plugins +from axolotl.utils.dict import DictDefault + +# pylint: disable=duplicate-code + + +@pytest.fixture() +def min_cfg(temp_dir): + return { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "plugins": [ + "axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin", + ], + "cut_cross_entropy": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "output_dir": temp_dir, + "lr_scheduler": "cosine", + "save_safetensors": True, + "max_steps": 10, + "bf16": "auto", + } + + +class TestCutCrossEntropyIntegration: + """ + e2e tests for cut_cross_entropy integration with Axolotl + """ + + # pylint: disable=redefined-outer-name + def test_llama_w_cce(self, min_cfg, temp_dir): + cfg = DictDefault(min_cfg) + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + major, minor, _ = get_pytorch_version() + if (major, minor) < (2, 4): + with pytest.raises(ImportError): + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + else: + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() + + @pytest.mark.parametrize( + "attention_type", + ["flash_attention", "sdp_attention", "xformers_attention"], + ) + def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type): + cfg = DictDefault( + min_cfg + | { + attention_type: True, + } + ) + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + major, minor, _ = get_pytorch_version() + if (major, minor) < (2, 4): + with pytest.raises(ImportError): + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + else: + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists()