-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathmanual_seml_sweep.py
38 lines (31 loc) · 1.34 KB
/
manual_seml_sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pathlib import Path
from pprint import pprint
from seml.config import generate_configs, read_config
from chemCPA.experiments_run import ExperimentWrapper
if __name__ == "__main__":
exp = ExperimentWrapper(init_all=False)
# this is how seml loads the config file internally
assert Path("manual_run.yaml").exists(), "config file not found"
seml_config, slurm_config, experiment_config = read_config("manual_run.yaml")
# we take the first config generated
configs = generate_configs(experiment_config)
if len(configs) > 1:
print("Careful, more than one config generated from the yaml file")
args = configs[0]
pprint(args)
exp.seed = 1337
# loads the dataset splits
exp.init_dataset(**args["dataset"])
exp.init_drug_embedding(embedding=args["model"]["embedding"])
exp.init_model(
hparams=args["model"]["hparams"],
additional_params=args["model"]["additional_params"],
load_pretrained=args["model"]["load_pretrained"],
append_ae_layer=args["model"]["append_ae_layer"],
enable_cpa_mode=args["model"]["enable_cpa_mode"],
pretrained_model_path=args["model"]["pretrained_model_path"],
pretrained_model_hashes=args["model"]["pretrained_model_hashes"],
)
# setup the torch DataLoader
exp.update_datasets()
exp.train(**args["training"])