-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove ambig in nondiff cumprod(::Vector{AbstractBool}) on julia 1.6 #781
Conversation
it fixed it., Yay. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me - the only question I have is whether you considered merging the rrule
s in
ChainRules.jl/src/rulesets/Base/mapreduce.jl
Lines 328 to 374 in b3f9f9b
function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1) | |
y = cumprod(x; dims=dims) # does nothing unless dims == 1 | |
project_x = ProjectTo(x) | |
function cumprod_pullback_1(dy_raw) | |
dy = unthunk(dy_raw) | |
dx_thunk = InplaceableThunk( | |
dx -> if dims == 1 | |
∇cumprod!(dx, x, dy, y) | |
else | |
dx .+= dy | |
end | |
, | |
@thunk project_x(if dims == 1 | |
∇cumprod(x, dy, y) | |
else | |
dy | |
end) | |
) | |
return (NoTangent(), dx_thunk) | |
end | |
return y, cumprod_pullback_1 | |
end | |
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) | |
y = cumprod(x; dims=dims) | |
project_x = ProjectTo(x) | |
function cumprod_pullback_2(dy_raw) | |
dy = unthunk(dy_raw) | |
dx_thunk = InplaceableThunk( | |
dx -> if dims <= ndims(x) | |
vald = Val(Int(dims)) | |
∇cumprod_dim!(dx, vald, x, dy, y) | |
else | |
dx .+= dy | |
end | |
, | |
@thunk project_x(if dims <= ndims(x) | |
vald = Val(Int(dims)) | |
∇cumprod_dim(vald, x, dy, y) | |
else | |
dy | |
end) | |
) | |
return (NoTangent(), dx_thunk) | |
end | |
return y, cumprod_pullback_2 | |
end |
Merging this over #784 as it is passing CI already. |
This doesn't reproduce locally for me.
And doesn't show up on never julia versions.
Possibly dispatch changed subtly to prevent this.
But lets see if it fixes it on CI