From ffbaa5fecca8da39f20aeb66cc6e4edf3e0c3f11 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 24 Oct 2021 12:13:03 +0200 Subject: [PATCH] Use `RealDot.realdot` (#542) * Use `RealDot.realdot` * More `realdot` --- Project.toml | 4 +++- src/ChainRules.jl | 1 + src/rulesets/Base/base.jl | 2 +- src/rulesets/Base/fastmath_able.jl | 6 +++--- src/rulesets/Base/mapreduce.jl | 6 +++--- src/rulesets/Base/utils.jl | 7 ------- src/rulesets/LinearAlgebra/blas.jl | 4 ++-- src/rulesets/LinearAlgebra/factorization.jl | 2 +- src/rulesets/LinearAlgebra/norm.jl | 6 +++--- src/rulesets/LinearAlgebra/symmetric.jl | 2 +- 10 files changed, 18 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 06bf43758..54ad35131 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.11.6" +version = "1.12.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] @@ -15,6 +16,7 @@ ChainRulesTestUtils = "1" Compat = "3.35" FiniteDifferences = "0.12.8" JuliaInterpreter = "0.8" +RealDot = "0.1" StaticArrays = "1.2" julia = "1" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index b951a0e85..f2806fd3d 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -6,6 +6,7 @@ using Compat using LinearAlgebra using LinearAlgebra.BLAS using Random +using RealDot: realdot using Statistics # Basically everything this package does is overloading these, so we make an exception diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 5b9237bcf..4118aa37e 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -78,7 +78,7 @@ end function frule((_, Δz), ::typeof(hypot), z::Complex) Ω = hypot(z) - ∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω) + ∂Ω = realdot(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω) return Ω, ∂Ω end diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index d1839518a..2ef7a1f9c 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -68,7 +68,7 @@ let Ω = abs(x) # `ifelse` is applied only to denominator to ensure type-stability. signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) - return Ω, _realconjtimes(signx, Δx) + return Ω, realdot(signx, Δx) end function rrule(::typeof(abs), x::Union{Real, Complex}) @@ -82,7 +82,7 @@ let ## abs2 function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex}) - return abs2(z), 2 * _realconjtimes(z, Δz) + return abs2(z), 2 * realdot(z, Δz) end function rrule(::typeof(abs2), z::Union{Real, Complex}) @@ -146,7 +146,7 @@ let ) where {T<:Union{Real,Complex}} Ω = hypot(x, y) n = ifelse(iszero(Ω), one(Ω), Ω) - ∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n + ∂Ω = (realdot(x, Δx) + realdot(y, Δy)) / n return Ω, ∂Ω end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 41cbca20c..0a658454c 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -111,13 +111,13 @@ function frule( ẋ = unthunk(Δx) y = sum(abs2, x; dims=dims) ∂y = if dims isa Colon - 2 * real(dot(x, ẋ)) + 2 * realdot(x, ẋ) elseif VERSION ≥ v"1.2" # multi-iterator mapreduce introduced in v1.2 mapreduce(+, x, ẋ; dims=dims) do xi, dxi - 2 * _realconjtimes(xi, dxi) + 2 * realdot(xi, dxi) end else - 2 * sum(_realconjtimes.(x, ẋ); dims=dims) + 2 * sum(realdot.(x, ẋ); dims=dims) end return y, ∂y end diff --git a/src/rulesets/Base/utils.jl b/src/rulesets/Base/utils.jl index 30bbdd47e..9dd991742 100644 --- a/src/rulesets/Base/utils.jl +++ b/src/rulesets/Base/utils.jl @@ -1,10 +1,3 @@ -# real(conj(x) * y) avoiding computing the imaginary part if possible -@inline _realconjtimes(x, y) = real(conj(x) * y) -@inline _realconjtimes(x::Complex, y::Complex) = muladd(real(x), real(y), imag(x) * imag(y)) -@inline _realconjtimes(x::Real, y::Complex) = x * real(y) -@inline _realconjtimes(x::Complex, y::Real) = real(x) * y -@inline _realconjtimes(x::Real, y::Real) = x * y - # imag(conj(x) * y) avoiding computing the real part if possible @inline _imagconjtimes(x, y) = imag(conj(x) * y) @inline function _imagconjtimes(x::Complex, y::Complex) diff --git a/src/rulesets/LinearAlgebra/blas.jl b/src/rulesets/LinearAlgebra/blas.jl index 509e15847..c0f90a31d 100644 --- a/src/rulesets/LinearAlgebra/blas.jl +++ b/src/rulesets/LinearAlgebra/blas.jl @@ -41,7 +41,7 @@ function frule((_, Δx), ::typeof(BLAS.nrm2), x) ∂Ω = if x isa Real BLAS.dot(x, Δx) / s else - sum(y -> _realconjtimes(y...), zip(x, Δx)) / s + sum(y -> realdot(y...), zip(x, Δx)) / s end return Ω, ∂Ω end @@ -72,7 +72,7 @@ end function frule((_, Δx), ::typeof(BLAS.asum), x) ∂Ω = sum(zip(x, Δx)) do (xi, Δxi) - return _realconjtimes(_signcomp(xi), Δxi) + return realdot(_signcomp(xi), Δxi) end return BLAS.asum(x), ∂Ω end diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index f5342bfad..d7530a16e 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -347,7 +347,7 @@ function _eigen_norm_phase_fwd!(∂V, A, V) @inbounds for i in axes(V, 2) v, ∂v = @views V[:, i], ∂V[:, i] # account for unit normalization - ∂c_norm = -real(dot(v, ∂v)) + ∂c_norm = -realdot(v, ∂v) if eltype(V) <: Real ∂c = ∂c_norm else diff --git a/src/rulesets/LinearAlgebra/norm.jl b/src/rulesets/LinearAlgebra/norm.jl index c2114c867..2cc69355a 100644 --- a/src/rulesets/LinearAlgebra/norm.jl +++ b/src/rulesets/LinearAlgebra/norm.jl @@ -14,7 +14,7 @@ function frule((_, ẋ), ::typeof(norm), x::Number, p::Real) zero(real(x)) * zero(real(Δx)) else signx = x isa Real ? sign(x) : x * pinv(y) - _realconjtimes(signx, Δx) + realdot(signx, Δx) end return y, ∂y end @@ -235,7 +235,7 @@ function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number}) end function _norm2_forward(x, Δx, y) - ∂y = real(dot(x, Δx)) * pinv(y) + ∂y = realdot(x, Δx) * pinv(y) return ∂y end function _norm2_back(x, y, Δy) @@ -280,7 +280,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}) LinearAlgebra.__normalize!(y, nrm) function normalize_pullback(ȳ) Δy = unthunk(ȳ) - ∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm) + ∂x = (Δy .- realdot(y, Δy) .* y) .* pinv(nrm) return (NoTangent(), ∂x) end normalize_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent()) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 573d4fcf4..67693575e 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -218,7 +218,7 @@ function frule( # diag(U' * tmp) without computing matrix product ∂λ = similar(λ) @inbounds for i in eachindex(λ) - ∂λ[i] = @views real(dot(U[:, i], tmp[:, i])) + ∂λ[i] = @views realdot(U[:, i], tmp[:, i]) end return λ, ∂λ end