Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regarding the Differences in MatMul Matrix Multiplication Behavior #44

Open
owenliang opened this issue Dec 21, 2024 · 0 comments
Open

Comments

@owenliang
Copy link

In the implementation of MatMul, the used dot does not conform to the behavior of matrix multiplication in Torch, especially it is not suitable for cases with dimensions higher than two. It should be replaced with np.matmul() or cupy.matmul(). Moreover, during backpropagation, it's important to adapt according to the behavior of matrix multiplication on the last two axes as performed by matmul(). Below is the approach I took in my replicated code version; you can find my code at: this link.

dezero code:

class MatMul(Function):
    def forward(self, x, W):
        y = x.dot(W)
        return y

    def backward(self, gy):
        x, W = self.inputs
        gx = matmul(gy, W.T)
        gW = matmul(x.T, gy)
        return gx, gW

my code:

# Matrix Multiply
class MatMul(Function):
    def _forward(self,a,b): # (N,A,B)@(B,C)=(N,A,C)
        xp=get_array_module(a)   # CUDA compatibility
        return xp.matmul(a,b)

    def _backward(self,grad):
        transpose_idx=list(range(0,len(self.inputs[1].shape)))
        transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
        grad_a=MatMul()(grad,self.inputs[1].transpose(transpose_idx))    # (N,A,C)@(C,B)=(N,A,B)
        if len(self.inputs[0].shape)!=len(grad_a.shape):
            grad_a=Sum(axes=tuple(range(0,len(grad_a.shape)-len(self.inputs[0].shape))),keepdims=False)(grad_a)

        transpose_idx=list(range(0,len(self.inputs[0].shape)))
        transpose_idx[-1],transpose_idx[-2]=transpose_idx[-2],transpose_idx[-1]
        grad_b=MatMul()(self.inputs[0].transpose(transpose_idx),grad)    # (N,B,A)@(N,A,C)=(N,B,C) -> Sum() -> (B,C)
        if len(self.inputs[1].shape)!=len(grad_b.shape):
            grad_b=Sum(axes=tuple(range(0,len(grad_b.shape)-len(self.inputs[1].shape))),keepdims=False)(grad_b)

        return grad_a,grad_b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant