-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
40 lines (31 loc) · 1.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
Some handy functions for pytroch model training ...
"""
import torch
import logging
# Checkpoints
def save_checkpoint(model, model_dir):
torch.save(model.state_dict(), model_dir)
def resume_checkpoint(model, model_dir, device_id):
state_dict = torch.load(model_dir,
map_location=lambda storage, loc: storage.cuda(device=device_id)) # ensure all storage are on gpu
model.load_state_dict(state_dict)
# Hyper params
def use_cuda(enabled, device_id=0):
if enabled:
assert torch.cuda.is_available(), 'CUDA is not available'
torch.cuda.set_device(device_id)
def initLogging(logFilename):
"""Init for logging
"""
logging.basicConfig(
level = logging.DEBUG,
format='%(asctime)s-%(levelname)s-%(message)s',
datefmt = '%y-%m-%d %H:%M',
filename = logFilename,
filemode = 'w');
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s-%(levelname)s-%(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)