diff --git a/Project.toml b/Project.toml index f5273df1d..b46577795 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.2" +version = "0.10.3" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index efaebb0c7..0f54d8d23 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -13,6 +13,10 @@ end return element, (val, new_state) end +Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b) +Base.:(==)(a::AbstractThunk, b) = unthunk(a) == b +Base.:(==)(a, b::AbstractThunk) = a == unthunk(b) + """ @thunk expr diff --git a/test/differentials/thunks.jl b/test/differentials/thunks.jl index 2fb532fb0..8cd83a92f 100644 --- a/test/differentials/thunks.jl +++ b/test/differentials/thunks.jl @@ -1,6 +1,12 @@ @testset "Thunk" begin @test @thunk(3) isa Thunk + @testset "==" begin + @test @thunk(3.2) == InplaceableThunk(@thunk(3.2), x -> x + 3.2) + @test @thunk(3.2) == 3.2 + @test 3.2 == InplaceableThunk(@thunk(3.2), x -> x + 3.2) + end + @testset "show" begin rep = repr(Thunk(rand)) @test occursin(r"Thunk\(.*rand.*\)", rep)