Skip to content

Commit

Permalink
Improve FSDP + QLora (#25) [release]
Browse files Browse the repository at this point in the history
* Fix simple issues

* Default eval_accumulation_steps

* Filter model warnings

* Support FSDP activation checkpointing

* Refactor training code

* Reduce CPU RAM usage by compute_metrics

* Fix tests

* Fix tests

* Update RAM

* Add test_StrWithSeed
  • Loading branch information
AjayP13 authored Apr 30, 2024
1 parent 7c4ab24 commit 1312423
Show file tree
Hide file tree
Showing 25 changed files with 1,235 additions and 1,059 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "DataDreamer"
version = "0.31.0"
version = "0.32.0"
description = "Prompt. Generate Synthetic Data. Train & Align Models."
license = "MIT"
authors= [
Expand Down
2 changes: 1 addition & 1 deletion scripts/.cluster/slurm/_sbatch_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#SBATCH --output=.cluster/slurm/.last_job/submission.out
#SBATCH --ntasks 1
#SBATCH --cpus-per-task 16
#SBATCH --mem=10G
#SBATCH --mem=30G
#SBATCH --gpus=2

# Source the user's bashrc
Expand Down
10 changes: 9 additions & 1 deletion src/_cachable/_cachable.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _notify_adaptive_batch_sizing(model_logger: Logger, progress_state: dict[str
class _StrWithSeed(str):
seed: Any

def __new__(cls, value: str, seed: "Any | _StrWithSeed"):
def __new__(cls, value: str, seed: "Any | _StrWithSeed" = None):
obj = str.__new__(cls, value)
obj.seed = seed.seed if isinstance(seed, _StrWithSeed) else seed
return obj
Expand All @@ -75,6 +75,14 @@ def __eq__(self, __value: object) -> bool:
def __hash__(self):
return hash((self.seed, str(self)))

def __getstate__(self):
state = {"str": str(self), "seed": self.seed}

return state

def __setstate__(self, state):
self.seed = state["seed"]

@staticmethod
def total_per_input_seeds(inputs: list["str | _StrWithSeed"]) -> int:
return sum(
Expand Down
4 changes: 4 additions & 0 deletions src/embedders/sentence_transformers_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from ..utils.hf_model_utils import (
convert_dtype,
filter_model_warnings,
get_model_max_context_length,
get_tokenizer,
)
Expand Down Expand Up @@ -122,6 +123,9 @@ def model(self) -> SentenceTransformer:
# torch._dynamo.config.suppress_errors = True
# model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down
13 changes: 4 additions & 9 deletions src/llms/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
HF_TRANSFORMERS_CITATION,
PEFT_CITATION,
convert_dtype,
filter_model_warnings,
get_attn_implementation,
get_config,
get_model_max_context_length,
Expand Down Expand Up @@ -273,6 +274,9 @@ def model(self) -> PreTrainedModel:
torch._dynamo.config.suppress_errors = True
model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down Expand Up @@ -323,15 +327,6 @@ def count_tokens(self, value: str) -> int:
Returns:
The number of tokens in the string.
"""
pass
"""_summary_
Args:
value (_type_): _description_
Returns:
_type_: _description_
"""
return len(self.tokenizer.encode(value))

@torch.no_grad()
Expand Down
4 changes: 4 additions & 0 deletions src/llms/petals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..utils.arg_utils import AUTO, Default
from ..utils.background_utils import RunIfTimeout
from ..utils.fs_utils import safe_fn
from ..utils.hf_model_utils import filter_model_warnings
from ..utils.import_utils import ignore_hivemind_warnings, ignore_transformers_warnings
from .hf_transformers import HFTransformers

Expand Down Expand Up @@ -161,6 +162,9 @@ def model(self) -> PreTrainedModel:
# torch._dynamo.config.suppress_errors = True
# model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down
4 changes: 4 additions & 0 deletions src/task_models/hf_classification_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HF_TRANSFORMERS_CITATION,
PEFT_CITATION,
convert_dtype,
filter_model_warnings,
get_config,
get_model_max_context_length,
get_tokenizer,
Expand Down Expand Up @@ -152,6 +153,9 @@ def model(self) -> PreTrainedModel:
# torch._dynamo.config.suppress_errors = True
# model = torch.compile(model)

# Filter any warnings from the model
filter_model_warnings()

# Finished loading
log_if_timeout.stop(
partial(
Expand Down
22 changes: 21 additions & 1 deletion src/tests/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
from time import sleep
from types import GeneratorType

import dill
import psutil
import pytest
import torch
from flaky import flaky
from sortedcontainers import SortedDict

from ... import DataDreamer
from ..._cachable._cachable import _is_primitive
from ..._cachable._cachable import _is_primitive, _StrWithSeed
from ...llms import (
AI21,
VLLM,
Expand Down Expand Up @@ -338,6 +339,25 @@ def test_is_primitive(self):
assert _is_primitive({"foo": 5})
assert not _is_primitive({"foo": object()})

def test_StrWithSeed(self):
seed_a = _StrWithSeed("hello", seed=1)
seed_b = _StrWithSeed("hello", seed=2)
seed_c = _StrWithSeed("hello", seed=1)
assert (
isinstance(seed_a, str)
and isinstance(seed_b, str)
and isinstance(seed_c, str)
)
assert seed_a.seed == 1
assert seed_b.seed == 2
assert seed_c.seed == 1
assert str(seed_a) == "hello"
assert str(seed_b) == "hello"
assert str(seed_c) == "hello"
assert hash(seed_a) != hash(seed_b)
assert hash(seed_a) == hash(seed_c)
assert hash(seed_a) == hash(dill.loads(dill.dumps(seed_c)))

def test_check_temperature_and_top_p(self):
assert _check_temperature_and_top_p(
temperature=0.3,
Expand Down
5 changes: 3 additions & 2 deletions src/tests/trainers/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
TrainHFPPO,
TrainSentenceTransformer,
)
from ...trainers._train_hf_base import CustomDataCollatorWithPadding
from ...utils.arg_utils import AUTO
from ...utils.hf_model_utils import get_orig_model, is_bnb_quantized
from ...utils.hf_training_utils import CustomDataCollatorWithPadding
from ...utils.import_utils import ignore_transformers_warnings

with ignore_transformers_warnings():
Expand Down Expand Up @@ -422,12 +422,13 @@ def test_fsdp_peft(self, qlora, create_datadreamer, mocker):
validation_output=val_dataset.output["outputs"],
epochs=1,
batch_size=8,
gradient_checkpointing=qlora,
)
assert data_collator_spy.call_count == 0
trainer_path = cast(str, trainer._output_folder_path)
with open(os.path.join(trainer_path, "fingerprint.json"), "r") as f:
assert (
json.load(f) == "ce4179deefbddefd" if qlora else "6b385aca0ce684b3"
json.load(f) == "42a7bd193f804a4a" if qlora else "6b385aca0ce684b3"
)
assert train_result is trainer
assert (
Expand Down
2 changes: 1 addition & 1 deletion src/tests/trainers/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
TrainSentenceTransformer,
TrainSetFitClassifier,
)
from ...trainers._train_hf_base import CustomDataCollatorWithPadding
from ...utils.fs_utils import clear_dir
from ...utils.hf_model_utils import get_orig_model, validate_peft_config
from ...utils.hf_training_utils import CustomDataCollatorWithPadding
from ...utils.import_utils import ignore_transformers_warnings

with ignore_transformers_warnings():
Expand Down
5 changes: 4 additions & 1 deletion src/tests/utils/test_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,7 @@ def test_get_device_env_variables(self):
get_device_env_variables([0, 2, 999999, 0, 1, -1, -1])
with pytest.raises(AssertionError):
get_device_env_variables([0, 2, 0, 1])
assert get_device_env_variables([0, 2, 1]) == {"CUDA_VISIBLE_DEVICES": "6,3,4"}
assert get_device_env_variables([0, 2, 1]) == {
"CUDA_VISIBLE_DEVICES": "6,3,4",
"NCCL_P2P_DISABLE": "1",
}
Loading

0 comments on commit 1312423

Please sign in to comment.