diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index 3433cc2..37f3350 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -95,7 +95,7 @@ where E: Dtype + Float, D: Device, S1: Dim, - S2: Dim, + S2: Dim + AssertDimEq, T: Tape, { type Output = Tensor<(S1, Const), E, D, T>; @@ -131,7 +131,7 @@ where let weights = weights.try_softmax::>()?; // Get new tokens - let tokens = weights.try_dynamic_matmul(v)?; + let tokens = weights.try_matmul(v)?; let tokens = tokens.try_permute::<_, Axes3<1, 0, 2>>()?; let tokens = tokens.try_reshape_like(&(s1, Const::)).unwrap()?; @@ -150,7 +150,7 @@ where D: Device, B: Dim, S1: Dim, - S2: Dim, + S2: Dim + AssertDimEq, T: Tape, { type Output = Tensor<(B, S1, Const), E, D, T>; @@ -174,24 +174,27 @@ where let s2 = v.shape.1; let v = self.w_v.try_forward(v.retaped::())?; - let v = v.try_reshape_like(&(b, s2, H, V / H)).unwrap()?; - let v = v.try_permute::<_, Axes4<0, 2, 1, 3>>()?; + let v = v.try_reshape_like(&(b, s2, Const::, V / H)).unwrap()?; + let v: Tensor<(B, Const, S2, usize), E, D, T> = + v.try_permute::<_, Axes4<0, 2, 1, 3>>()?; let k = self.w_k.try_forward(k.retaped::())?; - let k = k.try_reshape_like(&(b, s2, H, K / H)).unwrap()?; + let k = k.try_reshape_like(&(b, s2, Const::, K / H)).unwrap()?; let k = k.try_permute::<_, Axes4<0, 2, 3, 1>>()?; let q = self.w_q.try_forward(q)?; - let q = q.try_reshape_like(&(b, s1, H, K / H)).unwrap()?; + let q = q.try_reshape_like(&(b, s1, Const::, K / H)).unwrap()?; let q = q.try_permute::<_, Axes4<0, 2, 1, 3>>()?; // Get weights let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt(); - let weights = q.try_dynamic_matmul(k)?.try_mul(scalar)?; - let weights = weights.try_softmax::>()?; + let weights = q.try_matmul(k)?.try_mul(scalar)?; + let weights: Tensor<(B, Const, S1, S2), E, D, T> = weights.try_softmax::>()?; // Get new tokens - let tokens = weights.try_dynamic_matmul(v)?; + // weights (B, Const, S1, S2) + // v (B, Const, S2, usize) + let tokens = weights.try_matmul(v)?; let tokens = tokens.try_permute::<_, Axes4<0, 2, 1, 3>>()?; let tokens = tokens.try_reshape_like(&(b, s1, Const::)).unwrap()?; diff --git a/src/tensor_ops/matmul/mod.rs b/src/tensor_ops/matmul/mod.rs index 480ed0b..66e809d 100644 --- a/src/tensor_ops/matmul/mod.rs +++ b/src/tensor_ops/matmul/mod.rs @@ -123,44 +123,71 @@ fn try_binary_op< Ok(out.put_tape(tape)) } -pub trait MulStaticDimCheck { +pub trait AssertDimEq { const TYPE_CHECK: (); - fn assert_dim_eq(&self); + fn assert_dim_eq(&self, rhs: &Rhs); +} + +impl AssertDimEq for usize { + const TYPE_CHECK: () = (); + fn assert_dim_eq(&self, rhs: &usize) { + assert_eq!(self, rhs); + } +} + +impl AssertDimEq> for usize { + const TYPE_CHECK: () = (); + fn assert_dim_eq(&self, rhs: &Const) { + assert_eq!(*self, M); + } } -impl MulStaticDimCheck> for Rank1 { - const TYPE_CHECK: () = assert!( - L == R, - "You are trying to multiply vectors whose dimensions don't match." - ); - fn assert_dim_eq(&self) { - let _ = >>::TYPE_CHECK; +impl AssertDimEq for Const { + const TYPE_CHECK: () = (); + fn assert_dim_eq(&self, rhs: &usize) { + assert_eq!(M, *rhs); } } -impl MulStaticDimCheck<(Const, N)> for Const { - const TYPE_CHECK: () = assert!( - L == R, - "You are trying to multiply a vector to a matrix whose row dimension does not match the dimension of the vector." - ); - fn assert_dim_eq(&self) { - let _ = , N)>>::TYPE_CHECK; +impl AssertDimEq> for Const { + const TYPE_CHECK: () = assert!(M == N); + fn assert_dim_eq(&self, rhs: &Const) { + let _ = >>::TYPE_CHECK; } } -impl MulStaticDimCheck<(Const, N)> - for (M, Const) +pub trait MulStaticDimCheck { + fn assert_shape_eq(&self, rhs: &Rhs); +} + +impl MulStaticDimCheck<(R,)> for (L,) +where + L: AssertDimEq, { - const TYPE_CHECK: () = assert!( - L == R, - "You are trying to multiply matrices where the column dimension of the first does not match the row dimension of the second." - ); - fn assert_dim_eq(&self) { - let _ = , N)>>::TYPE_CHECK; + fn assert_shape_eq(&self, rhs: &(R,)) { + self.0.assert_dim_eq(&rhs.0); } } -// impl MulDimCheck<(Const, usize)> for (usize, Const) { +impl MulStaticDimCheck<(R, N)> for (L,) +where + L: AssertDimEq, +{ + fn assert_shape_eq(&self, rhs: &(R, N)) { + self.0.assert_dim_eq(&rhs.0); + } +} + +impl MulStaticDimCheck<(R, N)> for (M, L) +where + L: AssertDimEq, +{ + fn assert_shape_eq(&self, rhs: &(R, N)) { + self.1.assert_dim_eq(&rhs.0); + } +} + +// impl MulDimCheck<(Const, usize)> for (usize, Const) { // const TYPE_CHECK: () = assert!( // L == R, // "You are trying to multiply matrices where the column dimension of the first does not match the row dimension of the second." @@ -170,31 +197,25 @@ impl MulStaticDimCheck<(Const // } // } -impl MulStaticDimCheck<(Const, N)> - for (B, M, Const) +impl MulStaticDimCheck<(R, N)> for (B, M, L) +where + L: AssertDimEq, { - const TYPE_CHECK: () = assert!( - L == R, - "You are trying to multiply a tensor of rank 3 to a matrix where the last dimension of the first does not match the first dimension of the second." - ); - fn assert_dim_eq(&self) { - let _ = , N)>>::TYPE_CHECK; + fn assert_shape_eq(&self, rhs: &(R, N)) { + self.2.assert_dim_eq(&rhs.0); } } -impl MulStaticDimCheck<(B, Const, N)> - for (B, M, Const) +impl MulStaticDimCheck<(B, R, N)> for (B, M, L) +where + L: AssertDimEq, { - const TYPE_CHECK: () = assert!( - L == R, - "You are trying to multiply two tensors of rank 3 for Batch3Mul where the last dimension of the first does not match the second dimension of the second." - ); - fn assert_dim_eq(&self) { - let _ = , M)>>::TYPE_CHECK; + fn assert_shape_eq(&self, rhs: &(B, R, N)) { + self.2.assert_dim_eq(&rhs.1); } } -// impl MulDimCheck<(usize, Const, usize)> +// impl MulDimCheck<(usize, Const, usize)> // for (usize, usize, Const) // { // const TYPE_CHECK: () = assert!( @@ -206,19 +227,17 @@ impl MulStaticDimCheck<( // } // } -impl - MulStaticDimCheck<(B, S, Const, N)> for (B, S, M, Const) +impl MulStaticDimCheck<(B, S, RightK, N)> + for (B, S, M, LeftK) +where + LeftK: AssertDimEq, { - const TYPE_CHECK: () = assert!( - L == R, - "You are trying to multiply two tensors of rank 4 for Batch4Mul where the last dimension of the first does not match the second to last dimension of the second." - ); - fn assert_dim_eq(&self) { - let _ = , M)>>::TYPE_CHECK; + fn assert_shape_eq(&self, rhs: &(B, S, RightK, N)) { + self.3.assert_dim_eq(&rhs.2); } } -// impl MulDimCheck<(usize, usize, Const, usize)> +// impl MulDimCheck<(usize, usize, Const, usize)> // for (usize, usize, usize, Const) // { // const TYPE_CHECK: () = assert!( @@ -383,7 +402,7 @@ where /// ``` fn try_matmul(self, rhs: Tensor<(RightK, N), E, D, R>) -> Result { // assert_eq!(self.shape.1.size(), rhs.shape.0.size()); - self.shape.assert_dim_eq(); + self.shape.assert_shape_eq(&rhs.shape); // println!( // "Left {:?} Right {:?}", // self.shape.1.size(), @@ -434,7 +453,7 @@ where rhs: Tensor<(usize, N), E, D, R>, ) -> Result { // assert_eq!(self.shape.1.size(), rhs.shape.0.size()); - self.shape.assert_dim_eq(&rhs.shape); + self.shape.assert_shape_eq(&rhs.shape); // println!( // "Left {:?} Right {:?}", // self.shape.1.size(), @@ -482,7 +501,7 @@ where /// ``` fn try_matmul(self, rhs: Tensor<(RightK, N), E, D, R>) -> Result { // assert_eq!(self.shape.2, rhs.shape.0); - self.shape.assert_dim_eq(); + self.shape.assert_shape_eq(&rhs.shape); try_binary_op(self, rhs, D::forward, D::backward) } } @@ -527,7 +546,7 @@ where fn try_matmul(self, rhs: Tensor<(B, RightK, N), E, D, R>) -> Result { // assert_eq!(self.shape.0, rhs.shape.0); // assert_eq!(self.shape.2, rhs.shape.1); - self.shape.assert_dim_eq(); + self.shape.assert_shape_eq(&rhs.shape); try_binary_op(self, rhs, D::forward, D::backward) } } @@ -575,7 +594,7 @@ where ) -> Result { // assert_eq!(self.shape.0, rhs.shape.0); // assert_eq!(self.shape.2, rhs.shape.1); - self.shape.assert_dim_eq(&rhs.shape); + self.shape.assert_shape_eq(&rhs.shape); try_binary_op(self, rhs, D::forward, D::backward) } } @@ -623,7 +642,7 @@ where ) -> Result { // assert_eq!(self.shape.0, rhs.shape.0); // assert_eq!(self.shape.2, rhs.shape.1); - self.shape.assert_dim_eq(&rhs.shape); + self.shape.assert_shape_eq(&rhs.shape); try_binary_op(self, rhs, D::forward, D::backward) } } @@ -720,7 +739,7 @@ where // assert_eq!(self.shape.0, rhs.shape.0); // assert_eq!(self.shape.1, rhs.shape.1); // assert_eq!(self.shape.3, rhs.shape.2); - self.shape.assert_dim_eq(); + self.shape.assert_shape_eq(&rhs.shape); try_binary_op(self, rhs, D::forward, D::backward) } } @@ -770,7 +789,7 @@ where // assert_eq!(self.shape.0, rhs.shape.0); // assert_eq!(self.shape.1, rhs.shape.1); // assert_eq!(self.shape.3, rhs.shape.2); - self.shape.assert_dim_eq(&rhs.shape); + self.shape.assert_shape_eq(&rhs.shape); try_binary_op(self, rhs, D::forward, D::backward) } } diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index fbb606f..5b6e832 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -221,7 +221,7 @@ pub use huber_error::huber_error; pub use ln::ln; pub use log_softmax::log_softmax; pub use logsumexp_to::LogSumExpTo; -pub use matmul::{matmul, TryDynamicMatMul, TryDynamicMatMul1, TryStaticMatMul}; +pub use matmul::{matmul, AssertDimEq, TryDynamicMatMul, TryDynamicMatMul1, TryStaticMatMul}; pub use max_to::MaxTo; pub use maximum::maximum; pub use mean_to::MeanTo;