-
Notifications
You must be signed in to change notification settings - Fork 2
/
DisplayCallback.m
167 lines (149 loc) · 6.38 KB
/
DisplayCallback.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
classdef DisplayCallback
%DISPLAYCALLBACK Callback for displaying the a cross section of
% the current solution for a GridSim-based simulation
%
% Don't create an instance of this class directly, instead
% specify it as a callback in the options passed to the Sim object.
% For example:
% opt.callback = @DisplayCallback
% opt.cross_section = @(u) u(3, :, end/2)
% sim = DiffuseSim(D, a, opt)
%
% Options:
% cross_section selects the data to display. It holds a cell
% array, with either a relative position where to
% make the cross section (between 1/end and 1), or []
% to indicate taking all data along that dimension.
% For vector-data, the first element is an integer
% indicating which componetn to display.
% For example, to display the 3th component of a
% 4-element vector field, over a cross section at
% x=end/2, and all y and z.
% {4, 1/2, ':', ':'}
%
% show_boundaries when false (default), only shows the region of
% interest. When true, includes the boundaries.
% show_convergence when true (default), shows a plot of ‖Δ𝜓‖^2 to
% monitor convergence
%
properties (SetAccess = private)
% Grid object with scaling and cropping information
grid Grid
% scaling operator/matrix: we want to plot the solution in non-scaled form
Tr
%component % index of the vector component we are displaying
imageplot % true for 2-dimensional cross sections, false otherwise
coord1 % coordinates for 'x' axis
coord2 % coordinates for 'y' axis
label1 % label for 'x' axis
label2 % label for 'y' axis
selection % subsref structure to access cross section to display
end
properties
% for tensor data: 2-element index of component to display.
% for vector data: 1-element index of component to display.
% for scalar data: ignored
% missing values: default to 1
component { mustBePositive, mustBeInteger } = []
% select data region to display. For each dimension, either indicate
% ':' (display all data along that dimension) or a fraction between
% 0 (first 'row') or 1 (last 'row).
% missing values: use 0.5 for all other dimensions.
cross_section (:,1) = {':', ':'}
% true to draw a graph of magnitude of each update step
show_convergence (1,1) logical = true
% true to display residual, instead of current estimate
plot_residual (1,1) logical = false;
end
methods
function obj = DisplayCallback(opt)
arguments
opt.?DisplayCallback
end
obj = copy_properties(obj, opt);
end
function obj = prepare(obj, sim)
arguments
obj
sim GridSim % This feedback function can only be used with grid-based simulations!
end
obj.grid = sim.grid;
obj.Tr = sim.Tr;
% convert component and cross_section to a structure that can
% be used in subsref
indices = cell(length(obj.grid.N_u), 1);
% component selection
valuedim = length(obj.grid.N_components);
obj.component = extend(obj.component, ones(valuedim, 1)); % missing values: 1
for i = 1:valuedim
indices{i} = obj.component(i);
end
% roi cross section
dims = [];
for i = 1:obj.grid.N_dim
% missing values default to 0.5
if (i > length(obj.cross_section))
obj.cross_section{i} = 0.5;
end
pos = obj.cross_section{i};
if pos == ':'
indices{i + valuedim} = ':';
dims = [dims i]; %#ok<AGROW>
else
N = numel(obj.grid.coordinates(i));
indices{i + valuedim} = min(max(round(pos * (N-1) + 1), 1), N);
end
end
obj.selection.type = '()';
obj.selection.subs = indices;
%% find out in what dimensions the cross section was taken,
if length(dims) > 2 || length(dims) < 1
error('DisplayCallback: Cross section must be 1 or 2 dimensional');
end
%% Prepare labels and coordinates
if length(dims) == 2 && all(sim.grid.N(dims) > 1)
obj.imageplot = true;
obj.label2 = sprintf("y [%s]", obj.grid.pixel_unit(dims(2)));
obj.coord2 = sim.grid.coordinates(dims(2));
else
obj.imageplot = false;
end
obj.label1 = sprintf("x [%s]", obj.grid.pixel_unit(dims(1)));
obj.coord1 = sim.grid.coordinates(dims(1));
end
function call(obj, u, r, state)
if obj.show_convergence && ~isempty(state.residuals)
subplot(2, 1, 1);
semilogy(state.residual_its, state.residuals / state.residuals(1));
xlabel('Iteration');
ylabel('‖Δ𝜓‖^2 (normalized)');
subplot(2, 1, 2);
end
if obj.plot_residual
u = r;
tlabel = "r [%d]";
else
tlabel = "u [%d]";
end
u = obj.grid.crop(u);
u = real(fieldmultiply(obj.Tr, u));
u = squeeze(subsref(u, obj.selection));
if obj.imageplot
imagesc(obj.coord1, obj.coord2, u.');
xlabel(obj.label1);
ylabel(obj.label2);
colorbar;
axis image;
else
plot(obj.coord1, u);
xlabel(obj.label1);
end
title(sprintf(tlabel, state.iteration));
nancount = sum(isnan(u(:)));
if nancount > 0
title(sprintf("\\color{red}{%d NaNs encountered}", nancount));
end
drawnow();
end
end
end