diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index f1409515c..77c868abc 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -10,7 +10,7 @@ end function rrule(::typeof(inv), x::AbstractArray) Ω = inv(x) function inv_pullback(ΔΩ) - return NoTangent(), -Ω' * ΔΩ * Ω' + return NoTangent(), Ω' * -ΔΩ * Ω' end return Ω, inv_pullback end @@ -23,12 +23,7 @@ frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB) frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC - -function rrule( - ::typeof(*), - A::AbstractVecOrMat{<:CommutativeMulNumber}, - B::AbstractVecOrMat{<:CommutativeMulNumber}, -) +function rrule(::typeof(*), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) function times_pullback(ȳ) @@ -46,8 +41,8 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411 function rrule( ::typeof(*), - A::StridedMatrix{<:CommutativeMulNumber}, - B::StridedVecOrMat{<:CommutativeMulNumber}, + A::StridedMatrix{<:Number}, + B::StridedVecOrMat{<:Number}, ) function times_pullback(ȳ) Ȳ = unthunk(ȳ) @@ -64,13 +59,7 @@ function rrule( return A * B, times_pullback end - - -function rrule( - ::typeof(*), - A::AbstractVector{<:CommutativeMulNumber}, - B::AbstractMatrix{<:CommutativeMulNumber}, -) +function rrule(::typeof(*), A::AbstractVector{<:Number}, B::AbstractMatrix{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) function times_pullback(ȳ) @@ -97,7 +86,7 @@ end ##### function rrule( - ::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber} + ::typeof(*), A::Number, B::AbstractArray{<:Number} ) project_A = ProjectTo(A) project_B = ProjectTo(B) @@ -105,7 +94,16 @@ function rrule( Ȳ = unthunk(ȳ) return ( NoTangent(), - @thunk(project_A(dot(Ȳ, B)')), + Thunk() do + if eltype(B) isa CommutativeMulNumber + project_A(dot(Ȳ, B)') + elseif ndims(B) < 3 + # https://github.com/JuliaLang/julia/issues/44152 + project_A(dot(conj(Ȳ), conj(B))) + else + project_A(dot(conj(vec(Ȳ)), conj(vec(B)))) + end + end, InplaceableThunk( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), @thunk(project_B(A' * Ȳ)), @@ -116,7 +114,7 @@ function rrule( end function rrule( - ::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber + ::typeof(*), B::AbstractArray{<:Number}, A::Number ) project_A = ProjectTo(A) project_B = ProjectTo(B) @@ -128,7 +126,17 @@ function rrule( X̄ -> mul!(X̄, conj(A), Ȳ, true, true), @thunk(project_B(A' * Ȳ)), ), - @thunk(project_A(dot(Ȳ, B)')), + # @thunk(project_A(eltype(A) isa CommutativeMulNumber ? dot(Ȳ, B)' : dot(Ȳ', B'))), + Thunk() do + if eltype(B) isa CommutativeMulNumber + project_A(dot(Ȳ, B)') + elseif ndims(B) < 3 + # https://github.com/JuliaLang/julia/issues/44152 + project_A(dot(conj(Ȳ), conj(B))) + else + project_A(dot(conj(vec(Ȳ)), conj(vec(B)))) + end + end, ) end return A * B, times_pullback @@ -217,9 +225,9 @@ end function rrule( ::typeof(muladd), - A::AbstractMatrix{<:CommutativeMulNumber}, - B::AbstractVecOrMat{<:CommutativeMulNumber}, - z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}, + A::AbstractMatrix{<:Number}, + B::AbstractVecOrMat{<:Number}, + z::Union{Number, AbstractVecOrMat{<:Number}}, ) project_A = ProjectTo(A) project_B = ProjectTo(B) @@ -255,9 +263,9 @@ end function rrule( ::typeof(muladd), - ut::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber}, - v::AbstractVector{<:CommutativeMulNumber}, - z::CommutativeMulNumber, + ut::LinearAlgebra.AdjOrTransAbsVec{<:Number}, + v::AbstractVector{<:Number}, + z::Number, ) project_ut = ProjectTo(ut) project_v = ProjectTo(v) @@ -268,11 +276,11 @@ function rrule( dy = unthunk(ȳ) ut_thunk = InplaceableThunk( dut -> dut .+= v' .* dy, - @thunk(project_ut(v' .* dy)), + @thunk(project_ut((v * dy')')), ) v_thunk = InplaceableThunk( dv -> dv .+= ut' .* dy, - @thunk(project_v(ut' .* dy)), + @thunk(project_v(ut' * dy)), ) (NoTangent(), ut_thunk, v_thunk, z isa Bool ? NoTangent() : project_z(dy)) end @@ -281,9 +289,9 @@ end function rrule( ::typeof(muladd), - u::AbstractVector{<:CommutativeMulNumber}, - vt::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber}, - z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}, + u::AbstractVector{<:Number}, + vt::LinearAlgebra.AdjOrTransAbsVec{<:Number}, + z::Union{Number, AbstractVecOrMat{<:Number}}, ) project_u = ProjectTo(u) project_vt = ProjectTo(vt) @@ -293,8 +301,8 @@ function rrule( function muladd_pullback_3(ȳ) Ȳ = unthunk(ȳ) proj = ( - @thunk(project_u(vec(sum(Ȳ .* conj.(vt), dims=2)))), - @thunk(project_vt(vec(sum(u .* conj.(Ȳ), dims=1))')), + @thunk(project_u(Ȳ * vec(vt'))), + @thunk(project_vt((Ȳ' * u)')), ) addon = if z isa Bool NoTangent() @@ -315,14 +323,14 @@ end ##### `/` ##### -function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) +function rrule(::typeof(/), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) Aᵀ, dA_pb = rrule(adjoint, A) Bᵀ, dB_pb = rrule(adjoint, B) Cᵀ, dS_pb = rrule(\, Bᵀ, Aᵀ) C, dC_pb = rrule(adjoint, Cᵀ) - function slash_pullback(Ȳ) + function slash_pullback(Ȳ) # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want - _, dC = dC_pb(Ȳ) + _, dC = dC_pb(Ȳ) _, dBᵀ, dAᵀ = dS_pb(unthunk(dC)) ∂A = last(dA_pb(unthunk(dAᵀ))) @@ -337,7 +345,7 @@ end ##### `\` ##### -function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real}) +function rrule(::typeof(\), A::AbstractVecOrMat{<:Number}, B::AbstractVecOrMat{<:Number}) project_A = ProjectTo(A) project_B = ProjectTo(B) @@ -362,14 +370,17 @@ end ##### `\`, `/` matrix-scalar_rule ##### -function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) - return A/b, ΔA/b - A*(Δb/b^2) + +function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:Number}, b::Number) + Y = A / b + return Y, muladd(Y, -Δb, ΔA) / b end -function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}) - return B/a, ΔB/a - B*(Δa/a^2) +function frule((_, Δa, ΔB), ::typeof(\), a::Number, B::AbstractArray{<:Number}) + Y = a \ B + return Y, a \ muladd(-Δa, Y, ΔB) end -function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber) +function rrule(::typeof(/), A::AbstractArray{<:Number}, b::Number) Y = A/b function slash_pullback_scalar(ȳ) Ȳ = unthunk(ȳ) @@ -377,17 +388,28 @@ function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::Commuta dA -> dA .+= Ȳ ./ conj(b), @thunk(Ȳ / conj(b)), ) - bthunk = @thunk(-dot(A,Ȳ) / conj(b^2)) + bthunk = @thunk(-dot(Y,Ȳ) / conj(b)) return (NoTangent(), Athunk, bthunk) end return Y, slash_pullback_scalar end -function rrule(::typeof(\), b::CommutativeMulNumber, A::AbstractArray{<:CommutativeMulNumber}) - Y, back = rrule(/, A, b) - function backslash_pullback(dY) # just reverses the arguments! - d0, dA, db = back(dY) - return (d0, db, dA) +function rrule(::typeof(\), b::Number, A::AbstractArray{<:Number}) + Y = b \ A + function backslash_pullback(ȳ) + Ȳ = unthunk(ȳ) + Athunk = InplaceableThunk( + dA -> dA .+= conj(b) .\ Ȳ, + @thunk(conj(b) \ Ȳ), + ) + bthunk = if eltype(Y) isa CommutativeMulNumber + @thunk(-conj(b) \ dot(Y, Ȳ)) + else + # NOTE: dot(Ȳ', Y') currently incorrect for non-commutative numbers + # https://github.com/JuliaLang/julia/issues/44152 + @thunk(-conj(b) \ dot(conj(Ȳ), conj(Y))) + end + return (NoTangent(), bthunk, Athunk) end return Y, backslash_pullback end diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index e95ff6eff..668cf9e28 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -76,13 +76,13 @@ end @scalar_rule hypot(x::Real) sign(x) -function frule((_, Δz), ::typeof(hypot), z::Complex) +function frule((_, Δz), ::typeof(hypot), z::Number) Ω = hypot(z) ∂Ω = realdot(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω) return Ω, ∂Ω end -function rrule(::typeof(hypot), z::Complex) +function rrule(::typeof(hypot), z::Number) Ω = hypot(z) function hypot_pullback(ΔΩ) return (NoTangent(), (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z) @@ -90,8 +90,24 @@ function rrule(::typeof(hypot), z::Complex) return (Ω, hypot_pullback) end -@scalar_rule fma(x, y, z) (y, x, true) -@scalar_rule muladd(x, y, z) (y, x, true) +@scalar_rule fma(x, y::CommutativeMulNumber, z) (y, x, true) +function frule((_, Δx, Δy, Δz), ::typeof(fma), x::Number, y::Number, z::Number) + return fma(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz)) +end +function rrule(::typeof(fma), x::Number, y::Number, z::Number) + projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) + fma_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) + fma(x, y, z), fma_pullback +end +@scalar_rule muladd(x, y::CommutativeMulNumber, z) (y, x, true) +function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number) + return muladd(x, y, z), muladd(Δx, y, muladd(x, Δy, Δz)) +end +function rrule(::typeof(muladd), x::Number, y::Number, z::Number) + projectx, projecty, projectz = ProjectTo(x), ProjectTo(y), ProjectTo(z) + muladd_pullback(ΔΩ) = NoTangent(), projectx(ΔΩ * y'), projecty(x' * ΔΩ), projectz(ΔΩ) + muladd(x, y, z), muladd_pullback +end @scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent()) @scalar_rule( mod(x, y), @@ -105,50 +121,50 @@ end @scalar_rule(ldexp(x, y), (2^y, NoTangent())) # Can't multiply though sqrt in acosh because of negative complex case for x -@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1)) -@scalar_rule acoth(x) inv(1 - x ^ 2) -@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) +@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1)) +@scalar_rule acoth(x::CommutativeMulNumber) inv(1 - x ^ 2) +@scalar_rule acsch(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) @scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2))) -@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2))) -@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1)) -@scalar_rule atanh(x) inv(1 - x ^ 2) +@scalar_rule asech(x::CommutativeMulNumber) -(inv(x * sqrt(1 - x ^ 2))) +@scalar_rule asinh(x::CommutativeMulNumber) inv(sqrt(x ^ 2 + 1)) +@scalar_rule atanh(x::CommutativeMulNumber) inv(1 - x ^ 2) -@scalar_rule acosd(x) -inv(deg2rad(sqrt(1 - x ^ 2))) -@scalar_rule acotd(x) -inv(deg2rad(1 + x ^ 2)) -@scalar_rule acscd(x) -inv(deg2rad(x^2 * sqrt(1 - x ^ -2))) +@scalar_rule acosd(x::CommutativeMulNumber) -inv(deg2rad(sqrt(1 - x ^ 2))) +@scalar_rule acotd(x::CommutativeMulNumber) -inv(deg2rad(1 + x ^ 2)) +@scalar_rule acscd(x::CommutativeMulNumber) -inv(deg2rad(x^2 * sqrt(1 - x ^ -2))) @scalar_rule acscd(x::Real) -inv(deg2rad(abs(x) * sqrt(x ^ 2 - 1))) -@scalar_rule asecd(x) inv(deg2rad(x ^ 2 * sqrt(1 - x ^ -2))) +@scalar_rule asecd(x::CommutativeMulNumber) inv(deg2rad(x ^ 2 * sqrt(1 - x ^ -2))) @scalar_rule asecd(x::Real) inv(deg2rad(abs(x) * sqrt(x ^ 2 - 1))) -@scalar_rule asind(x) inv(deg2rad(sqrt(1 - x ^ 2))) -@scalar_rule atand(x) inv(deg2rad(1 + x ^ 2)) - -@scalar_rule cot(x) -((1 + Ω ^ 2)) -@scalar_rule coth(x) -(csch(x) ^ 2) -@scalar_rule cotd(x) -deg2rad(1 + Ω ^ 2) -@scalar_rule csc(x) -Ω * cot(x) -@scalar_rule cscd(x) -deg2rad(Ω * cotd(x)) -@scalar_rule csch(x) -(coth(x)) * Ω -@scalar_rule sec(x) Ω * tan(x) -@scalar_rule secd(x) deg2rad(Ω * tand(x)) -@scalar_rule sech(x) -(tanh(x)) * Ω - -@scalar_rule acot(x) -(inv(1 + x ^ 2)) -@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) +@scalar_rule asind(x::CommutativeMulNumber) inv(deg2rad(sqrt(1 - x ^ 2))) +@scalar_rule atand(x::CommutativeMulNumber) inv(deg2rad(1 + x ^ 2)) + +@scalar_rule cot(x::CommutativeMulNumber) -((1 + Ω ^ 2)) +@scalar_rule coth(x::CommutativeMulNumber) -(csch(x) ^ 2) +@scalar_rule cotd(x::CommutativeMulNumber) -deg2rad(1 + Ω ^ 2) +@scalar_rule csc(x::CommutativeMulNumber) -Ω * cot(x) +@scalar_rule cscd(x::CommutativeMulNumber) -deg2rad(Ω * cotd(x)) +@scalar_rule csch(x::CommutativeMulNumber) -(coth(x)) * Ω +@scalar_rule sec(x::CommutativeMulNumber) Ω * tan(x) +@scalar_rule secd(x::CommutativeMulNumber) deg2rad(Ω * tand(x)) +@scalar_rule sech(x::CommutativeMulNumber) -(tanh(x)) * Ω + +@scalar_rule acot(x::CommutativeMulNumber) -(inv(1 + x ^ 2)) +@scalar_rule acsc(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) @scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1))) -@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2)) +@scalar_rule asec(x::CommutativeMulNumber) inv(x ^ 2 * sqrt(1 - x ^ -2)) @scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1)) -@scalar_rule cosd(x) -deg2rad(sind(x)) -@scalar_rule cospi(x) -π * sinpi(x) -@scalar_rule sind(x) deg2rad(cosd(x)) -@scalar_rule sinpi(x) π * cospi(x) -@scalar_rule tand(x) deg2rad(1 + Ω ^ 2) +@scalar_rule cosd(x::CommutativeMulNumber) -deg2rad(sind(x)) +@scalar_rule cospi(x::CommutativeMulNumber) -π * sinpi(x) +@scalar_rule sind(x::CommutativeMulNumber) deg2rad(cosd(x)) +@scalar_rule sinpi(x::CommutativeMulNumber) π * cospi(x) +@scalar_rule tand(x::CommutativeMulNumber) deg2rad(1 + Ω ^ 2) -@scalar_rule sinc(x) cosc(x) +@scalar_rule sinc(x::CommutativeMulNumber) cosc(x) # the position of the minus sign below warrants the correct type for π -@scalar_rule sincospi(x) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) +@scalar_rule sincospi(x::CommutativeMulNumber) @setup((sinpix, cospix) = Ω) (π * cospix) (π * (-sinpix)) @scalar_rule( clamp(x, low, high), @@ -158,7 +174,22 @@ end ), (!(islow | ishigh), islow, ishigh), ) -@scalar_rule x \ y (-(Ω / x), one(y) / x) + +@scalar_rule x::CommutativeMulNumber \ y::CommutativeMulNumber (-(x \ Ω), x \ one(y)) +function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number) + Ω = x \ y + return Ω, x \ muladd(-Δx, Ω, Δy) +end +function rrule(::typeof(\), x::Number, y::Number) + Ω = x \ y + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function backslash_pullback(ΔΩ) + ∂y = x' \ ΔΩ + return NoTangent(), project_x(-∂y * Ω'), project_y(∂y) + end + return Ω, backslash_pullback +end function frule((_, ẏ), ::typeof(identity), x) return (x, ẏ) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 552392cb7..beb01b709 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -9,69 +9,82 @@ let ## for the rules for `sin` and `cos` ## See issue: https://github.com/JuliaDiff/ChainRules.jl/issues/291 ## sin - function rrule(::typeof(sin), x::Number) + function rrule(::typeof(sin), x::CommutativeMulNumber) sinx, cosx = sincos(x) sin_pullback(Δy) = (NoTangent(), cosx' * Δy) return (sinx, sin_pullback) end - function frule((_, Δx), ::typeof(sin), x::Number) + function frule((_, Δx), ::typeof(sin), x::CommutativeMulNumber) sinx, cosx = sincos(x) return (sinx, cosx * Δx) end ## cos - function rrule(::typeof(cos), x::Number) + function rrule(::typeof(cos), x::CommutativeMulNumber) sinx, cosx = sincos(x) cos_pullback(Δy) = (NoTangent(), -sinx' * Δy) return (cosx, cos_pullback) end - function frule((_, Δx), ::typeof(cos), x::Number) + function frule((_, Δx), ::typeof(cos), x::CommutativeMulNumber) sinx, cosx = sincos(x) return (cosx, -sinx * Δx) end - @scalar_rule tan(x) 1 + Ω ^ 2 + @scalar_rule tan(x::CommutativeMulNumber) 1 + Ω ^ 2 # Trig-Hyperbolic - @scalar_rule cosh(x) sinh(x) - @scalar_rule sinh(x) cosh(x) - @scalar_rule tanh(x) 1 - Ω ^ 2 + @scalar_rule cosh(x::CommutativeMulNumber) sinh(x) + @scalar_rule sinh(x::CommutativeMulNumber) cosh(x) + @scalar_rule tanh(x::CommutativeMulNumber) 1 - Ω ^ 2 # Trig- Inverses - @scalar_rule acos(x) -(inv(sqrt(1 - x ^ 2))) - @scalar_rule asin(x) inv(sqrt(1 - x ^ 2)) - @scalar_rule atan(x) inv(1 + x ^ 2) + @scalar_rule acos(x::CommutativeMulNumber) -(inv(sqrt(1 - x ^ 2))) + @scalar_rule asin(x::CommutativeMulNumber) inv(sqrt(1 - x ^ 2)) + @scalar_rule atan(x::CommutativeMulNumber) inv(1 + x ^ 2) # Trig-Multivariate - @scalar_rule atan(y, x) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) - @scalar_rule sincos(x) @setup((sinx, cosx) = Ω) cosx -sinx + @scalar_rule atan(y::Real, x::Real) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) + @scalar_rule sincos(x::CommutativeMulNumber) @setup((sinx, cosx) = Ω) cosx -sinx # exponents - @scalar_rule cbrt(x) inv(3 * Ω ^ 2) - @scalar_rule inv(x) -(Ω ^ 2) - @scalar_rule sqrt(x) inv(2Ω) # gradient +Inf at x==0 - @scalar_rule exp(x) Ω - @scalar_rule exp10(x) logten * Ω - @scalar_rule exp2(x) logtwo * Ω - @scalar_rule expm1(x) exp(x) - @scalar_rule log(x) inv(x) - @scalar_rule log10(x) inv(logten * x) - @scalar_rule log1p(x) inv(x + 1) - @scalar_rule log2(x) inv(logtwo * x) + @scalar_rule cbrt(x::CommutativeMulNumber) inv(3 * Ω ^ 2) + @scalar_rule inv(x::CommutativeMulNumber) -(Ω ^ 2) + function frule((_, Δx), ::typeof(inv), x::Number) + Ω = inv(x) + return Ω, Ω * -Δx * Ω + end + function rrule(::typeof(inv), x::Number) + Ω = inv(x) + project_x = ProjectTo(x) + function inv_pullback(ΔΩ) + Ω′ = conj(Ω) + return NoTangent(), project_x(Ω′ * -ΔΩ * Ω′) + end + return Ω, inv_pullback + end + @scalar_rule sqrt(x::CommutativeMulNumber) inv(2Ω) # gradient +Inf at x==0 + @scalar_rule exp(x::CommutativeMulNumber) Ω + @scalar_rule exp10(x::CommutativeMulNumber) logten * Ω + @scalar_rule exp2(x::CommutativeMulNumber) logtwo * Ω + @scalar_rule expm1(x::CommutativeMulNumber) exp(x) + @scalar_rule log(x::CommutativeMulNumber) inv(x) + @scalar_rule log10(x::CommutativeMulNumber) inv(logten * x) + @scalar_rule log1p(x::CommutativeMulNumber) inv(x + 1) + @scalar_rule log2(x::CommutativeMulNumber) inv(logtwo * x) # Unary complex functions ## abs - function frule((_, Δx), ::typeof(abs), x::Union{Real, Complex}) + function frule((_, Δx), ::typeof(abs), x::Number) Ω = abs(x) # `ifelse` is applied only to denominator to ensure type-stability. signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) return Ω, realdot(signx, Δx) end - function rrule(::typeof(abs), x::Union{Real, Complex}) + function rrule(::typeof(abs), x::Number) Ω = abs(x) function abs_pullback(ΔΩ) signx = x isa Real ? sign(x) : x / ifelse(iszero(x), one(Ω), Ω) @@ -81,23 +94,23 @@ let end ## abs2 - function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex}) + function frule((_, Δz), ::typeof(abs2), z::Number) return abs2(z), 2 * realdot(z, Δz) end - function rrule(::typeof(abs2), z::Union{Real, Complex}) + function rrule(::typeof(abs2), z::Number) function abs2_pullback(ΔΩ) Δu = real(ΔΩ) - return (NoTangent(), 2Δu*z) + return (NoTangent(), 2Δu * z) end return abs2(z), abs2_pullback end ## conj - function frule((_, Δz), ::typeof(conj), z::Union{Real, Complex}) + function frule((_, Δz), ::typeof(conj), z::Number) return conj(z), conj(Δz) end - function rrule(::typeof(conj), z::Union{Real, Complex}) + function rrule(::typeof(conj), z::Number) function conj_pullback(ΔΩ) return (NoTangent(), conj(ΔΩ)) end @@ -105,7 +118,7 @@ let end ## angle - function frule((_, Δx), ::typeof(angle), x) + function frule((_, Δx), ::typeof(angle), x::Union{Real,Complex}) Ω = angle(x) # `ifelse` is applied only to denominator to ensure type-stability. n = ifelse(iszero(x), one(real(x)), abs2(x)) @@ -143,14 +156,14 @@ let ::typeof(hypot), x::T, y::T, - ) where {T<:Union{Real,Complex}} + ) where {T<:Number} Ω = hypot(x, y) n = ifelse(iszero(Ω), one(Ω), Ω) ∂Ω = (realdot(x, Δx) + realdot(y, Δy)) / n return Ω, ∂Ω end - function rrule(::typeof(hypot), x::T, y::T) where {T<:Union{Real,Complex}} + function rrule(::typeof(hypot), x::T, y::T) where {T<:Number} Ω = hypot(x, y) function hypot_pullback(ΔΩ) c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω) @@ -161,11 +174,24 @@ let @scalar_rule x + y (true, true) @scalar_rule x - y (true, -1) - @scalar_rule x / y (one(x) / y, -(Ω / y)) - + @scalar_rule x / y::CommutativeMulNumber (one(x) / y, -(Ω / y)) + function frule((_, Δx, Δy), ::typeof(/), x::Number, y::Number) + Ω = x / y + return Ω, muladd(Δx, Ω, -Δy) / y + end + function rrule(::typeof(/), x::Number, y::Number) + Ω = x / y + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function slash_pullback(ΔΩ) + ∂x = ΔΩ / y' + return NoTangent(), project_x(∂x), project_y(Ω' * -∂x) + end + return Ω, slash_pullback + end ## power # literal_pow is in base.jl - function frule((_, Δx, Δp), ::typeof(^), x::Number, p::Number) + function frule((_, Δx, Δp), ::typeof(^), x::CommutativeMulNumber, p::CommutativeMulNumber) y = x ^ p _dx = _pow_grad_x(x, p, float(y)) if iszero(Δp) @@ -178,7 +204,7 @@ let end end - function rrule(::typeof(^), x::Number, p::Number) + function rrule(::typeof(^), x::CommutativeMulNumber, p::CommutativeMulNumber) y = x^p project_x = ProjectTo(x) project_p = ProjectTo(p) @@ -209,26 +235,27 @@ let @scalar_rule -x -1 ## `sign` - function frule((_, Δx), ::typeof(sign), x) + function frule((_, Δx), ::typeof(sign), x::Number) n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n - ∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im + d = dot(Ω, Δx) + ∂Ω = Ω * ((d - real(d)) / n) return Ω, ∂Ω end - function rrule(::typeof(sign), x) + function rrule(::typeof(sign), x::Number) n = ifelse(iszero(x), one(real(x)), abs(x)) Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) - ∂x = Ω * (_imagconjtimes(Ω, ΔΩ) / n) * im + d = dot(Ω, ΔΩ) + ∂x = Ω * ((d - real(d)) / n) return (NoTangent(), ∂x) end return Ω, sign_pullback end - # product rule requires special care for arguments where `mul` is non-commutative function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) - # Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more + # Optimized version of `Δx * y + x * Δy`. Also, it is potentially more # accurate on machines with FMA instructions, since there are only two # rounding operations, one in `muladd/fma` and the other in `*`. ∂xy = muladd(Δx, y, x * Δy) @@ -250,9 +277,9 @@ let ΔΩ = unthunk(Ω̇) return ( NoTangent(), - ProjectTo(x)(ΔΩ * y' * z'), + ProjectTo(x)(ΔΩ * z' * y'), ProjectTo(y)(x' * ΔΩ * z'), - ProjectTo(z)(x' * y' * ΔΩ), + ProjectTo(z)(y' * x' * ΔΩ), ) end return x * y * z, times_pullback3 diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 13d77d302..2bdc56c9f 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -34,9 +34,9 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N z = adjoint(x) * Ay function dot_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - dx = @thunk project_x(conj(ΔΩ) .* Ay) - dA = @thunk project_A(ΔΩ .* x .* adjoint(y)) - dy = @thunk project_y(ΔΩ .* (adjoint(A) * x)) + dx = @thunk project_x(Ay .* conj(ΔΩ)) + dA = @thunk project_A(x .* ΔΩ .* adjoint(y)) + dy = @thunk project_y((adjoint(A) * x) .* ΔΩ) return (NoTangent(), dx, dA, dy) end dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent()) @@ -99,15 +99,14 @@ function frule((_, Δa, Δb), ::typeof(cross), a::AbstractVector, b::AbstractVec return cross(a, b), cross(Δa, b) .+ cross(a, Δb) end -# TODO: support complex vectors -function rrule(::typeof(cross), a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) +function rrule(::typeof(cross), a::AbstractVector, b::AbstractVector) project_a = ProjectTo(a) project_b = ProjectTo(b) Ω = cross(a, b) function cross_pullback(Ω̄) ΔΩ = unthunk(Ω̄) - da = @thunk(project_a(cross(b, ΔΩ))) - db = @thunk(project_b(cross(ΔΩ, a))) + da = @thunk(project_a(eltype(b) <: Real ? cross(b, ΔΩ) : -cross(ΔΩ, vec(b')))) + db = @thunk(project_b(eltype(a) <: Real ? cross(ΔΩ, a) : -cross(vec(a'), ΔΩ))) return (NoTangent(), da, db) end return Ω, cross_pullback diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 1a5a4bcd0..b084ba46e 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -38,7 +38,7 @@ end _diagview(x::Diagonal) = x.diag _diagview(x::AbstractMatrix) = view(x, diagind(x)) _diagview(x::Tangent{<:Diagonal}) = x.diag -function ChainRulesCore.rrule(::typeof(sqrt), d::Diagonal) +function ChainRulesCore.rrule(::typeof(sqrt), d::Diagonal{<:CommutativeMulNumber}) y = sqrt(d) @assert y isa Diagonal function sqrt_pullback(Δ) @@ -99,13 +99,13 @@ function _diagm_back(p, ȳ) return Tangent{typeof(p)}(second = d) end -function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) +function rrule(::typeof(*), D::Diagonal{<:Number}, V::AbstractVector{<:Number}) project_D = ProjectTo(D) project_V = ProjectTo(V) function times_pullback(ȳ) Ȳ = unthunk(ȳ) - dD = @thunk(project_D(Diagonal(Ȳ .* V))) - dV = @thunk(project_V(D * Ȳ)) + dD = @thunk(project_D(Diagonal(Ȳ .* conj.(V)))) + dV = @thunk(project_V(conj.(D.diag) .* Ȳ)) return (NoTangent(), dD, dV) end return D * V, times_pullback @@ -250,7 +250,7 @@ end _diag_view(X) = view(X, diagind(X)) _diag_view(X::Diagonal) = parent(X) #Diagonal wraps a Vector of just Diagonal elements -function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) +function rrule(::typeof(det), X::Union{Diagonal{T}, AbstractTriangular{T}}) where {T<:CommutativeMulNumber} y = det(X) s = conj!(y ./ _diag_view(X)) function det_pullback(ȳ) @@ -259,7 +259,7 @@ function rrule(::typeof(det), X::Union{Diagonal, AbstractTriangular}) return y, det_pullback end -function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) +function rrule(::typeof(logdet), X::Union{Diagonal{T}, AbstractTriangular{T}}) where {T<:CommutativeMulNumber} y = logdet(X) s = conj!(one(eltype(X)) ./ _diag_view(X)) function logdet_pullback(ȳ) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 77dff1827..ccbfa5d37 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -73,10 +73,12 @@ end end - @testset "Unary complex functions" begin - for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) + @testset "Unary non-real functions" begin + for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im, Quaternion(0.0,0,0,0), Quaternion(0.6,1,-0.2,0.5)) test_scalar(real, x) - test_scalar(imag, x) + if !isa(x, Quaternion) + test_scalar(imag, x) + end test_scalar(hypot, x) test_scalar(adjoint, x) end @@ -92,9 +94,12 @@ @testset "*(x, y) (scalar)" begin # This is pretty important so testing it fairly heavily - test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im) + test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im, Quaternion(-0.48, -0.15, 0.63, 0.14)) @testset "($x) * ($y)" for x in test_points, y in test_points + if (x isa Complex && y isa Quaternion) || (x isa Quaternion && y isa Complex) + continue + end # ensure all complex if any complex for FiniteDifferences x, y = Base.promote(x, y) @@ -103,6 +108,10 @@ test_rrule(*, x, y) end @testset "*($x, $y, ...)" for x in test_points, y in test_points + if (x isa Complex && y isa Quaternion) || (x isa Quaternion && y isa Complex) + continue + end + # This promotion is only for FiniteDifferences, the rules allow mixtures: x, y = Base.promote(x, y) @@ -119,7 +128,7 @@ end end - @testset "\\(x::$T, y::$T) (scalar)" for T in (Float64, ComplexF64) + @testset "\\(x::$T, y::$T) (scalar)" for T in (Float64, ComplexF64, QuaternionF64) test_frule(*, randn(T), randn(T)) test_rrule(*, randn(T), randn(T)) end @@ -132,7 +141,7 @@ test_rrule(mod, (rand(0:10) + .6rand() + .2) * base, base) end - @testset "identity" for T in (Float64, ComplexF64) + @testset "identity" for T in (Float64, ComplexF64, QuaternionF64) test_frule(identity, randn(T)) test_frule(identity, randn(T, 4)) test_frule(identity, Tuple(randn(T, 3))) @@ -142,14 +151,14 @@ test_rrule(identity, Tuple(randn(T, 3))) end - @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im) + @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im, Quaternion(2.0,0,3,0)) test_scalar(one, x) test_scalar(zero, x) end - @testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64) - test_frule(muladd, 10randn(), randn(), randn()) - test_rrule(muladd, 10randn(), randn(), randn()) + @testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64, QuaternionF64) + test_frule(muladd, 10randn(T), randn(T), randn(T)) + test_rrule(muladd, 10randn(T), randn(T), randn(T)) end @testset "fma" begin diff --git a/test/test_helpers.jl b/test/test_helpers.jl index 83de7b35d..0261165bf 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -75,6 +75,63 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x) return make_two_vec(x), make_two_vec_pullback end +# Minimal quaternion implementation for testing rules that accept numbers that don't commute +# under multiplication +# adapted from Base Julia +# https://github.com/JuliaLang/julia/blob/bb5b98e72a151c41471d8cc14cacb495d647fb7f/test/testhelpers/Quaternions.jl + +export Quaternion + +struct Quaternion{T<:Real} <: Number + s::T + v1::T + v2::T + v3::T +end +const QuaternionF64 = Quaternion{Float64} +Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...) +Quaternion{T}(s::Real) where {T} = Quaternion(T(s), zero(T), zero(T), zero(T)) +Base.convert(::Type{Quaternion{T}}, s::Real) where {T <: Real} = + Quaternion{T}(convert(T, s), zero(T), zero(T), zero(T)) +Base.promote_rule(::Type{Quaternion{S}}, ::Type{T}) where {S<:Real,T<:Real} = Quaternion{Base.promote_type(S,T)} +Base.abs2(q::Quaternion) = q.s*q.s + q.v1*q.v1 + q.v2*q.v2 + q.v3*q.v3 +Base.float(z::Quaternion{T}) where T = Quaternion(float(z.s), float(z.v1), float(z.v2), float(z.v3)) +Base.abs(q::Quaternion) = sqrt(abs2(q)) +Base.real(::Type{Quaternion{T}}) where {T} = T +Base.real(q::Quaternion) = q.s +Base.conj(q::Quaternion) = Quaternion(q.s, -q.v1, -q.v2, -q.v3) +Base.isfinite(q::Quaternion) = isfinite(q.s) & isfinite(q.v1) & isfinite(q.v2) & isfinite(q.v3) +Base.zero(::Type{Quaternion{T}}) where T = Quaternion{T}(zero(T), zero(T), zero(T), zero(T)) +Base.:(+)(ql::Quaternion, qr::Quaternion) = + Quaternion(ql.s + qr.s, ql.v1 + qr.v1, ql.v2 + qr.v2, ql.v3 + qr.v3) +Base.:(-)(ql::Quaternion, qr::Quaternion) = + Quaternion(ql.s - qr.s, ql.v1 - qr.v1, ql.v2 - qr.v2, ql.v3 - qr.v3) +Base.:(-)(q::Quaternion) = Quaternion(-q.s, -q.v1, -q.v2, -q.v3) +Base.:(*)(q::Quaternion, w::Quaternion) = Quaternion(q.s*w.s - q.v1*w.v1 - q.v2*w.v2 - q.v3*w.v3, + q.s*w.v1 + q.v1*w.s + q.v2*w.v3 - q.v3*w.v2, + q.s*w.v2 - q.v1*w.v3 + q.v2*w.s + q.v3*w.v1, + q.s*w.v3 + q.v1*w.v2 - q.v2*w.v1 + q.v3*w.s) +Base.:(*)(q::Quaternion, r::Real) = Quaternion(q.s*r, q.v1*r, q.v2*r, q.v3*r) +Base.:(*)(q::Quaternion, b::Bool) = b * q # remove method ambiguity +Base.:(/)(q::Quaternion, w::Quaternion) = q * conj(w) * (1.0 / abs2(w)) +Base.:(\)(q::Quaternion, w::Quaternion) = conj(q) * w * (1.0 / abs2(q)) +# rand implementations adapted from https://github.com/JuliaGeometry/Quaternions.jl/pull/42/files +function Base.rand(rng::AbstractRNG, ::Random.SamplerType{Quaternion{T}}) where {T<:Real} + return Quaternion{T}(rand(rng, T), rand(rng, T), rand(rng, T), rand(rng, T)) +end +function Base.randn(rng::AbstractRNG, ::Type{Quaternion{T}}) where {T<:AbstractFloat} + return Quaternion{T}( + randn(rng, T) / 2, + randn(rng, T) / 2, + randn(rng, T) / 2, + randn(rng, T) / 2, + ) +end +(project::ProjectTo{<:Real})(dx::Quaternion) = project(real(dx)) +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, q::Quaternion{Float64}) + return Quaternion(rand(rng, -9:0.1:9), rand(rng, -9:0.1:9), rand(rng, -9:0.1:9), rand(rng, -9:0.1:9)) +end + @testset "test_helpers.jl" begin @testset "Multiplier" begin