StatProfilerHTML.jl report
Generated on Thu, 21 Dec 2023 12:59:22
File source code
Line Exclusive Inclusive Code
1 module ArrayInterface
2
3 using LinearAlgebra
4 using SparseArrays
5 using SuiteSparse
6
7 @static if isdefined(Base, Symbol("@assume_effects"))
8 using Base: @assume_effects
9 else
10 macro assume_effects(args...)
11 n = nfields(args)
12 call = getfield(args, n)
13 if n === 2 && getfield(args, 1) === QuoteNode(:total)
14 return esc(:(Base.@pure $(call)))
15 else
16 return esc(call)
17 end
18 end
19 end
20 @assume_effects :total __parameterless_type(T)=Base.typename(T).wrapper
21 parameterless_type(x) = parameterless_type(typeof(x))
22 parameterless_type(x::Type) = __parameterless_type(x)
23
24 const VecAdjTrans{T, V <: AbstractVector{T}} = Union{Transpose{<:Any, V}, Adjoint{<:Any, V}}
25 const MatAdjTrans{T, M <: AbstractMatrix{T}} = Union{Transpose{<:Any, M}, Adjoint{<:Any, M}}
26 const UpTri{T, M} = Union{UpperTriangular{T, M}, UnitUpperTriangular{T, M}}
27 const LoTri{T, M} = Union{LowerTriangular{T, M}, UnitLowerTriangular{T, M}}
28
29 """
30 ArrayInterface.map_tuple_type(f, T::Type{<:Tuple})
31
32 Returns tuple where each field corresponds to the field type of `T` modified by the function `f`.
33
34 # Examples
35
36 ```julia
37 julia> ArrayInterface.map_tuple_type(sqrt, Tuple{1,4,16})
38 (1.0, 2.0, 4.0)
39
40 ```
41 """
42 function map_tuple_type end
43 @inline function map_tuple_type(f, @nospecialize(T::Type))
44 ntuple(i -> f(fieldtype(T, i)), Val{fieldcount(T)}())
45 end
46
47 """
48 ArrayInterface.flatten_tuples(t::Tuple) -> Tuple
49
50 Flattens any field of `t` that is a tuple. Only direct fields of `t` may be flattened.
51
52 # Examples
53
54 ```julia
55 julia> ArrayInterface.flatten_tuples((1, ()))
56 (1,)
57
58 julia> ArrayInterface.flatten_tuples((1, (2, 3)))
59 (1, 2, 3)
60
61 julia> ArrayInterface.flatten_tuples((1, (2, (3,))))
62 (1, 2, (3,))
63
64 ```
65 """
66 function flatten_tuples end
67 function flatten_tuples(t::Tuple)
68 fields = _new_field_positions(t)
69 ntuple(Val{nfields(fields)}()) do k
70 i, j = getfield(fields, k)
71 i = length(t) - i
72 @inbounds j === 0 ? getfield(t, i) : getfield(getfield(t, i), j)
73 end
74 end
75 _new_field_positions(::Tuple{}) = ()
76 @nospecialize
77 function _new_field_positions(x::Tuple)
78 (_fl1(x, x[1])..., _new_field_positions(Base.tail(x))...)
79 end
80 _fl1(x::Tuple, x1::Tuple) = ntuple(Base.Fix1(tuple, length(x) - 1), Val(length(x1)))
81 _fl1(x::Tuple, x1) = ((length(x) - 1, 0),)
82 @specialize
83
84 """
85 parent_type(::Type{T}) -> Type
86
87 Returns the parent array that type `T` wraps.
88 """
89 parent_type(x) = parent_type(typeof(x))
90 parent_type(@nospecialize T::Type{<:Union{Symmetric, Hermitian}}) = fieldtype(T, :data)
91 parent_type(@nospecialize T::Type{<:Union{UpTri, LoTri}}) = fieldtype(T, :data)
92 parent_type(@nospecialize T::Type{<:PermutedDimsArray}) = fieldtype(T, :parent)
93 parent_type(@nospecialize T::Type{<:Adjoint}) = fieldtype(T, :parent)
94 parent_type(@nospecialize T::Type{<:Transpose}) = fieldtype(T, :parent)
95 parent_type(@nospecialize T::Type{<:SubArray}) = fieldtype(T, :parent)
96 parent_type(@nospecialize T::Type{<:Base.ReinterpretArray}) = fieldtype(T, :parent)
97 parent_type(@nospecialize T::Type{<:Base.ReshapedArray}) = fieldtype(T, :parent)
98 function parent_type(@nospecialize T::Type{<:Union{Base.Slice, Base.IdentityUnitRange}})
99 fieldtype(T, :indices)
100 end
101 parent_type(@nospecialize T::Type{<:Diagonal}) = fieldtype(T, :diag)
102 parent_type(T::Type) = T
103
104 """
105 promote_eltype(::Type{<:AbstractArray{T,N}}, ::Type{T2})
106
107 Computes the type of the `AbstractArray` that results from the element
108 type changing to `promote_type(T,T2)`.
109
110 Note that no generic fallback is given.
111 """
112 function promote_eltype end
113 function promote_eltype(::Type{Array{T, N}}, ::Type{T2}) where {T, T2, N}
114 Array{promote_type(T, T2), N}
115 end
116
117 """
118 buffer(x)
119
120 Return the buffer data that `x` points to. Unlike `parent(x::AbstractArray)`, `buffer(x)`
121 may not return another array type.
122 """
123 buffer(x) = parent(x)
124 buffer(x::SparseMatrixCSC) = getfield(x, :nzval)
125 buffer(x::SparseVector) = getfield(x, :nzval)
126 buffer(@nospecialize x::Union{Base.Slice, Base.IdentityUnitRange}) = getfield(x, :indices)
127
128 """
129 is_forwarding_wrapper(::Type{T}) -> Bool
130
131 Returns `true` if the type `T` wraps another data type and does not alter any of its
132 standard interface. For example, if `T` were an array then its size, indices, and elements
133 would all be equivalent to its wrapped data.
134 """
135 is_forwarding_wrapper(T::Type) = false
136 is_forwarding_wrapper(@nospecialize T::Type{<:Base.Slice}) = true
137 is_forwarding_wrapper(@nospecialize x) = is_forwarding_wrapper(typeof(x))
138
139 """
140 GetIndex(buffer) = GetIndex{true}(buffer)
141 GetIndex{check}(buffer) -> g
142
143 Wraps an indexable buffer in a function type that is indexed when called, so that `g(inds..)`
144 is equivalent to `buffer[inds...]`. If `check` is `false`, then all indexing arguments are
145 considered in-bounds. The default value for `check` is `true`, requiring bounds checking for
146 each index.
147
148 See also [`SetIndex!`](@ref)
149
150 !!! warning
151 Passing `false` as `check` may result in incorrect results/crashes/corruption for
152 out-of-bounds indices, similar to inappropriate use of `@inbounds`. The user is
153 responsible for ensuring this is correctly used.
154
155 # Examples
156
157 ```julia
158 julia> ArrayInterface.GetIndex(1:10)(3)
159 3
160
161 julia> ArrayInterface.GetIndex{false}(1:10)(11) # shouldn't be in-bounds
162 11
163
164 ```
165
166 """
167 struct GetIndex{CB, B} <: Function
168 buffer::B
169
170 GetIndex{true, B}(b) where {B} = new{true, B}(b)
171 GetIndex{false, B}(b) where {B} = new{false, B}(b)
172 GetIndex{check}(b::B) where {check, B} = GetIndex{check, B}(b)
173 GetIndex(b) = GetIndex{true}(b)
174 end
175
176 """
177 SetIndex!(buffer) = SetIndex!{true}(buffer)
178 SetIndex!{check}(buffer) -> g
179
180 Wraps an indexable buffer in a function type that sets a value at an index when called, so
181 that `g(val, inds..)` is equivalent to `setindex!(buffer, val, inds...)`. If `check` is
182 `false`, then all indexing arguments are considered in-bounds. The default value for `check`
183 is `true`, requiring bounds checking for each index.
184
185 See also [`GetIndex`](@ref)
186
187 !!! warning
188 Passing `false` as `check` may result in incorrect results/crashes/corruption for
189 out-of-bounds indices, similar to inappropriate use of `@inbounds`. The user is
190 responsible for ensuring this is correctly used.
191
192 # Examples
193
194 ```julia
195
196 julia> x = [1, 2, 3, 4];
197
198 julia> ArrayInterface.SetIndex!(x)(10, 2);
199
200 julia> x[2]
201 10
202
203 ```
204 """
205 struct SetIndex!{CB, B} <: Function
206 buffer::B
207
208 SetIndex!{true, B}(b) where {B} = new{true, B}(b)
209 SetIndex!{false, B}(b) where {B} = new{false, B}(b)
210 SetIndex!{check}(b::B) where {check, B} = SetIndex!{check, B}(b)
211 SetIndex!(b) = SetIndex!{true}(b)
212 end
213
214 buffer(x::Union{SetIndex!, GetIndex}) = getfield(x, :buffer)
215
216 Base.@propagate_inbounds @inline (g::GetIndex{true})(inds...) = buffer(g)[inds...]
217 @inline (g::GetIndex{false})(inds...) = @inbounds(buffer(g)[inds...])
218 Base.@propagate_inbounds @inline function (s::SetIndex!{true})(v, inds...)
219 setindex!(buffer(s), v, inds...)
220 end
221 @inline (s::SetIndex!{false})(v, inds...) = @inbounds(setindex!(buffer(s), v, inds...))
222
223 """
224 can_change_size(::Type{T}) -> Bool
225
226 Returns `true` if the Base.size of `T` can change, in which case operations
227 such as `pop!` and `popfirst!` are available for collections of type `T`.
228 """
229 can_change_size(x) = can_change_size(typeof(x))
230 function can_change_size(::Type{T}) where {T}
231 is_forwarding_wrapper(T) ? can_change_size(parent_type(T)) : false
232 end
233 can_change_size(::Type{<:Vector}) = true
234 can_change_size(::Type{<:AbstractDict}) = true
235 can_change_size(::Type{<:Base.ImmutableDict}) = false
236
237 function ismutable end
238
239 """
240 ismutable(::Type{T}) -> Bool
241
242 Query whether instances of type `T` are mutable or not, see
243 https://github.com/JuliaDiffEq/RecursiveArrayTools.jl/issues/19.
244 """
245 ismutable(x) = ismutable(typeof(x))
246 function ismutable(::Type{T}) where {T <: AbstractArray}
247 if parent_type(T) <: T
248 return true
249 else
250 return ismutable(parent_type(T))
251 end
252 end
253 ismutable(::Type{<:AbstractRange}) = false
254 ismutable(::Type{<:AbstractDict}) = true
255 ismutable(::Type{<:Base.ImmutableDict}) = false
256 ismutable(::Type{BigFloat}) = false
257 ismutable(::Type{BigInt}) = false
258 function ismutable(::Type{T}) where {T}
259 if parent_type(T) <: T
260 return Base.ismutabletype(T)
261 else
262 return ismutable(parent_type(T))
263 end
264 end
265
266 """
267 can_setindex(::Type{T}) -> Bool
268
269 Query whether a type can use `setindex!`.
270 """
271 can_setindex(x) = can_setindex(typeof(x))
272 can_setindex(T::Type) = is_forwarding_wrapper(T) ? can_setindex(parent_type(T)) : true
273 can_setindex(@nospecialize T::Type{<:AbstractRange}) = false
274 can_setindex(::Type{<:AbstractDict}) = true
275 can_setindex(::Type{<:Base.ImmutableDict}) = false
276 can_setindex(@nospecialize T::Type{<:Tuple}) = false
277 can_setindex(@nospecialize T::Type{<:NamedTuple}) = false
278 can_setindex(::Type{<:Base.Iterators.Pairs{<:Any, <:Any, P}}) where {P} = can_setindex(P)
279
280 """
281 aos_to_soa(x)
282
283 Converts an array of structs formulation to a struct of array.
284 """
285 aos_to_soa(x) = x
286
287 """
288 isstructured(::Type{T}) -> Bool
289
290 Query whether a type is a representation of a structured matrix.
291 """
292 isstructured(x) = isstructured(typeof(x))
293 isstructured(::Type) = false
294 isstructured(::Type{<:Symmetric}) = true
295 isstructured(::Type{<:Hermitian}) = true
296 isstructured(::Type{<:UpperTriangular}) = true
297 isstructured(::Type{<:LowerTriangular}) = true
298 isstructured(::Type{<:Tridiagonal}) = true
299 isstructured(::Type{<:SymTridiagonal}) = true
300 isstructured(::Type{<:Bidiagonal}) = true
301 isstructured(::Type{<:Diagonal}) = true
302
303 """
304 has_sparsestruct(x::AbstractArray) -> Bool
305
306 Determine whether `findstructralnz` accepts the parameter `x`.
307 """
308 has_sparsestruct(x) = has_sparsestruct(typeof(x))
309 has_sparsestruct(::Type) = false
310 has_sparsestruct(::Type{<:AbstractArray}) = false
311 has_sparsestruct(::Type{<:SparseMatrixCSC}) = true
312 has_sparsestruct(::Type{<:Diagonal}) = true
313 has_sparsestruct(::Type{<:Bidiagonal}) = true
314 has_sparsestruct(::Type{<:Tridiagonal}) = true
315 has_sparsestruct(::Type{<:SymTridiagonal}) = true
316
317 """
318 issingular(A::AbstractMatrix) -> Bool
319
320 Determine whether a given abstract matrix is singular.
321 """
322 issingular(A::AbstractMatrix) = issingular(Matrix(A))
323 issingular(A::AbstractSparseMatrix) = !issuccess(lu(A, check = false))
324 issingular(A::Matrix) = !issuccess(lu(A, check = false))
325 issingular(A::UniformScaling) = A.λ == 0
326 issingular(A::Diagonal) = any(iszero, A.diag)
327 issingular(A::Bidiagonal) = any(iszero, A.dv)
328 issingular(A::SymTridiagonal) = diaganyzero(ldlt(A).data)
329 issingular(A::Tridiagonal) = !issuccess(lu(A, check = false))
330 issingular(A::Union{Hermitian, Symmetric}) = diaganyzero(bunchkaufman(A, check = false).LD)
331 issingular(A::Union{LowerTriangular, UpperTriangular}) = diaganyzero(A.data)
332 issingular(A::Union{UnitLowerTriangular, UnitUpperTriangular}) = false
333 issingular(A::Union{Adjoint, Transpose}) = issingular(parent(A))
334 diaganyzero(A) = any(iszero, view(A, diagind(A)))
335
336 """
337 findstructralnz(x::AbstractArray)
338
339 Return: (I,J) #indexable objects
340 Find sparsity pattern of special matrices, the same as the first two elements of findnz(::SparseMatrixCSC).
341 """
342 function findstructralnz(x::Diagonal)
343 n = Base.size(x, 1)
344 (1:n, 1:n)
345 end
346
347 function findstructralnz(x::Bidiagonal)
348 n = Base.size(x, 1)
349 isup = x.uplo == 'U' ? true : false
350 rowind = BidiagonalIndex(n + n - 1, isup)
351 colind = BidiagonalIndex(n + n - 1, !isup)
352 (rowind, colind)
353 end
354
355 function findstructralnz(x::Union{Tridiagonal, SymTridiagonal})
356 n = Base.size(x, 1)
357 rowind = TridiagonalIndex(n + n - 1 + n - 1, n, true)
358 colind = TridiagonalIndex(n + n - 1 + n - 1, n, false)
359 (rowind, colind)
360 end
361
362 function findstructralnz(x::SparseMatrixCSC)
363 rowind, colind, _ = findnz(x)
364 (rowind, colind)
365 end
366
367 abstract type ColoringAlgorithm end
368
369 """
370 fast_matrix_colors(A)
371
372 Query whether a matrix has a fast algorithm for getting the structural
373 colors of the matrix.
374 """
375 fast_matrix_colors(A) = false
376 fast_matrix_colors(A::AbstractArray) = fast_matrix_colors(typeof(A))
377 function fast_matrix_colors(A::Type{
378 <:Union{Diagonal, Bidiagonal, Tridiagonal,
379 SymTridiagonal}})
380 true
381 end
382
383 """
384 matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
385
386 The color vector for dense matrix and triangular matrix is simply
387 `[1,2,3,..., Base.size(A,2)]`.
388 """
389 function matrix_colors(A::Union{Array, UpperTriangular, LowerTriangular})
390 eachindex(1:Base.size(A, 2)) # Vector Base.size matches number of rows
391 end
392 matrix_colors(A::Diagonal) = fill(1, Base.size(A, 2))
393 matrix_colors(A::Bidiagonal) = _cycle(1:2, Base.size(A, 2))
394 matrix_colors(A::Union{Tridiagonal, SymTridiagonal}) = _cycle(1:3, Base.size(A, 2))
395 _cycle(repetend, len) = repeat(repetend, div(len, length(repetend)) + 1)[1:len]
396
397 """
398 bunchkaufman_instance(A, pivot = LinearAlgebra.RowMaximum()) -> bunchkaufman_factorization_instance
399
400 Returns an instance of the Bunch-Kaufman factorization object with the correct type
401 cheaply.
402 """
403 function bunchkaufman_instance(A::Matrix{T}) where T
404 return bunchkaufman(similar(A, 0, 0), check = false)
405 end
406 function bunchkaufman_instance(A::SparseMatrixCSC)
407 bunchkaufman(sparse(similar(A, 1, 1)), check = false)
408 end
409
410 """
411 bunchkaufman_instance(a::Number) -> a
412
413 Returns the number.
414 """
415 bunchkaufman_instance(a::Number) = a
416
417 """
418 bunchkaufman_instance(a::Any) -> cholesky(a, check=false)
419
420 Returns the number.
421 """
422 bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false)
423
424 const DEFAULT_CHOLESKY_PIVOT = LinearAlgebra.NoPivot()
425
426 """
427 cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance
428
429 Returns an instance of the Cholesky factorization object with the correct type
430 cheaply.
431 """
432 function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
433 return cholesky(similar(A, 0, 0), pivot, check = false)
434 end
435
436 function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
437 cholesky(sparse(similar(A, 1, 1)), check = false)
438 end
439
440 """
441 cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) -> a
442
443 Returns the number.
444 """
445 cholesky_instance(a::Number, pivot = DEFAULT_CHOLESKY_PIVOT) = a
446
447 """
448 cholesky_instance(a::Any, pivot = LinearAlgebra.RowMaximum()) -> cholesky(a, check=false)
449
450 Slow fallback which gets the instance via factorization. Should get
451 specialized for new matrix types.
452 """
453 88 (31 %)
88 (31 %) samples spent in cholesky_instance
88 (100 %) (incl.) when called from init_cacheval line 1013
88 (100 %) samples spent calling cholesky
cholesky_instance(a::Any, pivot = DEFAULT_CHOLESKY_PIVOT) = cholesky(a, pivot, check = false)
454
455 """
456 ldlt_instance(A) -> ldlt_factorization_instance
457
458 Returns an instance of the LDLT factorization object with the correct type
459 cheaply.
460 """
461 function ldlt_instance(A::Matrix{T}) where {T}
462 return ldlt_instance(SymTridiagonal(similar(A, 0, 0)))
463 end
464
465 function ldlt_instance(A::SparseMatrixCSC)
466 ldlt(sparse(similar(A, 1, 1)), check=false)
467 end
468
469 function ldlt_instance(A::SymTridiagonal{T,V}) where {T,V}
470 return LinearAlgebra.LDLt{T,SymTridiagonal{T,V}}(A)
471 end
472
473 """
474 ldlt_instance(a::Number) -> a
475
476 Returns the number.
477 """
478 ldlt_instance(a::Number) = a
479
480 """
481 ldlt_instance(a::Any) -> ldlt(a, check=false)
482
483 Slow fallback which gets the instance via factorization. Should get
484 specialized for new matrix types.
485 """
486 ldlt_instance(a::Any) = ldlt(a)
487
488 """
489 lu_instance(A) -> lu_factorization_instance
490
491 Returns an instance of the LU factorization object with the correct type
492 cheaply.
493 """
494 function lu_instance(A::Matrix{T}) where {T}
495 noUnitT = typeof(zero(T))
496 luT = LinearAlgebra.lutype(noUnitT)
497 ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
498 info = zero(LinearAlgebra.BlasInt)
499 return LU{luT}(similar(A, 0, 0), ipiv, info)
500 end
501 function lu_instance(jac_prototype::SparseMatrixCSC)
502 SuiteSparse.UMFPACK.UmfpackLU(similar(jac_prototype, 1, 1))
503 end
504
505 function lu_instance(A::Symmetric{T}) where {T}
506 noUnitT = typeof(zero(T))
507 luT = LinearAlgebra.lutype(noUnitT)
508 ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
509 info = zero(LinearAlgebra.BlasInt)
510 return LU{luT}(similar(A, 0, 0), ipiv, info)
511 end
512
513 noalloc_diag(A::Diagonal) = A.diag
514 noalloc_diag(A::Tridiagonal) = A.d
515 noalloc_diag(A::SymTridiagonal) = A.dv
516
517 function lu_instance(A::Union{Tridiagonal{T},Diagonal{T},SymTridiagonal{T}}) where {T}
518 noUnitT = typeof(zero(T))
519 luT = LinearAlgebra.lutype(noUnitT)
520 ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
521 info = zero(LinearAlgebra.BlasInt)
522 vectype = similar(noalloc_diag(A), 0)
523 newA = Tridiagonal(vectype, vectype, vectype)
524 return LU{luT}(newA, ipiv, info)
525 end
526
527 """
528 lu_instance(a::Number) -> a
529
530 Returns the number.
531 """
532 lu_instance(a::Number) = a
533
534 """
535 lu_instance(a::Any) -> lu(a, check=false)
536
537 Slow fallback which gets the instance via factorization. Should get
538 specialized for new matrix types.
539 """
540 lu_instance(a::Any) = lu(a, check = false)
541
542 """
543 qr_instance(A, pivot = NoPivot()) -> qr_factorization_instance
544
545 Returns an instance of the QR factorization object with the correct type
546 cheaply.
547 """
548 function qr_instance(A::Matrix{T},pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
549 if pivot === DEFAULT_CHOLESKY_PIVOT
550 LinearAlgebra.QRCompactWY(zeros(T,0,0),zeros(T,0,0))
551 else
552 LinearAlgebra.QRPivoted(zeros(T,0,0),zeros(T,0),zeros(Int,0))
553 end
554 end
555
556 function qr_instance(A::Matrix{BigFloat},pivot = DEFAULT_CHOLESKY_PIVOT)
557 LinearAlgebra.QR(zeros(BigFloat,0,0),zeros(BigFloat,0))
558 end
559
560 # Could be optimized but this should work for any real case.
561 function qr_instance(jac_prototype::SparseMatrixCSC, pivot = DEFAULT_CHOLESKY_PIVOT)
562 qr(sparse(rand(1,1)))
563 end
564
565 """
566 qr_instance(a::Number) -> a
567
568 Returns the number.
569 """
570 qr_instance(a::Number, pivot = DEFAULT_CHOLESKY_PIVOT) = a
571
572 """
573 qr_instance(a::Any) -> qr(a)
574
575 Slow fallback which gets the instance via factorization. Should get
576 specialized for new matrix types.
577 """
578 qr_instance(a::Any, pivot = DEFAULT_CHOLESKY_PIVOT) = qr(a)# check = false)
579
580 """
581 svd_instance(A) -> qr_factorization_instance
582
583 Returns an instance of the SVD factorization object with the correct type
584 cheaply.
585 """
586 function svd_instance(A::Matrix{T}) where {T}
587 LinearAlgebra.SVD(zeros(T,0,0),zeros(real(T),0),zeros(T,0,0))
588 end
589
590 """
591 svd_instance(a::Number) -> a
592
593 Returns the number.
594 """
595 svd_instance(a::Number) = a
596
597 """
598 svd_instance(a::Any) -> svd(a)
599
600 Slow fallback which gets the instance via factorization. Should get
601 specialized for new matrix types.
602 """
603 svd_instance(a::Any) = svd(a) #check = false)
604
605 """
606 safevec(v)
607
608 It is a form of `vec` which is safe for all values in vector spaces, i.e., if it
609 is already a vector, like an AbstractVector or Number, it will return said
610 AbstractVector or Number.
611 """
612 safevec(v) = vec(v)
613 safevec(v::Number) = v
614 safevec(v::AbstractVector) = v
615
616 """
617 zeromatrix(u::AbstractVector)
618
619 Creates the zero'd matrix version of `u`. Note that this is unique because
620 `similar(u,length(u),length(u))` returns a mutable type, so it is not type-matching,
621 while `fill(zero(eltype(u)),length(u),length(u))` doesn't match the array type,
622 i.e., you'll get a CPU array from a GPU array. The generic fallback is
623 `u .* u' .* false`, which works on a surprising number of types, but can be broken
624 with weird (recursive) broadcast overloads. For higher-order tensors, this
625 returns the matrix linear operator type which acts on the `vec` of the array.
626 """
627 function zeromatrix(u)
628 x = safevec(u)
629 x .* x' .* false
630 end
631
632 # Reduces compile time burdens
633 function zeromatrix(u::Array{T}) where {T}
634 out = Matrix{T}(undef, length(u), length(u))
635 fill!(out, false)
636 end
637
638 """
639 undefmatrix(u::AbstractVector)
640
641 Creates the matrix version of `u` with possibly undefined values. Note that this is unique because
642 `similar(u,length(u),length(u))` returns a mutable type, so it is not type-matching,
643 while `fill(zero(eltype(u)),length(u),length(u))` doesn't match the array type,
644 i.e., you'll get a CPU array from a GPU array. The generic fallback is
645 `u .* u'`, which works on a surprising number of types, but can be broken
646 with weird (recursive) broadcast overloads. For higher-order tensors, this
647 returns the matrix linear operator type which acts on the `vec` of the array.
648 """
649 function undefmatrix(u)
650 similar(u, length(u), length(u))
651 end
652 function undefmatrix(u::Number)
653 return zero(u)
654 end
655 """
656 restructure(x,y)
657
658 Restructures the object `y` into a shape of `x`, keeping its values intact. For
659 simple objects like an `Array`, this simply amounts to a reshape. However, for
660 more complex objects such as an `ArrayPartition`, not all of the structural
661 information is adequately contained in the type for standard tools to work. In
662 these cases, `restructure` gives a way to convert for example an `Array` into
663 a matching `ArrayPartition`.
664 """
665 function restructure(x, y)
666 out = similar(x, eltype(y))
667 vec(out) .= vec(y)
668 out
669 end
670
671 function restructure(x::Array, y)
672 reshape(convert(Array, y), Base.size(x)...)
673 end
674
675 abstract type AbstractDevice end
676 abstract type AbstractCPU <: AbstractDevice end
677 struct CPUPointer <: AbstractCPU end
678 struct CPUTuple <: AbstractCPU end
679 struct CheckParent end
680 struct CPUIndex <: AbstractCPU end
681 struct GPU <: AbstractDevice end
682
683 """
684 device(::Type{T}) -> AbstractDevice
685
686 Indicates the most efficient way to access elements from the collection in low-level code.
687 For `GPUArrays`, will return `ArrayInterface.GPU()`.
688 For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`.
689 For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`.
690 Otherwise, returns `nothing`.
691 """
692 device(A) = device(typeof(A))
693 device(::Type) = nothing
694 device(::Type{<:Tuple}) = CPUTuple()
695 device(::Type{T}) where {T <: Array} = CPUPointer()
696 device(::Type{T}) where {T <: AbstractArray} = _device(parent_type(T), T)
697 function _device(::Type{P}, ::Type{T}) where {P, T}
698 if defines_strides(T)
699 return device(P)
700 else
701 return _not_pointer(device(P))
702 end
703 end
704 _not_pointer(::CPUPointer) = CPUIndex()
705 _not_pointer(x) = x
706 _device(::Type{T}, ::Type{T}) where {T <: DenseArray} = CPUPointer()
707 _device(::Type{T}, ::Type{T}) where {T} = CPUIndex()
708
709 """
710 can_avx(f) -> Bool
711
712 Returns `true` if the function `f` is guaranteed to be compatible with
713 `LoopVectorization.@avx` for supported element and array types. While a return
714 value of `false` does not indicate the function isn't supported, this allows a
715 library to conservatively apply `@avx` only when it is known to be safe to do so.
716
717 ```julia
718 function mymap!(f, y, args...)
719 if can_avx(f)
720 @avx @. y = f(args...)
721 else
722 @. y = f(args...)
723 end
724 end
725 ```
726 """
727 can_avx(::Any) = false
728
729 """
730 fast_scalar_indexing(::Type{T}) -> Bool
731
732 Query whether an array type has fast scalar indexing.
733 """
734 fast_scalar_indexing(x) = fast_scalar_indexing(typeof(x))
735 fast_scalar_indexing(::Type) = true
736 fast_scalar_indexing(::Type{<:LinearAlgebra.AbstractQ}) = false
737 fast_scalar_indexing(::Type{<:LinearAlgebra.LQPackedQ}) = false
738
739 """
740 allowed_getindex(x,i...)
741
742 A scalar `getindex` which is always allowed.
743 """
744 allowed_getindex(x, i...) = x[i...]
745
746 """
747 allowed_setindex!(x,v,i...)
748
749 A scalar `setindex!` which is always allowed.
750 """
751 allowed_setindex!(x, v, i...) = Base.setindex!(x, v, i...)
752
753 """
754 ArrayIndex{N}
755
756 Subtypes of `ArrayIndex` represent series of transformations for a provided index to some
757 buffer which is typically accomplished with square brackets (e.g., `buffer[index[inds...]]`).
758 The only behavior that is required of a subtype of `ArrayIndex` is the ability to transform
759 individual index elements (i.e. not collections). This does not guarantee bounds checking or
760 the ability to iterate (although additional functionality may be provided for specific
761 types).
762 """
763 abstract type ArrayIndex{N} end
764
765 const MatrixIndex = ArrayIndex{2}
766
767 const VectorIndex = ArrayIndex{1}
768
769 Base.ndims(::ArrayIndex{N}) where {N} = N
770 Base.ndims(::Type{<:ArrayIndex{N}}) where {N} = N
771
772 struct BidiagonalIndex <: MatrixIndex
773 count::Int
774 isup::Bool
775 end
776
777 struct TridiagonalIndex <: MatrixIndex
778 count::Int# count==nsize+nsize-1+nsize-1
779 nsize::Int
780 isrow::Bool
781 end
782
783 Base.firstindex(i::Union{BidiagonalIndex, TridiagonalIndex}) = 1
784 Base.lastindex(i::Union{BidiagonalIndex, TridiagonalIndex}) = i.count
785 Base.length(i::Union{BidiagonalIndex, TridiagonalIndex}) = lastindex(i)
786
787 Base.@propagate_inbounds function Base.getindex(ind::BidiagonalIndex, i::Int)
788 @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i))
789 if ind.isup
790 ii = i + 1
791 else
792 ii = i + 1 + 1
793 end
794 convert(Int, floor(ii / 2))
795 end
796
797 Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int)
798 @boundscheck 1 <= i <= ind.count || throw(BoundsError(ind, i))
799 offsetu = ind.isrow ? 0 : 1
800 offsetl = ind.isrow ? 1 : 0
801 if 1 <= i <= ind.nsize
802 return i
803 elseif ind.nsize < i <= ind.nsize + ind.nsize - 1
804 return i - ind.nsize + offsetu
805 else
806 return i - (ind.nsize + ind.nsize - 1) + offsetl
807 end
808 end
809
810
811 """
812 ndims_index(::Type{I}) -> Int
813
814 Returns the number of dimensions that an instance of `I` indexes into. If this method is
815 not explicitly defined, then `1` is returned.
816
817 See also [`ndims_shape`](@ref)
818
819 # Examples
820
821 ```julia
822 julia> ArrayInterface.ndims_index(Int)
823 1
824
825 julia> ArrayInterface.ndims_index(CartesianIndex(1, 2, 3))
826 3
827
828 julia> ArrayInterface.ndims_index([CartesianIndex(1, 2), CartesianIndex(1, 3)])
829 2
830
831 ```
832 """
833 ndims_index(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N
834 # preserve CartesianIndices{0} as they consume a dimension.
835 ndims_index(::Type{CartesianIndices{0, Tuple{}}}) = 1
836 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T)
837 ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T))
838 ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask))
839 ndims_index(T::Type) = 1
840 ndims_index(@nospecialize(i)) = ndims_index(typeof(i))
841
842 """
843 ndims_shape(::Type{I}) -> Union{Int,Tuple{Vararg{Int}}}
844
845 Returns the number of dimension that are represented in the shape of the returned array when
846 indexing with an instance of `I`.
847
848 See also [`ndims_index`](@ref)
849
850 # Examples
851
852 ```julia
853 julia> ArrayInterface.ndims_shape([CartesianIndex(1, 1), CartesianIndex(1, 2)])
854 1
855
856 julia> ndims(CartesianIndices((2,2))[[CartesianIndex(1, 1), CartesianIndex(1, 2)]])
857 1
858
859 """
860 ndims_shape(T::DataType) = ndims_index(T)
861 ndims_shape(::Type{Colon}) = 1
862 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T)
863 ndims_shape(@nospecialize T::Type{<:Union{Number, Base.AbstractCartesianIndex}}) = 0
864 ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1
865 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T)
866 ndims_shape(x) = ndims_shape(typeof(x))
867
868
869
870 """
871 instances_do_not_alias(::Type{T}) -> Bool
872
873 Is it safe to `ivdep` arrays containing elements of type `T`?
874 That is, would it be safe to write to an array full of `T` in parallel?
875 This is not true for `mutable struct`s in general, where editing one index
876 could edit other indices.
877 That is, it is not safe when different instances may alias the same memory.
878 """
879 instances_do_not_alias(::Type{T}) where {T} = Base.isbitstype(T)
880
881 """
882 indices_do_not_alias(::Type{T<:AbstractArray}) -> Bool
883
884 Is it safe to `ivdep` arrays of type `T`?
885 That is, would it be safe to write to an array of type `T` in parallel?
886 Examples where this is not true are `BitArray`s or `view(rand(6), [1,2,3,1,2,3])`.
887 That is, it is not safe whenever different indices may alias the same memory.
888 """
889 indices_do_not_alias(::Type{T}) where {T} = _indices_do_not_alias(T)
890 function indices_do_not_alias(::Type{A}) where {T, A <: Base.StridedArray{T}}
891 instances_do_not_alias(T)
892 end
893 function indices_do_not_alias(::Type{Adjoint{T, A}}) where {T, A <: AbstractArray{T}}
894 indices_do_not_alias(A)
895 end
896 function indices_do_not_alias(::Type{Transpose{T, A}}) where {T, A <: AbstractArray{T}}
897 indices_do_not_alias(A)
898 end
899 _indices_do_not_alias(::Type) = false
900 function _indices_do_not_alias(::Type{<:SubArray{<:Any, <:Any, A, I}}) where
901 {
902 A,
903 I <: Tuple{
904 Vararg{
905 Union{Integer, UnitRange, Base.ReshapedUnitRange,
906 Base.AbstractCartesianIndex}}}}
907 indices_do_not_alias(A)
908 end
909
910 """
911 defines_strides(::Type{T}) -> Bool
912
913 Is strides(::T) defined? It is assumed that types returning `true` also return a valid
914 pointer on `pointer(::T)`.
915 """
916 defines_strides(x) = defines_strides(typeof(x))
917 _defines_strides(::Type{T}, ::Type{T}) where {T} = false
918 _defines_strides(::Type{P}, ::Type{T}) where {P, T} = defines_strides(P)
919 defines_strides(::Type{T}) where {T} = _defines_strides(parent_type(T), T)
920 defines_strides(@nospecialize T::Type{<:StridedArray}) = true
921 defines_strides(@nospecialize T::Type{<:BitArray}) = true
922 @inline function defines_strides(@nospecialize T::Type{<:SubArray})
923 stride_preserving_index(fieldtype(T, :indices))
924 end
925
926 #=
927 stride_preserving_index(::Type{T}) -> Bool
928
929 Returns `True` if strides between each element can still be derived when indexing with an
930 instance of type `T`.
931 =#
932 stride_preserving_index(@nospecialize T::Type{<:AbstractRange}) = true
933 stride_preserving_index(@nospecialize T::Type{<:Number}) = true
934 @inline function stride_preserving_index(@nospecialize T::Type{<:Tuple})
935 all(map_tuple_type(stride_preserving_index, T))
936 end
937 stride_preserving_index(@nospecialize T::Type) = false
938
939 ## Stubs
940 struct BandedMatrixIndex <: ArrayInterface.MatrixIndex
941 count::Int
942 rowsize::Int
943 colsize::Int
944 bandinds::Array{Int,1}
945 bandsizes::Array{Int,1}
946 isrow::Bool
947 end
948
949 """
950 ensures_all_unique(T::Type) -> Bool
951
952 Returns `true` if all instances of type `T` are composed of a unique set of elements.
953 This does not require that `T` subtypes `AbstractSet` or implements the `AbstractSet`
954 interface.
955
956 # Examples
957
958 ```julia
959 julia> ArrayInterface.ensures_all_unique(BitSet())
960 true
961
962 julia> ArrayInterface.ensures_all_unique([])
963 false
964
965 julia> ArrayInterface.ensures_all_unique(typeof(1:10))
966 true
967
968 julia> ArrayInterface.ensures_all_unique(LinRange(1, 1, 10))
969 false
970 ```
971 """
972 ensures_all_unique(@nospecialize T::Type{<:Union{AbstractSet,AbstractDict}}) = true
973 ensures_all_unique(@nospecialize T::Type{<:LinRange}) = false
974 ensures_all_unique(@nospecialize T::Type{<:AbstractRange}) = true
975 @inline function ensures_all_unique(T::Type)
976 is_forwarding_wrapper(T) ? ensures_all_unique(parent_type(T)) : false
977 end
978 ensures_all_unique(@nospecialize(x)) = ensures_all_unique(typeof(x))
979
980 """
981 ensures_sorted(T::Type) -> Bool
982
983 Returns `true` if all instances of `T` are sorted.
984
985 # Examples
986
987 ```julia
988 julia> ArrayInterface.ensures_sorted(BitSet())
989 true
990
991 julia> ArrayInterface.ensures_sorted([])
992 false
993
994 julia> ArrayInterface.ensures_sorted(1:10)
995 true
996 ```
997 """
998 ensures_sorted(@nospecialize(T::Type{BitSet})) = true
999 ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true
1000 ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false
1001 ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x))
1002
1003 ## Extensions
1004
1005 import Requires
1006 @static if !isdefined(Base, :get_extension)
1007 function __init__()
1008 Requires.@require BandedMatrices = "aae01518-5342-5314-be14-df237901396f" begin include("../ext/ArrayInterfaceBandedMatricesExt.jl") end
1009 Requires.@require BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" begin include("../ext/ArrayInterfaceBlockBandedMatricesExt.jl") end
1010 Requires.@require GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" begin include("../ext/ArrayInterfaceGPUArraysCoreExt.jl") end
1011 Requires.@require StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticArraysCoreExt.jl") end
1012 Requires.@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin include("../ext/ArrayInterfaceCUDAExt.jl") end
1013 Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/ArrayInterfaceTrackerExt.jl") end
1014 end
1015 end
1016
1017 end # module