-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
105 lines (76 loc) · 3.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/python3
import os
import glob
import argparse
import tensorflow as tf
from osgeo import gdal
def get_codings(description_file):
"""Get lists of label codes and names and a an id-name mapping dictionary.
:param description_file: path to the txt file with labels and their names
:return: list of label codes, list of label names, id2code dictionary
"""
label_codes, label_names = zip(
*[parse_label_code(i) for i in open(description_file)])
label_codes, label_names = list(label_codes), list(label_names)
id2code = {i: j for i, j in enumerate(label_codes)}
return label_codes, label_names, id2code
def get_nr_of_bands(data_dir):
"""Get number of bands in the first *image.tif raster in a directory.
:param data_dir: directory with images for training or detection
"""
images = glob.glob(os.path.join(data_dir, '*image.tif'))
dataset_image = gdal.Open(images[0], gdal.GA_ReadOnly)
nr_bands = dataset_image.RasterCount
dataset_image = None
return nr_bands
def parse_label_code(line):
"""Parse lines in a text file into a label code and a label name.
:param line: line in the txt file
:return: tuple with an integer label code, a string label name
"""
a, b = line.strip().split(',')
# format label_value, label_name
return int(a), b
def print_device_info():
"""Print info about used GPUs."""
print('Available GPUs:')
print(tf.config.list_physical_devices('GPU'))
print('Device name:')
print(tf.random.uniform((1, 1)).device)
print('TF executing eagerly:')
print(tf.executing_eagerly())
def str2bool(string_val):
"""Transform a string looking like a boolean value to a boolean value.
This is needed because using type=bool in argparse actually parses strings.
Such an behaviour could result in `--force_dataset_generation False` being
misinterpreted as True (bool('False') == True).
:param string_val: a string looking like a boolean value
:return: the corresponding boolean value
"""
if isinstance(string_val, bool):
return string_val
elif string_val.lower() in ('true', 'yes', 't', 'y', '1'):
return True
elif string_val.lower() in ('false', 'no', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def model_replace_nans(weights):
"""Replace NaN values with zeroes.
Needed because of a (supposedly) bug in TF - a model trained on GPUs can
contain NaN values, but running predict() on CPUs cannot handle these
values and ends in returning zeros everywhere. Replacing NaN values fixes
this issue.
:param weights: weights of a model
:return: weights with NaN values replaced by zeros
"""
import numpy as np
valid_weights = []
for weights_layer in weights:
if np.isnan(weights_layer).any():
valid_weights.append(np.nan_to_num(weights_layer))
tf.print('NaN values found in the weights -> they are changed '
'to zeros')
else:
valid_weights.append(weights_layer)
return valid_weights