Skip to content

Commit

Permalink
Fixes to lion8b test for torch 2.1 (#649)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 7, 2023
1 parent d3c3305 commit 7fb084a
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 240 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.1.0'
Expand Down
9 changes: 8 additions & 1 deletion llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import torch
from packaging import version


class DecoupledLionW_8bit(torch.optim.Optimizer):
Expand Down Expand Up @@ -53,7 +54,7 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
by retaining information across optimizer steps.
Raises:
NotImplemenetedError - If any of `quantize`, `compress_state_dict`,
NotImplementedError - If any of `quantize`, `compress_state_dict`,
or `error_correction` are `True` and either a) there is no CUDA
device, or b) step() is executed on a non-CUDA parameter.
"""
Expand All @@ -67,6 +68,12 @@ def __init__(self,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
if version.parse(torch.__version__) >= version.parse(
'2.1.0') and error_correction:
raise RuntimeError(
'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0'
)

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] <= 1.0:
Expand Down
Loading

0 comments on commit 7fb084a

Please sign in to comment.