-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
35 lines (27 loc) · 871 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import logging
import hydra
from omegaconf import DictConfig
from omegaconf import OmegaConf # Do not confuse with dataclass.MISSING
# Import the task functions
from src.preprocess import main as preprocess
from src.inference import main as inference
from src.train import main as train
log = logging.getLogger(__name__)
# Define a registry of tasks
TASK_REGISTRY = {
"preprocess": preprocess,
"train": train,
"inference": inference,
# Add more tasks here as needed
}
@hydra.main(version_base="1.3", config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
cfg = OmegaConf.create(cfg)
mode = cfg.mode
log.info(f"Starting {mode}")
if mode not in TASK_REGISTRY:
log.error(f"Task {mode} not found in task registry")
return
TASK_REGISTRY[mode](cfg)
if __name__ == "__main__":
main()