forked from demianmnave/CML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
matrix_product.tpp
93 lines (74 loc) · 3.19 KB
/
matrix_product.tpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/* -*- C++ -*- ------------------------------------------------------------
@@COPYRIGHT@@
*-----------------------------------------------------------------------*/
/** @file
*/
#ifndef __CML_MATRIX_MATRIX_PRODUCT_TPP
#error "matrix/matrix_product.tpp not included correctly"
#endif
#include <cml/matrix/detail/resize.h>
#include <cml/matrix/types.h>
#include <xmmintrin.h>
namespace cml {
template<class Test, class Reference>
using is_same_t = typename std::enable_if<std::is_same<cml::unqualified_type_t<Test>, Reference>::value>::type;
template<class Test, class Reference>
using is_different_t = typename std::enable_if<!std::is_same<cml::unqualified_type_t<Test>, Reference>::value>::type;
template<class LeftMatrix, class RightMatrix>
using matrix_product_t = matrix_inner_product_promote_t<actual_operand_type_of_t<LeftMatrix>, actual_operand_type_of_t<RightMatrix>>;
// General purpose matrix product implementation for all matrix types, except special matrix types with optimized implementation
template<class LeftMatrix, class RightMatrix,
enable_if_matrix_t<LeftMatrix>* = nullptr,
enable_if_matrix_t<RightMatrix>* = nullptr,
is_different_t<LeftMatrix, matrix44f_r>* = nullptr,
is_different_t<RightMatrix, matrix44f_r>* = nullptr>
inline auto matrix_product(LeftMatrix&& left, RightMatrix&& right)
-> matrix_product_t<decltype(left), decltype(right)>
{
cml::check_same_inner_size(left, right);
matrix_product_t<decltype(left), decltype(right)> M;
detail::resize(M, array_rows_of(left), array_cols_of(right));
for(int i = 0; i < M.rows(); ++ i) {
for(int j = 0; j < M.cols(); ++ j) {
auto m = left(i,0) * right(0,j);
for(int k = 1; k < left.cols(); ++ k) m += left(i,k) * right(k,j);
M(i,j) = m;
}
}
return M;
}
// SSE optimized matrix product for float fixed matrices with row major alignment and row basis
template<class LeftMatrix, class RightMatrix,
is_same_t<LeftMatrix, matrix44f_r>* = nullptr,
is_same_t<RightMatrix, matrix44f_r>* = nullptr>
inline matrix44f_r matrix_product(LeftMatrix&& left, RightMatrix&& right)
{
matrix44f_r result;
float const* p_left_row = left.data();
float* p_result_row = result.data();
__m128 right_cols[4];
for (int col = 0; col < right.cols(); ++col) {
right_cols[col] = _mm_loadu_ps(right.data() + col * right.rows());
}
for (int row = 0; row < left.rows(); ++row, p_left_row += left.cols(), p_result_row += result.cols()) {
__m128 res_row = _mm_setzero_ps();
for (int col = 0; col < left.cols(); ++col) {
__m128 left_element = _mm_set1_ps(p_left_row[col]);
res_row = _mm_add_ps(res_row, _mm_mul_ps(left_element, right_cols[col]));
}
_mm_storeu_ps(p_result_row, res_row);
}
return result;
}
// Final matrix product template implementation
template<class LeftMatrix, class RightMatrix,
enable_if_matrix_t<LeftMatrix>*,
enable_if_matrix_t<RightMatrix>*>
inline auto operator*(LeftMatrix&& left, RightMatrix&& right)
-> matrix_product_t<decltype(left), decltype(right)>
{
return matrix_product(left, right);
}
} // namespace cml
// -------------------------------------------------------------------------
// vim:ft=cpp:sw=2