-
Notifications
You must be signed in to change notification settings - Fork 3
/
plots.py
192 lines (155 loc) · 6.19 KB
/
plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""Metrics for generic plotting.
Functions
---------
plot_metrics(history,metric)
plot_metrics_panels(history, settings)
plot_map(x, clim=None, title=None, text=None, cmap='RdGy')
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy as ct
import numpy.ma as ma
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import custom_metrics
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
def savefig(filename,dpi=300):
for fig_format in (".png",".pdf"):
plt.savefig(filename + fig_format,
bbox_inches="tight",
dpi=dpi)
def plot_metrics(history,metric):
imin = np.argmin(history.history['val_loss'])
plt.plot(history.history[metric], label='training')
plt.plot(history.history['val_' + metric], label='validation')
plt.title(metric)
plt.axvline(x=imin, linewidth=.5, color='gray',alpha=.5)
plt.legend()
def plot_metrics_panels(history, settings):
if(settings["network_type"]=="reg"):
error_name = "mae"
elif settings['network_type'] == 'shash2':
error_name = "custom_mae"
else:
raise NotImplementedError('no such network_type')
imin = len(history.history[error_name])
plt.subplots(figsize=(20,4))
plt.subplot(1,4,1)
plot_metrics(history,'loss')
plt.ylim(0,10.)
plt.subplot(1,4,2)
plot_metrics(history,error_name)
plt.ylim(0,10)
try:
plt.subplot(1,4,3)
plot_metrics(history,'interquartile_capture')
plt.subplot(1,4,4)
plot_metrics(history,'sign_test')
except:
pass
def plot_map(x, clim=None, title=None, text=None, cmap='RdGy'):
plt.pcolor(x,
cmap=cmap,
)
plt.clim(clim)
plt.colorbar()
plt.title(title,fontsize=15,loc='right')
plt.yticks([])
plt.xticks([])
plt.text(0.01, 1.0, text, fontfamily='monospace', fontsize='small', va='bottom',transform=plt.gca().transAxes)
def drawOnGlobe(ax, map_proj, data, lats, lons, cmap='coolwarm', vmin=None, vmax=None, inc=None, cbarBool=True, contourMap=[], contourVals = [], fastBool=False, extent='both'):
data_crs = ct.crs.PlateCarree()
data_cyc, lons_cyc = add_cyclic_point(data, coord=lons) #fixes white line by adding point#data,lons#ct.util.add_cyclic_point(data, coord=lons) #fixes white line by adding point
data_cyc = data
lons_cyc = lons
# ax.set_global()
# ax.coastlines(linewidth = 1.2, color='black')
# ax.add_feature(cartopy.feature.LAND, zorder=0, scale = '50m', edgecolor='black', facecolor='black')
land_feature = cfeature.NaturalEarthFeature(
category='physical',
name='land',
scale='50m',
facecolor='None',
edgecolor = 'k',
linewidth=.5,
)
ax.add_feature(land_feature)
# ax.GeoAxes.patch.set_facecolor('black')
if(fastBool):
image = ax.pcolormesh(lons_cyc, lats, data_cyc, transform=data_crs, cmap=cmap)
# image = ax.contourf(lons_cyc, lats, data_cyc, np.linspace(0,vmax,20),transform=data_crs, cmap=cmap)
else:
image = ax.pcolor(lons_cyc, lats, data_cyc, transform=data_crs, cmap=cmap,shading='auto')
if(np.size(contourMap) !=0 ):
contourMap_cyc, __ = add_cyclic_point(contourMap, coord=lons) #fixes white line by adding point
ax.contour(lons_cyc,lats,contourMap_cyc,contourVals, transform=data_crs, colors='fuchsia')
if(cbarBool):
cb = plt.colorbar(image, shrink=.45, orientation="horizontal", pad=.02, extend=extent)
cb.ax.tick_params(labelsize=6)
else:
cb = None
image.set_clim(vmin,vmax)
return cb, image
def add_cyclic_point(data, coord=None, axis=-1):
# had issues with cartopy finding utils so copied for myself
if coord is not None:
if coord.ndim != 1:
raise ValueError('The coordinate must be 1-dimensional.')
if len(coord) != data.shape[axis]:
raise ValueError('The length of the coordinate does not match '
'the size of the corresponding dimension of '
'the data array: len(coord) = {}, '
'data.shape[{}] = {}.'.format(
len(coord), axis, data.shape[axis]))
delta_coord = np.diff(coord)
if not np.allclose(delta_coord, delta_coord[0]):
raise ValueError('The coordinate must be equally spaced.')
new_coord = ma.concatenate((coord, coord[-1:] + delta_coord[0]))
slicer = [slice(None)] * data.ndim
try:
slicer[axis] = slice(0, 1)
except IndexError:
raise ValueError('The specified axis does not correspond to an '
'array dimension.')
new_data = ma.concatenate((data, data[tuple(slicer)]), axis=axis)
if coord is None:
return_value = new_data
else:
return_value = new_data, new_coord
return return_value
def plot_pits(ax, x_val, onehot_val, model_shash):
plt.sca(ax)
clr_shash = 'tab:blue'
# shash pit
bins, hist_shash, D_shash, EDp_shash = custom_metrics.compute_pit(onehot_val, x_data=x_val,model_shash=model_shash)
bins_inc = bins[1]-bins[0]
bin_add = bins_inc/2
bin_width = bins_inc*.98
ax.bar(hist_shash[1][:-1] + bin_add,
hist_shash[0],
width=bin_width,
color=clr_shash,
label='SHASH',
)
# make the figure pretty
ax.axhline(y=.1,
linestyle='--',
color='k',
linewidth=2.,
)
# ax = plt.gca()
yticks = np.around(np.arange(0,.55,.05),2)
plt.yticks(yticks,yticks)
ax.set_ylim(0,.25)
ax.set_xticks(bins,np.around(bins,1))
plt.text(0.,np.max(ax.get_ylim())*.99,
'SHASH D: ' + str(np.round(D_shash,4)) + ' (' + str(np.round(EDp_shash,3)) + ')',
color=clr_shash,
verticalalignment='top',
fontsize=12)
ax.set_xlabel('probability integral transform')
ax.set_ylabel('probability')
# plt.legend(loc=1)
# plt.title('PIT histogram comparison', fontsize=FS, color='k')