diff --git a/docker/pytorch-aarch64/patches/blas_to_mkl_acl.patch b/docker/pytorch-aarch64/patches/blas_to_mkl_acl.patch index 26a9e36d..a36698ea 100644 --- a/docker/pytorch-aarch64/patches/blas_to_mkl_acl.patch +++ b/docker/pytorch-aarch64/patches/blas_to_mkl_acl.patch @@ -16,10 +16,10 @@ # ******************************************************************************* diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp -index c658d4427c..5c792f0e73 100644 +index a0531c50c96..55102c9d2f5 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp -@@ -1308,6 +1308,16 @@ static void addmm_impl_cpu_( +@@ -1420,6 +1420,20 @@ static void addmm_impl_cpu_( AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, result.scalar_type(), "addmm_impl_cpu_", [&]{ @@ -29,21 +29,25 @@ index c658d4427c..5c792f0e73 100644 + // that will call then into ACL GEMM kernel and also additionaly have support + // for running kernel with BF16 instructions + if(transpose_a && !transpose_b) { ++ if (transpose_c) { + mkldnn_matmul(b, a, c, beta.to(), alpha.to()); -+ return; ++ } else { ++ mkldnn_matmul(a, b, c, beta.to(), alpha.to()); ++ } ++ return; + } + #endif using opmath_t = at::opmath_type; at::native::cpublas::gemm( transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose, diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp -index d41ebac635..e2cc13fe00 100644 +index 383d2965923..b15056d7161 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp -@@ -128,23 +128,25 @@ void mkldnn_matmul( +@@ -130,23 +130,25 @@ void mkldnn_matmul( (mat1.dim() == 1 && mat2.dim() == 1), // aten::dot "mkldnn_matmul: unsupported dims for mat and mat2"); - + +#if defined(__aarch64__) + // oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch + // fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted. @@ -58,7 +62,7 @@ index d41ebac635..e2cc13fe00 100644 +#else TORCH_CHECK(mkldnn_bf16_device_check(), "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq, or AWS Graviton3"); - + -#if defined(__aarch64__) - if (mkldnn_bf16_device_check_arm()) { - //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1 @@ -76,6 +80,6 @@ index d41ebac635..e2cc13fe00 100644 - mat2.scalar_type() == at::kBFloat16 && - result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path"); - } - + auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1; auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;