From 16a604d6d9591b7e1eec4effcc3ce8c36123b524 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 9 Feb 2024 21:48:38 +0800 Subject: [PATCH 1/3] merge cumprod rules --- src/rulesets/Base/mapreduce.jl | 43 +++++++++++----------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index f56ffa607..3972c3d21 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -325,52 +325,37 @@ end ##### `cumprod` ##### -function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1) - y = cumprod(x; dims=dims) # does nothing unless dims == 1 - project_x = ProjectTo(x) - function cumprod_pullback_1(dy_raw) - dy = unthunk(dy_raw) - dx_thunk = InplaceableThunk( - dx -> if dims == 1 - ∇cumprod!(dx, x, dy, y) - else - dx .+= dy - end - , - @thunk project_x(if dims == 1 - ∇cumprod(x, dy, y) - else - dy - end) - ) - return (NoTangent(), dx_thunk) - end - return y, cumprod_pullback_1 -end - function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) y = cumprod(x; dims=dims) project_x = ProjectTo(x) - function cumprod_pullback_2(dy_raw) + function cumprod_pullback(dy_raw) dy = unthunk(dy_raw) dx_thunk = InplaceableThunk( dx -> if dims <= ndims(x) - vald = Val(Int(dims)) - ∇cumprod_dim!(dx, vald, x, dy, y) + if ndims(x) == 1 + ∇cumprod!(dx, x, dy, y) + else + vald = Val(Int(dims)) + ∇cumprod_dim!(dx, vald, x, dy, y) + end else dx .+= dy end , @thunk project_x(if dims <= ndims(x) - vald = Val(Int(dims)) - ∇cumprod_dim(vald, x, dy, y) + if ndims(x) == 1 + ∇cumprod(x, dy, y) + else + vald = Val(Int(dims)) + ∇cumprod_dim(vald, x, dy, y) + end else dy end) ) return (NoTangent(), dx_thunk) end - return y, cumprod_pullback_2 + return y, cumprod_pullback end function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) where {dim} From a1a70cce52179661297ed69590537ec81ac93ddd Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 9 Feb 2024 23:03:13 +0800 Subject: [PATCH 2/3] Make default dims=1 --- src/rulesets/Base/mapreduce.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 3972c3d21..b1c69838f 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -325,7 +325,7 @@ end ##### `cumprod` ##### -function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) +function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer=1) y = cumprod(x; dims=dims) project_x = ProjectTo(x) function cumprod_pullback(dy_raw) From b1d7b46b87a24ed2d9aaabeda747aedcd5227807 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 9 Feb 2024 23:34:58 +0800 Subject: [PATCH 3/3] make sure we still error if the primal code errored --- src/rulesets/Base/mapreduce.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index b1c69838f..dd2306f59 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -325,7 +325,13 @@ end ##### `cumprod` ##### -function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer=1) +function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer=nothing) + if isnothing(dims) && ndims(x)==1 + dims=1 + else + throw(UndefKeywordError(:dims) + end + y = cumprod(x; dims=dims) project_x = ProjectTo(x) function cumprod_pullback(dy_raw)