From 7149c0e8bb91cf674b33d313edcbcb93d48c286d Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Tue, 12 May 2020 00:53:01 -0400 Subject: [PATCH] Report correct rank for matmul and aggregrations --- src/ops.jl | 4 +++- src/statistics.jl | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ops.jl b/src/ops.jl index cd24052..7a759a8 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -3,9 +3,11 @@ import Base: +, -, *, / for (op,fn) in zip((:+, :-, :/, :*), (atg_add, atg_sub, atg_div, atg_matmul)) @eval function $op(t1::Tensor{T,N}, t2::Tensor{T,K}) where {T,N,K} ptr = Ref(Ptr{Cvoid}()) + rank = Ref{Cint}(-1) $fn(ptr, t1.ptr, t2.ptr) - Tensor{T,N}(ptr[], on(t1)) + at_dim(rank, ptr[]) + Tensor{T,rank[]}(ptr[], on(t1)) end end diff --git a/src/statistics.jl b/src/statistics.jl index 313b6ac..80b653a 100644 --- a/src/statistics.jl +++ b/src/statistics.jl @@ -5,11 +5,11 @@ function Statistics.mean(t::Tensor{T,N}; dims = :) where {T,N} if dims isa Colon atg_mean(ptr, t.ptr, options[T]) + Tensor{T,0}(ptr[], on(t)) else atg_mean1(ptr, t.ptr, dims, length(dims), dims[1], options[T]) + Tensor{T,N-length(dims)}(ptr[], on(t)) end - - Tensor{T,N}(ptr[], on(t)) end function Statistics.sum(t::Tensor{T,N}; dims = :) where {T,N} @@ -17,9 +17,9 @@ function Statistics.sum(t::Tensor{T,N}; dims = :) where {T,N} if dims isa Colon atg_sum(ptr, t.ptr, options[T]) + Tensor{T,0}(ptr[], on(t)) else atg_sum1(ptr, t.ptr, dims, length(dims), dims[1], options[T]) + Tensor{T,N-length(dims)}(ptr[], on(t)) end - - Tensor{T,N}(ptr[], on(t)) end