Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add and use VersionedDeprecationWarning #944

Merged
merged 6 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from composer.callbacks import Generate as ComposerGenerate
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from llmfoundry.utils.warnings import VersionedDeprecationWarning

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


Expand All @@ -20,10 +22,10 @@ def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):

warnings.warn(
('Accessing llmfoundry.callbacks.generate_callback.Generate '
'is deprecated and will be removed in a future release. '
'Please use composer.callbacks.Generate instead.'),
DeprecationWarning,
VersionedDeprecationWarning('Accessing llmfoundry.callbacks.generate_callback.Generate ' + \
'is deprecated and will be removed in a future release. Please use composer.callbacks.Generate instead.',
irenedea marked this conversation as resolved.
Show resolved Hide resolved
after_version='0.3.0',
)
)

interval = f'{batch_log_interval}ba'
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.warnings import VersionedDeprecationWarning

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -433,9 +435,11 @@ def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
import warnings

warnings.warn(
DeprecationWarning(
VersionedDeprecationWarning(
'Please use scripts/misc/profile_packing.py to profile packing.' +
'This script will be removed in later releases.'))
'This script will be removed in later releases.',
irenedea marked this conversation as resolved.
Show resolved Hide resolved
after_version='0.3.0',
))

import os
from argparse import ArgumentParser, Namespace
Expand Down
31 changes: 19 additions & 12 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.utils.warnings import VersionedDeprecationWarning


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -104,14 +105,16 @@ def scaled_multihead_dot_product_attention(
torch.Tensor]]]:
if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
after_version='0.2.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
after_version='0.2.0',
))
kv_n_heads = n_heads

Expand Down Expand Up @@ -249,14 +252,16 @@ def flash_attn_fn(

if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
after_version='0.2.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
after_version='0.2.0',
))
kv_n_heads = n_heads

Expand Down Expand Up @@ -422,14 +427,16 @@ def triton_flash_attn_fn(

if multiquery:
warnings.warn(
DeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.'
VersionedDeprecationWarning(
'The direct use of the multiquery arg is deprecated. Setting kv_n_heads=1 automatically. Please set kv_n_heads=1 explicitly to remove this warning.',
after_version='0.2.0',
))
kv_n_heads = 1
elif kv_n_heads is None:
warnings.warn(
DeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.'
VersionedDeprecationWarning(
'Not specifying a value for the kv_n_heads arg is deprecated. Setting kv_n_heads=n_heads automatically. Please set kv_n_heads=n_heads explicitly to remove this warning.',
after_version='0.2.0',
))
kv_n_heads = n_heads

Expand Down
14 changes: 9 additions & 5 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
# isort: off
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY

from llmfoundry.utils.warnings import VersionedDeprecationWarning # type: ignore (see note)

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
Expand Down Expand Up @@ -159,8 +161,9 @@ def __init__(
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
if verbose is not None:
warnings.warn(
DeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
VersionedDeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.',
after_version='0.2.0',
))

if 'name' in kwargs:
Expand Down Expand Up @@ -226,8 +229,9 @@ def _validate_config(self) -> None:

if self.attn_config['attn_impl'] == 'flash' and is_flash_v1_installed():
warnings.warn(
DeprecationWarning(
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.'
VersionedDeprecationWarning(
'Support for Flash Attention v1 is deprecated. Please upgrade to Flash Attention v2.4.2. To install Flash Attention v2.4.2, please run `pip install -e ".[gpu-flash2]"` from the root directory of the llm-foundry repository.',
after_version='0.4.0',
))

if self.attn_config[
Expand Down
27 changes: 27 additions & 0 deletions llmfoundry/utils/warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0


class VersionedDeprecationWarning(DeprecationWarning):
irenedea marked this conversation as resolved.
Show resolved Hide resolved
"""A custom deprecation warning class that includes version information.

Attributes:
message (str): The deprecation message describing why the feature is deprecated.
after_version (str): The version after which the feature will be deprecated.
irenedea marked this conversation as resolved.
Show resolved Hide resolved
It will be removed after two releases.

Example:
>>> def deprecated_function():
... warnings.warn(
... VersionedDeprecationWarning(
... "Function XYZ is deprecated.",
... after_version="2.0.0"
... )
... )
...
>>> deprecated_function()
DeprecationWarning: After version 2.0.0: Function XYZ is deprecated.
"""

def __init__(self, message: str, after_version: str) -> None:
super().__init__(f'After version {after_version}:' + message)
8 changes: 6 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from omegaconf import OmegaConf as om
from rich.traceback import install

from llmfoundry.utils.warnings import VersionedDeprecationWarning

install()

from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -219,8 +221,10 @@ def main(cfg: DictConfig) -> Trainer:
default_value=None)
if eval_gauntlet_config is not None:
warnings.warn(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`',
DeprecationWarning)
VersionedDeprecationWarning(
'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`',
irenedea marked this conversation as resolved.
Show resolved Hide resolved
after_version='0.2.0',
))
icl_subset_num_batches: Optional[int] = pop_config(cfg,
'icl_subset_num_batches',
must_exist=False,
Expand Down
Loading