Skip to content

Commit

Permalink
Merge pull request #19 from ToucheSir/bc/fix-rank
Browse files Browse the repository at this point in the history
Report correct rank for matmul and aggregrations
  • Loading branch information
DhairyaLGandhi authored May 12, 2020
2 parents eb8f25c + 7149c0e commit 72c0dde
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ import Base: +, -, *, /
for (op,fn) in zip((:+, :-, :/, :*), (atg_add, atg_sub, atg_div, atg_matmul))
@eval function $op(t1::Tensor{T,N}, t2::Tensor{T,K}) where {T,N,K}
ptr = Ref(Ptr{Cvoid}())
rank = Ref{Cint}(-1)

$fn(ptr, t1.ptr, t2.ptr)
Tensor{T,N}(ptr[], on(t1))
at_dim(rank, ptr[])
Tensor{T,rank[]}(ptr[], on(t1))
end
end

Expand Down
8 changes: 4 additions & 4 deletions src/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@ function Statistics.mean(t::Tensor{T,N}; dims = :) where {T,N}

if dims isa Colon
atg_mean(ptr, t.ptr, options[T])
Tensor{T,0}(ptr[], on(t))
else
atg_mean1(ptr, t.ptr, dims, length(dims), dims[1], options[T])
Tensor{T,N-length(dims)}(ptr[], on(t))
end

Tensor{T,N}(ptr[], on(t))
end

function Statistics.sum(t::Tensor{T,N}; dims = :) where {T,N}
ptr = Ref(Ptr{Cvoid}())

if dims isa Colon
atg_sum(ptr, t.ptr, options[T])
Tensor{T,0}(ptr[], on(t))
else
atg_sum1(ptr, t.ptr, dims, length(dims), dims[1], options[T])
Tensor{T,N-length(dims)}(ptr[], on(t))
end

Tensor{T,N}(ptr[], on(t))
end

0 comments on commit 72c0dde

Please sign in to comment.