Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix seed #2744

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions ppcls/engine/engine.py
100755 → 100644
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"

不太理解这句话。分布式对随机种子有要求的地方,是不是只有distributed sampler。

Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, config, mode="train"):

# set seed
seed = self.config["Global"].get("seed", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的get直接返回None会不会好点。

seed = self.config["Global"].get("seed", None)
if seed is not None:
    ...

if seed or seed == 0:
if seed or not isinstance(seed, bool):
assert isinstance(seed, int), "The 'seed' must be a integer!"
paddle.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -123,26 +123,6 @@ def __init__(self, config, mode="train"):
"epochs": self.config["Global"]["epochs"]
})

# build dataloader
if self.mode == 'train':
self.train_dataloader = build_dataloader(
self.config["DataLoader"], "Train", self.device, self.use_dali)
if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
self.unlabel_train_dataloader = build_dataloader(
self.config["DataLoader"], "UnLabelTrain", self.device,
self.use_dali)
else:
self.unlabel_train_dataloader = None

self.iter_per_epoch = len(
self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
if self.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self.iter_per_epoch = self.config["Global"].get(
"iter_per_epoch")
self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq

if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
if self.eval_mode in ["classification", "adaface"]:
Expand Down Expand Up @@ -183,21 +163,6 @@ def __init__(self, config, mode="train"):
else:
self.eval_loss_func = None

# build metric
if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
"Metric"] and self.config["Metric"]["Train"]:
metric_config = self.config["Metric"]["Train"]
if hasattr(self.train_dataloader, "collate_fn"
) and self.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
self.train_metric_func = build_metrics(metric_config)
else:
self.train_metric_func = None

if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
if self.eval_mode == "classification":
Expand Down Expand Up @@ -231,13 +196,6 @@ def __init__(self, config, mode="train"):
[self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])

# build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
self.iter_per_epoch // self.update_freq,
[self.model, self.train_loss_func])

# AMP training and evaluating
self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
Expand Down Expand Up @@ -331,6 +289,48 @@ def __init__(self, config, mode="train"):
paddle.seed(int(seed) + dist.get_rank())
np.random.seed(int(seed) + dist.get_rank())
random.seed(int(seed) + dist.get_rank())

# build dataloader
if self.mode == 'train':
self.train_dataloader = build_dataloader(
self.config["DataLoader"], "Train", self.device, self.use_dali)
if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
self.unlabel_train_dataloader = build_dataloader(
self.config["DataLoader"], "UnLabelTrain", self.device,
self.use_dali)
else:
self.unlabel_train_dataloader = None

self.iter_per_epoch = len(
self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
if self.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self.iter_per_epoch = self.config["Global"].get(
"iter_per_epoch")
self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq

# build optimizer
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
self.iter_per_epoch // self.update_freq,
[self.model, self.train_loss_func])

# build metric
if self.mode == 'train' and "Metric" in self.config and "Train" in self.config[
"Metric"] and self.config["Metric"]["Train"]:
metric_config = self.config["Metric"]["Train"]
if hasattr(self.train_dataloader, "collate_fn"
) and self.train_dataloader.collate_fn is not None:
for m_idx, m in enumerate(metric_config):
if "TopkAcc" in m:
msg = f"Unable to calculate accuracy when using \"batch_transform_ops\". The metric \"{m}\" has been removed."
logger.warning(msg)
metric_config.pop(m_idx)
self.train_metric_func = build_metrics(metric_config)
else:
self.train_metric_func = None

# build postprocess for infer
if self.mode == 'infer':
Expand Down