-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompute_features.py
46 lines (44 loc) · 1.92 KB
/
compute_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-
import argparse
import clip
import os
import utils as uti
import datasets as dts
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--data_root_path', type=str)
parser.add_argument('-d', '--datasets', default = ['dtd'], nargs="*", help = 'List of datasets for which to compute features.')
parser.add_argument('--backbone', default='vit_b16', type=str, help = 'Name of the backbone to use. Examples : vit_b16 or rn101.')
parser.add_argument('--root_cache_path', default = None, type = str, help = 'Path where the cached features and targets will be stored. Defaults to data_root_path/{dataset}/cache internally.')
args = parser.parse_args()
return args
def main():
args = get_arguments()
assert args.data_root_path is not None
cfg = {}
cfg['backbone'] = uti.backbones[args.backbone]
print('========== Loading Clip Model')
clip_model, preprocess = clip.load(cfg['backbone'])
if args.root_cache_path is not None:
base_cache_dir = args.root_cache_path
else:
base_cache_dir = args.data_root_path
for dataset_name in args.datasets:
print('\n******* dataset : ', dataset_name)
if dataset_name == 'imagenet':
cfg['load_cache'] = True
cfg['dataset'] = uti.datasets[dataset_name]
cfg['root_path'] = args.data_root_path
cfg['shots'] = 0
cfg['load_pre_feat'] = False
cache_dir = os.path.join(base_cache_dir, uti.datasets[dataset_name], 'cache')
os.makedirs(cache_dir, exist_ok=True)
cfg['cache_dir'] = cache_dir
print(cfg['cache_dir'])
print('Computing Features...')
train_loader, val_loader, test_loader, dataset = dts.get_all_dataloaders(cfg, preprocess, dirichlet=None)
_ = uti.get_all_features(
cfg, train_loader, val_loader, test_loader, dataset, clip_model)
return None
if __name__ == '__main__':
main()