-
Notifications
You must be signed in to change notification settings - Fork 99
/
CellStates.jl
147 lines (125 loc) · 4.18 KB
/
CellStates.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
This can be used as a CellField as long as one evaluates it
on the stored CellPoint.
"""
struct CellState{T,P<:CellPoint} <: CellField
points::P
values::AbstractArray
function CellState{T}(::UndefInitializer,points::CellPoint) where T
values = _init_values(T,get_data(points))
P = typeof(points)
new{T,P}(points,values)
end
function CellState(v::Number,points::CellPoint)
values = _init_values(v,get_data(points))
T = typeof(v)
P = typeof(points)
new{T,P}(points,values)
end
end
function _init_values(::Type{T},x::AbstractArray{<:Point}) where T
N = ndims(x)
Array{T,N}(undef,size(x))
end
function _init_values(::Type{T},x::AbstractArray{<:AbstractVector{<:Point}}) where T
[Vector{T}(undef,length(xi)) for xi in x]
end
function _init_values(v::Number,x::AbstractArray{<:Point})
fill(v,size(x))
end
function _init_values(v::Number,x::AbstractArray{<:AbstractVector{<:Point}})
[fill(v,length(xi)) for xi in x]
end
function get_data(f::CellState)
@unreachable """\n
get_data cannot be called on a CellState
If you see this error message it is likely that you are trying to perform
an operation that does not make sense for a CellState. In most cases,
to do the wanted operation, you would need first to project the CellState
to a FESpace (e.g. via a L2 projection).
"""
end
get_triangulation(f::CellState) = get_triangulation(f.points)
DomainStyle(::Type{CellState{T,P}}) where {T,P} = DomainStyle(P)
_get_cell_points(a::CellState) = a.points
function evaluate!(cache,f::CellState,x::CellPoint)
if f.points == x
f.values
else
@unreachable """\n
It is not possible to evaluate the given CellState on the given CellPoint.
a CellState can only be evaluated at the CellPoint it was created from.
If you want to evaluate at another location, you would need first to project the CellState
to a FESpace (e.g. via a L2 projection).
"""
end
end
function CellState{T}(::UndefInitializer,a) where T
points = get_cell_points(a)
CellState{T}(undef,points)
end
function CellState(v::Number,a)
points = get_cell_points(a)
CellState(v,points)
end
function update_state!(updater::Function,f::CellField...)
ids = findall(map(i->isa(i,CellState),f))
@assert length(ids) > 0 """\n
At least one CellState object has to be given to the update_state! function
"""
a = f[ids]
x = first(a).points
@assert all(map(i->i.points===x,a)) """\n
All the CellState objects given to the update_state! function need to be
defined on the same CellPoint.
"""
fx = map(i->evaluate(i,x),f)
if num_cells(x) > 0
fxi = map(first,fx)
fxiq = map(first,fxi)
need_to_update, states = first_and_tail(updater(fxiq...))
@assert isa(need_to_update,Bool) && isa(states,Tuple{Vararg{Number}}) """\n
Wrong return value of the user-defined updater Function. The signature is
need_to_update, state = first_and_tail(updater(args...))
where need_to_update is a Bool telling if we need to update_state and
states is a Tuple of Number objects with the new states.
See `first_and_tail` for further details.
"""
msg = """\n
The number of new states given by the updater Function does not match
the number of CellState objects given as arguments in the update_state! function.
"""
@check length(states) <= length(a) msg
end
caches = map(array_cache,fx)
x_data = get_data(x)
cache_x = array_cache(x_data)
_update_state_variables!(updater,caches,fx,cache_x,x_data)
nothing
end
@noinline function _update_state_variables!(updater,caches,fx,cache_x,x)
ncells = length(x)
for cell in 1:ncells
fxi = map((c,f)->getindex!(c,f,cell),caches,fx)
xi = getindex!(cache_x,x,cell)
for q in 1:length(xi)
fxiq = map(f->f[q],fxi)
need_to_update, states = first_and_tail(updater(fxiq...))
if need_to_update
_update_states!(fxi,q,states,Val{length(states)}())
end
end
end
end
function _update_states!(b,q,states,::Val{i}) where i
_update_state!(b,q,states,Val{i}())
_update_states!(b,q,states,Val{i-1}())
nothing
end
function _update_states!(b,q,states,::Val{0})
nothing
end
function _update_state!(b,q,states,::Val{i}) where i
b[i][q] = states[i]
nothing
end