Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add rules for common sum
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 29, 2024
1 parent fa4a2f5 commit 68b226a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,32 @@ function CRC.rrule(::typeof(getproperty), op::UniformBlockDiagonalOperator, x::S
∇getproperty(Δ) = (NoTangent(), UniformBlockDiagonalOperator(Δ))
return op.data, ∇getproperty
end

# mapreduce fallback rules for UniformBlockDiagonalOperator
@inline _unsum(x, dy, dims) = broadcast(last tuple, x, dy)
@inline _unsum(x, dy, ::Colon) = broadcast(last tuple, x, Ref(dy))

function CRC.rrule(::typeof(sum), ::typeof(abs2), op::UniformBlockDiagonalOperator{T};
dims=:) where {T <: Union{Real, Complex}}
y = sum(abs2, op; dims)
∇sum_abs2 = @closure Δ -> begin
∂op = if dims isa Colon
UniformBlockDiagonalOperator(2 .* real.(Δ) .* getdata(op))
else
UniformBlockDiagonalOperator(2 .* real.(getdata(Δ)) .* getdata(op))
end
return NoTangent(), NoTangent(), ∂op
end
return y, ∇sum_abs2
end

function CRC.rrule(::typeof(sum), ::typeof(identity), op::UniformBlockDiagonalOperator{T};
dims=:) where {T <: Union{Real, Complex}}
y = sum(abs2, op; dims)
project = CRC.ProjectTo(getdata(op))
∇sum_abs2 = @closure Δ -> begin
∂op = project(_unsum(getdata(op), getdata(Δ), dims))
return NoTangent(), NoTangent(), UniformBlockDiagonalOperator(∂op)
end
return y, ∇sum_abs2
end
1 change: 1 addition & 0 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ SciMLOperators.isconvertible(::UniformBlockDiagonalOperator) = false

# BatchedRoutines API
getdata(op::UniformBlockDiagonalOperator) = op.data
getdata(x) = x
nbatches(op::UniformBlockDiagonalOperator) = size(op.data, 3)
batchview(op::UniformBlockDiagonalOperator) = batchview(op.data)
batchview(op::UniformBlockDiagonalOperator, i::Int) = batchview(op.data, i)
Expand Down

0 comments on commit 68b226a

Please sign in to comment.