-
-
Notifications
You must be signed in to change notification settings - Fork 891
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: add cut_cross_entropy * fix: add to input * fix: remove from setup.py * feat: refactor into an integration * chore: ignore lint * feat: add test for cce * fix: set max_steps for liger test * chore: Update base model following suggestion Co-authored-by: Wing Lian <[email protected]> * chore: update special_tokens following suggestion Co-authored-by: Wing Lian <[email protected]> * chore: remove with_temp_dir following comments * fix: plugins aren't loaded * chore: update quotes in error message * chore: lint * chore: lint * feat: enable FA on test * chore: refactor get_pytorch_version * fix: lock cce commit version * fix: remove subclassing UT * fix: downcast even if not using FA and config check * feat: add test to check different attentions * feat: add install to CI * chore: refactor to use parametrize for attention * fix: pytest not detecting test * feat: handle torch lower than 2.4 * fix args/kwargs to match docs * use release version cut-cross-entropy==24.11.4 * fix quotes * fix: use named params for clarity for modal builder * fix: handle install from pip * fix: test check only top level module install * fix: re-add import check * uninstall existing version if no transformers submodule in cce * more dataset fixtures into the cache --------- Co-authored-by: Wing Lian <[email protected]> Co-authored-by: Wing Lian <[email protected]>
- Loading branch information
1 parent
f073af6
commit 4078f37
Showing
19 changed files
with
705 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Script to output the correct installation command for cut-cross-entropy.""" | ||
import importlib.util | ||
import sys | ||
|
||
try: | ||
import torch | ||
except ImportError as exc: | ||
raise ImportError("Install torch via `pip install torch`") from exc | ||
from packaging.version import Version as V | ||
|
||
v = V(torch.__version__) | ||
|
||
# no cut-cross-entropy support for torch < 2.4.0 | ||
if v < V("2.4.0"): | ||
print("") | ||
sys.exit(0) | ||
|
||
cce_spec = importlib.util.find_spec("cut_cross_entropy") | ||
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers") | ||
|
||
UNINSTALL_PREFIX = "" | ||
if cce_spec and not cce_spec_transformers: | ||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " | ||
|
||
print( | ||
UNINSTALL_PREFIX | ||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.