Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add and use VersionedDeprecationWarning #944

Merged
merged 6 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@
'TiktokenTokenizerWrapper',
]

__version__ = '0.4.0'
__version__ = '0.5.0'
10 changes: 6 additions & 4 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from composer.callbacks import Generate as ComposerGenerate
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from llmfoundry.utils.warnings import VersionedDeprecationWarning

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


Expand All @@ -20,10 +22,10 @@ def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):

warnings.warn(
('Accessing llmfoundry.callbacks.generate_callback.Generate '
'is deprecated and will be removed in a future release. '
'Please use composer.callbacks.Generate instead.'),
DeprecationWarning,
VersionedDeprecationWarning('Accessing llmfoundry.callbacks.generate_callback.Generate ' + \
'is deprecated. Please use composer.callbacks.Generate instead.',
remove_version='0.5.0',
)
)

interval = f'{batch_log_interval}ba'
Expand Down
9 changes: 6 additions & 3 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -433,9 +435,10 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
import warnings

warnings.warn(
DeprecationWarning(
'Please use scripts/misc/profile_packing.py to profile packing.' +
'This script will be removed in later releases.'))
VersionedDeprecationWarning(
'Please use scripts/misc/profile_packing.py to profile packing.',
remove_version='0.5.0',
))

import os
from argparse import ArgumentParser, Namespace
Expand Down
31 changes: 19 additions & 12 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.utils.warnings import VersionedDeprecationWarning


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -104,14 +105,16 @@ def scaled_multihead_dot_product_attention(
torch.Tensor]]]:
if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

Expand Down Expand Up @@ -249,14 +252,16 @@ def flash_attn_fn(

if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

Expand Down Expand Up @@ -422,14 +427,16 @@ def triton_flash_attn_fn(

if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
remove_version='0.5.0',
))
kv_n_heads = n_heads

Expand Down
12 changes: 8 additions & 4 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)

from llmfoundry.utils.warnings import VersionedDeprecationWarning

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
}
Expand Down Expand Up @@ -159,8 +161,9 @@ def __init__(
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
if verbose is not None:
warnings.warn(
DeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
VersionedDeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.',
remove_version='0.5.0',
))

if 'name' in kwargs:
Expand Down Expand Up @@ -226,8 +229,9 @@ def _validate_config(self) -> None:

if self.attn_config['attn_impl'] == 'flash' and is_flash_v1_installed():
warnings.warn(
DeprecationWarning(
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.'
VersionedDeprecationWarning(
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.',
remove_version='0.6.0',
))

if self.attn_config[
Expand Down
27 changes: 27 additions & 0 deletions llmfoundry/utils/warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0


class VersionedDeprecationWarning(DeprecationWarning):
irenedea marked this conversation as resolved.
Show resolved Hide resolved
"""A custom deprecation warning class that includes version information.

Attributes:
message (str): The deprecation message describing why the feature is deprecated.
remove_version (str): The version in which the feature will be removed.

Example:
>>> def deprecated_function():
... warnings.warn(
... VersionedDeprecationWarning(
... "Function XYZ is deprecated.",
... after_version="2.0.0"
... )
... )
...
>>> deprecated_function()
DeprecationWarning: Function XYZ is deprecated. It will be removed in version 2.0.0.
"""

def __init__(self, message: str, remove_version: str) -> None:
super().__init__(message +
f' It will be removed in version {remove_version}.')
8 changes: 6 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from omegaconf import OmegaConf as om
from rich.traceback import install

from llmfoundry.utils.warnings import VersionedDeprecationWarning

install()

from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -219,8 +221,10 @@ def main(cfg: DictConfig) -> Trainer:
default_value=None)
if eval_gauntlet_config is not None:
warnings.warn(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`',
DeprecationWarning)
VersionedDeprecationWarning(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`.',
remove_version='0.5.0',
))
icl_subset_num_batches: Optional[int] = pop_config(cfg,
'icl_subset_num_batches',
must_exist=False,
Expand Down
Loading