Skip to content

Commit

Permalink
Fix zero seed (#766)
Browse files Browse the repository at this point in the history
* Change the behavior of seed 0 from randomized to fixed

* Add a warning message to notify the behavior change of the seed zero

* Update a comment
  • Loading branch information
tanemaki authored Dec 8, 2022
1 parent 764f97d commit 2cc43a7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
9 changes: 9 additions & 0 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ def get_configurable_parameters(
# keep track of the original config file because it will be modified
config_original: DictConfig = config.copy()

# if the seed value is 0, notify a user that the behavior of the seed value zero has been changed.
if config.project.get("seed") == 0:
warn(
"The seed value is now fixed to 0. "
"Up to v0.3.7, the seed was not fixed when the seed value was set to 0. "
"If you want to use the random seed, please select `None` for the seed value "
"(`null` in the YAML file) or remove the `seed` key from the YAML file."
)

# Dataset Configs
if "format" not in config.dataset.keys():
config.dataset.format = "mvtec"
Expand Down
2 changes: 1 addition & 1 deletion tools/hpo/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_args():
model_config = get_configurable_parameters(model_name=args.model, config_path=args.model_config)
hpo_config = OmegaConf.load(args.sweep_config)

if model_config.project.seed != 0:
if model_config.project.get("seed") is not None:
seed_everything(model_config.project.seed)

# check hpo config structure to see whether it adheres to comet or wandb format
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def train():
warnings.filterwarnings("ignore")

config = get_configurable_parameters(model_name=args.model, config_path=args.config)
if config.project.seed:
if config.project.get("seed") is not None:
seed_everything(config.project.seed)

datamodule = get_datamodule(config)
Expand Down

0 comments on commit 2cc43a7

Please sign in to comment.