-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathhubconf.py
91 lines (75 loc) · 3.88 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from model.main import DepthNetModule
# Each model function below corresponds to a different version of HybridDepth
# trained with a specific number of focal stacks and datasets.
def HybridDepth_NYU5(pretrained=False, **kwargs):
"""
Loads the HybridDepth model trained on the NYU Depth V2 dataset using a 5-focal stack input.
Args:
pretrained (bool): If True, loads model with pre-trained weights from URL.
**kwargs: Additional keyword arguments for the DepthNetModule class.
Returns:
DepthNetModule: The initialized HybridDepth model with optional pre-trained weights.
"""
model = DepthNetModule(**kwargs)
if pretrained:
# URL for pre-trained weights
pretrained_resource = "https://github.com/cake-lab/HybridDepth/releases/download/v2.0/NYUBest5-DFV-Trained.ckpt"
# Load the checkpoint from the URL
checkpoint = torch.hub.load_state_dict_from_url(pretrained_resource, map_location='cpu')
# Check if checkpoint contains a 'state_dict' key (some .ckpt files store this way)
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
return model
def HybridDepth_NYU10(pretrained=False, **kwargs):
"""
Loads the HybridDepth model trained on the NYU Depth V2 dataset using a 10-focal stack input.
Args:
pretrained (bool): If True, loads model with pre-trained weights from URL.
**kwargs: Additional keyword arguments for the DepthNetModule class.
Returns:
DepthNetModule: The initialized HybridDepth model with optional pre-trained weights.
"""
model = DepthNetModule(**kwargs)
if pretrained:
pretrained_resource = "https://github.com/cake-lab/HybridDepth/releases/download/v2.0/NYUBest10-DFV-Trained.ckpt"
checkpoint = torch.hub.load_state_dict_from_url(pretrained_resource, map_location='cpu')
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
return model
def HybridDepth_DDFF5(pretrained=False, **kwargs):
"""
Loads the HybridDepth model trained on the DDFF (Depth from Focus) dataset using a 5-focal stack input.
Args:
pretrained (bool): If True, loads model with pre-trained weights from URL.
**kwargs: Additional keyword arguments for the DepthNetModule class.
Returns:
DepthNetModule: The initialized HybridDepth model with optional pre-trained weights.
"""
model = DepthNetModule(**kwargs)
if pretrained:
pretrained_resource = "https://github.com/cake-lab/HybridDepth/releases/download/v2.0/DDFF12.ckpt"
checkpoint = torch.hub.load_state_dict_from_url(pretrained_resource, map_location='cpu')
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
return model
def HybridDepth_NYU_PretrainedDFV5(pretrained=False, **kwargs):
"""
Loads the HybridDepth model trained on the NYU Depth V2 dataset using a 5-focal stack input, pre-trained on DFV.
Args:
pretrained (bool): If True, loads model with pre-trained weights from URL.
**kwargs: Additional keyword arguments for the DepthNetModule class.
Returns:
DepthNetModule: The initialized HybridDepth model with optional pre-trained weights.
"""
model = DepthNetModule(**kwargs)
if pretrained:
pretrained_resource = "https://github.com/cake-lab/HybridDepth/releases/download/v2.0/NyuBest5.ckpt"
checkpoint = torch.hub.load_state_dict_from_url(pretrained_resource, map_location='cpu')
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
return model