Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
oxinabox and github-actions[bot] authored May 21, 2024
1 parent a05ebb2 commit d40e282
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
8 changes: 2 additions & 6 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,24 +292,20 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage
return y, task_local_storage_pullback
end


####
#### merge
####

function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2}
function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1,F2}
y = merge(nt1, nt2)
function merge_pullback(dy)
dnt1 = Tangent{typeof(nt1)}(;
(f1 => (f1 in F2 ? ZeroTangent() : getproperty(dy, f1)) for f1 in F1)...
)
dnt2 = Tangent{typeof(nt2)}(;
(f2 => getproperty(dy, f2) for f2 in F2)...
)
dnt2 = Tangent{typeof(nt2)}(; (f2 => getproperty(dy, f2) for f2 in F2)...)
return (NoTangent(), dnt1, dnt2)
end
merge_pullback(dy::AbstractThunk) = merge_pullback(unthunk(dy))
merge_pullback(x::AbstractZero) = (NoTangent(), x, x)

return y, merge_pullback
end
4 changes: 2 additions & 2 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ end
end

@testset "merge NamedTuple" begin
test_rrule(merge, (;a=1.0), (;b=2.0), check_inferred=false)
test_rrule(merge, (;a=1.0), (;a=2.0), check_inferred=false)
test_rrule(merge, (; a=1.0), (; b=2.0); check_inferred=false)
test_rrule(merge, (; a=1.0), (; a=2.0); check_inferred=false)
end
end

0 comments on commit d40e282

Please sign in to comment.