Skip to content

Commit

Permalink
remove dummy_load, move gpu_ranks warning out of TrainingConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Dec 20, 2024
1 parent faa8917 commit 888b9df
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 19 deletions.
10 changes: 7 additions & 3 deletions eole/config/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
)
from eole.transforms import get_transforms_cls
from eole.constants import TransformType
from eole.utils.logging import logger
from pydantic import Field, field_validator, model_validator
import torch


# Models below somewhat replicate prior opt behavior to facilitate transition
Expand Down Expand Up @@ -82,6 +84,8 @@ def _validate_train_config(self):
self.save_data
), "-save_data should be set if \
want save samples."
if torch.cuda.is_available() and not self.training.gpu_ranks:
logger.warn("You have a CUDA device, should run with -gpu_ranks")
return self


Expand Down Expand Up @@ -129,6 +133,8 @@ def _validate_predict_config(self):
# TODO: do we really need this _all_transform?
if self._all_transform is None:
self._all_transform = self.transforms
if torch.cuda.is_available() and not self.gpu_ranks:
logger.warn("You have a CUDA device, should run with -gpu_ranks")
return self

def _update_with_model_config(self):
Expand All @@ -151,9 +157,7 @@ def _update_with_model_config(self):
if os.path.exists(config_path):
# logic from models.BaseModel.inference_logic
model_config = build_model_config(config_dict.get("model", {}))
training_config = TrainingConfig(
**config_dict.get("training", {}), dummy_load=True
)
training_config = TrainingConfig(**config_dict.get("training", {}))
training_config.world_size = self.world_size
training_config.gpu_ranks = self.gpu_ranks
# retrieve share_vocab from checkpoint config
Expand Down
12 changes: 0 additions & 12 deletions eole/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,6 @@ class TrainingConfig(
score_threshold: float = Field(
default=0.68, description="Threshold to filterout data"
)
dummy_load: bool | None = Field(
default=False,
description="Ignore some warnings if we are only loading the configuration "
"prior to other operations, e.g. in `train_from` context.",
)

@computed_field
@cached_property
Expand Down Expand Up @@ -323,13 +318,6 @@ def get_model_path(self):
def _validate_running_config(self):
super()._validate_running_config()
# self._validate_language_model_compatibilities_opts()
if (
torch.cuda.is_available()
and not self.gpu_ranks
and self.model_fields_set != set()
and not self.dummy_load
):
logger.warn("You have a CUDA device, should run with -gpu_ranks")
if self.world_size < len(self.gpu_ranks):
raise AssertionError(
"parameter counts of -gpu_ranks must be less or equal "
Expand Down
4 changes: 0 additions & 4 deletions eole/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def load_checkpoint(model_path):
config_dict = json.loads(os.path.expandvars(f.read()))
# drop data to prevent validation issues
config_dict["data"] = {}
if "training" in config_dict.keys():
config_dict["training"]["dummy_load"] = True
else:
config_dict["training"] = {"dummy_load": True}
_config = TrainConfig(**config_dict)
checkpoint["config"] = _config
else:
Expand Down

0 comments on commit 888b9df

Please sign in to comment.