-
Notifications
You must be signed in to change notification settings - Fork 11
/
utils.py
63 lines (54 loc) · 1.79 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import xarray as xr
import hydra
from hydra.utils import instantiate, get_class, call
from pathlib import Path
import hydra_main
def coords_to_dim(ds, dims=('time',), drop=('x',)):
df = ds.to_dataframe()
for d in dims:
df = df.set_index(d, append=True)
return (
df.reset_index(level=drop, drop=True)
.pipe(lambda ddf: xr.Dataset.from_dataframe(ddf))
)
def reindex(ds, dims=('time', 'lat', 'lon')):
df = ds.to_dataframe().reset_index()
for i, d in enumerate(dims):
df = df.set_index(d, append=i>0)
return df.pipe(lambda ddf: xr.Dataset.from_dataframe(ddf))
def get_cfg(xp_cfg, overrides=None):
overrides = overrides if overrides is not None else []
def get():
cfg = hydra.compose(config_name='main', overrides=
[
f'xp={xp_cfg}',
'file_paths=jz',
'entrypoint=train',
] + overrides
)
return cfg
try:
with hydra.initialize_config_dir(str(Path('hydra_config').absolute())):
return get()
except ValueError as e:
return get()
def get_model(xp_cfg, ckpt, dm=None, add_overrides=None):
overrides = []
if add_overrides is not None:
overrides = overrides + add_overrides
cfg = get_cfg(xp_cfg, overrides)
lit_mod_cls = get_class(cfg.lit_mod_cls)
if dm is None:
dm = instantiate(cfg.datamodule)
runner = hydra_main.FourDVarNetHydraRunner(cfg.params, dm, lit_mod_cls)
mod = runner._get_model(ckpt)
return mod
def get_dm(xp_cfg, setup=True, add_overrides=None):
overrides = []
if add_overrides is not None:
overrides = overrides + add_overrides
cfg = get_cfg(xp_cfg, overrides)
dm = instantiate(cfg.datamodule)
if setup:
dm.setup()
return dm