diff --git a/src/layers/conv.jl b/src/layers/conv.jl
index 56ab12018..c1aac5d6e 100644
--- a/src/layers/conv.jl
+++ b/src/layers/conv.jl
@@ -944,14 +944,16 @@ function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
     return CGConv(ch, dense_f, dense_s, residual)
 end
 
-function (l::CGConv)(g::GNNGraph, x::AbstractMatrix,
+function (l::CGConv)(g::AbstractGNNGraph, x,
                      e::Union{Nothing, AbstractMatrix} = nothing)
     check_num_nodes(g, x)
+    xj, xi = expand_srcdst(g, x)
+    
     if e !== nothing
         check_num_edges(g, e)
     end
 
-    m = propagate(message, g, +, l, xi = x, xj = x, e = e)
+    m = propagate(message, g, +, l, xi = xi, xj = xj, e = e)
 
     if l.residual
         if size(x, 1) == size(m, 1)
@@ -964,6 +966,7 @@ function (l::CGConv)(g::GNNGraph, x::AbstractMatrix,
     return m
 end
 
+
 function message(l::CGConv, xi, xj, e)
     if e !== nothing
         z = vcat(xi, xj, e)
diff --git a/src/msgpass.jl b/src/msgpass.jl
index 232aff36c..2118cdac6 100644
--- a/src/msgpass.jl
+++ b/src/msgpass.jl
@@ -87,11 +87,11 @@ end
 # https://github.com/JuliaLang/julia/issues/15276
 ## and zygote issues
 # https://github.com/FluxML/Zygote.jl/issues/1317
-function propagate(f, g::GNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
+function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing,
                    e = nothing)
     propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
 end
-function propagate(f, g::GNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing)
+function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing)
     propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e)
 end
 
diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl
index ff1203497..6cbd8563f 100644
--- a/test/layers/heteroconv.jl
+++ b/test/layers/heteroconv.jl
@@ -91,4 +91,13 @@
         output = y2.A[:, [2]]
         @test expected ≈ output
     end
+
+    @testset "CGConv" begin
+        g = rand_bipartite_heterograph((2,3), 6)
+        x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
+        layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu),
+                                    (:B, :to, :A) => CGConv(4 => 2, relu));
+        y = layers(g, x); 
+        @test size(y.A) == (2,2) && size(y.B) == (2,3)
+    end
 end