diff --git a/src/ops.jl b/src/ops.jl index 7a759a8..b49391b 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -6,8 +6,11 @@ for (op,fn) in zip((:+, :-, :/, :*), (atg_add, atg_sub, atg_div, atg_matmul)) rank = Ref{Cint}(-1) $fn(ptr, t1.ptr, t2.ptr) - at_dim(rank, ptr[]) - Tensor{T,rank[]}(ptr[], on(t1)) + + # TODO: using `rank` here causes compiler to emit error + # make shape checking more robust + # at_dim(rank, ptr[]) + Tensor{T,max(N,K)}(ptr[], on(t1)) end end