You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the implementation of MatMul, the used dot does not conform to the behavior of matrix multiplication in Torch, especially it is not suitable for cases with dimensions higher than two. It should be replaced with np.matmul() or cupy.matmul(). Moreover, during backpropagation, it's important to adapt according to the behavior of matrix multiplication on the last two axes as performed by matmul(). Below is the approach I took in my replicated code version; you can find my code at: this link.
dezero code:
class MatMul(Function):
def forward(self, x, W):
y = x.dot(W)
return y
def backward(self, gy):
x, W = self.inputs
gx = matmul(gy, W.T)
gW = matmul(x.T, gy)
return gx, gW
my code:
# Matrix Multiply
class MatMul(Function):
def _forward(self,a,b): # (N,A,B)@(B,C)=(N,A,C)
xp=get_array_module(a) # CUDA compatibility
return xp.matmul(a,b)
def _backward(self,grad):
transpose_idx=list(range(0,len(self.inputs[1].shape)))
transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
grad_a=MatMul()(grad,self.inputs[1].transpose(transpose_idx)) # (N,A,C)@(C,B)=(N,A,B)
if len(self.inputs[0].shape)!=len(grad_a.shape):
grad_a=Sum(axes=tuple(range(0,len(grad_a.shape)-len(self.inputs[0].shape))),keepdims=False)(grad_a)
transpose_idx=list(range(0,len(self.inputs[0].shape)))
transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
grad_b=MatMul()(self.inputs[0].transpose(transpose_idx),grad) # (N,B,A)@(N,A,C)=(N,B,C) -> Sum() -> (B,C)
if len(self.inputs[1].shape)!=len(grad_b.shape):
grad_b=Sum(axes=tuple(range(0,len(grad_b.shape)-len(self.inputs[1].shape))),keepdims=False)(grad_b)
return grad_a,grad_b
The text was updated successfully, but these errors were encountered:
In the implementation of MatMul, the used dot does not conform to the behavior of matrix multiplication in Torch, especially it is not suitable for cases with dimensions higher than two. It should be replaced with np.matmul() or cupy.matmul(). Moreover, during backpropagation, it's important to adapt according to the behavior of matrix multiplication on the last two axes as performed by matmul(). Below is the approach I took in my replicated code version; you can find my code at: this link.
dezero code:
my code:
The text was updated successfully, but these errors were encountered: