From 721d89bf2d6316c4735b91494658c89e66c36944 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Dec 2020 22:18:11 -0800 Subject: [PATCH] Fix various type-instabilities (#329) * Resolve instability in ifelse * Resolve instability in broadcasting over structured matrices * Generate DNE returns type-stably * Move Zero check out of pullback * Get unionall type-stably * Revert "Move Zero check out of pullback" This reverts commit d41ef75b891ff88a84cb37633010825d636ad46a. * Remove misplaced extern * Rename and document _unionall_typeof * Increment patch version number --- Project.toml | 2 +- src/rulesets/Base/array.jl | 3 ++- src/rulesets/Base/fastmath_able.jl | 12 +++++++----- src/rulesets/Base/indexing.jl | 3 ++- src/rulesets/LinearAlgebra/blas.jl | 1 - src/rulesets/LinearAlgebra/norm.jl | 15 ++++++++++++--- src/rulesets/LinearAlgebra/structured.jl | 6 ++---- src/rulesets/LinearAlgebra/symmetric.jl | 2 +- src/rulesets/LinearAlgebra/utils.jl | 16 ++++++++++++++++ 9 files changed, 43 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 56a91477f..2bd13a06a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.39" +version = "0.7.40" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 4803438f7..3b3996373 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -14,7 +14,8 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Int...) A_dims = size(A) function reshape_pullback(Ȳ) ∂A = reshape(Ȳ, A_dims) - return (NO_FIELDS, ∂A, fill(DoesNotExist(), length(dims))...) + ∂dims = broadcast(_ -> DoesNotExist(), dims) + return (NO_FIELDS, ∂A, ∂dims...) end return reshape(A, dims...), reshape_pullback end diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 4d63ab1c3..3bb0a79f4 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -66,8 +66,8 @@ let ## abs function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex}) Ω = abs(x) - signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) # `ifelse` is applied only to denominator to ensure type-stability. + signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) return Ω, _realconjtimes(signx, Δx) end @@ -108,7 +108,8 @@ let function frule((_, Δx), ::typeof(angle), x) Ω = angle(x) # `ifelse` is applied only to denominator to ensure type-stability. - ∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x)) + n = ifelse(iszero(x), one(real(x)), abs2(x)) + ∂Ω = _imagconjtimes(x, Δx) / n return Ω, ∂Ω end @@ -127,8 +128,9 @@ let function angle_pullback(ΔΩ) x, y = reim(z) Δu, Δv = reim(ΔΩ) - return (NO_FIELDS, (-y + im*x)*Δu/ifelse(iszero(z), one(z), abs2(z))) # `ifelse` is applied only to denominator to ensure type-stability. + n = ifelse(iszero(z), one(real(z)), abs2(z)) + return (NO_FIELDS, (-y + im*x)*Δu/n) end return angle(z), angle_pullback end @@ -185,14 +187,14 @@ let # `sign` function frule((_, Δx), ::typeof(sign), x) - n = ifelse(iszero(x), one(x), abs(x)) + n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n ∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im return Ω, ∂Ω end function rrule(::typeof(sign), x) - n = ifelse(iszero(x), one(x), abs(x)) + n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 8cf0067aa..78208434c 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -20,7 +20,8 @@ function rrule(::typeof(getindex), x::Array{<:Number}, inds...) @thunk(getindex_add!(zero(x))), getindex_add! ) - return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...) + īnds = broadcast(_ -> DoesNotExist(), inds) + return (NO_FIELDS, x̄, īnds...) end return y, getindex_pullback diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 1af4561ca..0f740c768 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -22,7 +22,6 @@ function rrule(::typeof(BLAS.dot), n, X, incx, Y, incy) ∂X = Zero() ∂Y = Zero() else - ΔΩ = extern(ΔΩ) ∂X = @thunk scal!(n, ΔΩ, blascopy!(n, Y, incy, _zeros(X), incx), incx) ∂Y = @thunk scal!(n, ΔΩ, blascopy!(n, X, incx, _zeros(Y), incy), incy) end diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index 53c876363..969702826 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -111,7 +111,8 @@ end function _normp_back_x(x, p, y, Δy) c = real(Δy) / y - ∂x = broadcast(x) do xi + ∂x = similar(x) + broadcast!(∂x, x) do xi a = norm(xi) ∂xi = xi * ((a / y)^(p - 2) * c) return ifelse(isfinite(∂xi), ∂xi, zero(∂xi)) @@ -181,7 +182,11 @@ function rrule( return y, norm1_pullback end -_norm1_back(x, y, Δy) = sign.(x) .* real(Δy) +function _norm1_back(x, y, Δy) + ∂x = similar(x) + ∂x .= sign.(x) .* real(Δy) + return ∂x +end ##### ##### `norm2` @@ -206,7 +211,11 @@ function _norm2_forward(x, Δx, y) ∂y = real(dot(x, Δx)) * pinv(y) return ∂y end -_norm2_back(x, y, Δy) = x .* (real(Δy) * pinv(y)) +function _norm2_back(x, y, Δy) + ∂x = similar(x) + ∂x .= x .* (real(Δy) * pinv(y)) + return ∂x +end ##### ##### `normalize` diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 753fae7ca..04f359acb 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -8,9 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T}, AbstractTriangular{T}} function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::T) where T<:SquareMatrix{<:Real} Y = A / B function slash_pullback(Ȳ) - S = T.name.wrapper ∂A = @thunk Ȳ / B' - ∂B = @thunk S(-Y' * (Ȳ / B')) + ∂B = @thunk _unionall_wrapper(T)(-Y' * (Ȳ / B')) return (NO_FIELDS, ∂A, ∂B) end return Y, slash_pullback @@ -19,8 +18,7 @@ end function rrule(::typeof(\), A::T, B::AbstractVecOrMat{<:Real}) where T<:SquareMatrix{<:Real} Y = A \ B function backslash_pullback(Ȳ) - S = T.name.wrapper - ∂A = @thunk S(-(A' \ Ȳ) * Y') + ∂A = @thunk _unionall_wrapper(T)(-(A' \ Ȳ) * Y') ∂B = @thunk A' \ Ȳ return NO_FIELDS, ∂A, ∂B end diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 44c1cb740..63fc907d6 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -9,7 +9,7 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist()) + return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, Ω.uplo), DoesNotExist()) end return Ω, HermOrSym_pullback end diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 9fa75dc6a..9aa7b5257 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -27,3 +27,19 @@ function _eyesubx!(X::AbstractMatrix) end _extract_imag(x) = complex(0, imag(x)) + +""" + _unionall_wrapper(T::Type) -> UnionAll + +Return the most general `UnionAll` type union associated with the concrete type `T`. + +# Example +```julia +julia> _unionall_wrapper(typeof(Diagonal(1:3))) +Diagonal + +julia> _unionall_wrapper(typeof(Symmetric(randn(3, 3)))) +Symmetric +```` +""" +_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper