From 2be04f321a9b1801daf1e731bba953f74aee6b51 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 24 Feb 2022 09:41:08 +0100 Subject: [PATCH] remove cast rrule not needed anymore --- src/utils.jl | 14 -------------- test/runtests.jl | 4 ++-- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 93314dfab..0c582eefa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -99,17 +99,3 @@ function broadcast_edges(g::GNNGraph, x) gi = graph_indicator(g, edges=true) return gather(x, gi) end - -# More generic version of -# https://github.com/JuliaDiff/ChainRules.jl/pull/586 -# This applies to all arrays -# Withouth this, gradient of T.(A) for A dense gpu matrix errors. -function ChainRulesCore.rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractArray) - proj = ProjectTo(x) - - function broadcasted_cast(Δ) - return NoTangent(), NoTangent(), proj(Δ) - end - - return T.(x), broadcasted_cast -end diff --git a/test/runtests.jl b/test/runtests.jl index 3ce648cc4..41c60fee8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,10 +41,10 @@ tests = [ !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") -@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:dense, :coo, :sparse) +@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) global GRAPH_T = graph_type global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse) - + for t in tests startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI include("$t.jl")