Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PermutationMatrix instead of indices #475

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/albatross/src/cereal/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ inline void load(Archive &archive,
v.indices() = indices;
}

template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
typename _StorageIndex>
inline void
save(Archive &archive,
const Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex> &v,
const std::uint32_t) {
archive(cereal::make_nvp("indices", v.indices()));
}

template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
typename _StorageIndex>
inline void
load(Archive &archive,
Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex> &v,
const std::uint32_t) {
typename Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex>::IndicesType indices;
archive(cereal::make_nvp("indices", indices));
v.indices() = indices;
}

template <typename Archive, typename _Scalar, int SizeAtCompileTime>
inline void serialize(Archive &archive,
Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime> &matrix,
Expand Down
16 changes: 8 additions & 8 deletions include/albatross/src/cereal/gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,
archive(cereal::make_nvp("information", fit.information));
archive(cereal::make_nvp("train_covariance", fit.train_covariance));
archive(cereal::make_nvp("train_features", fit.train_features));
archive(cereal::make_nvp("sigma_R", fit.sigma_R));
archive(cereal::make_nvp("permutation_indices", fit.permutation_indices));
archive(cereal::make_nvp("R", fit.R));
archive(cereal::make_nvp("P", fit.P));
if (version > 1) {
archive(cereal::make_nvp("numerical_rank", fit.numerical_rank));
} else {
Expand All @@ -53,19 +53,19 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,

template <typename Archive, typename CovFunc, typename MeanFunc,
typename ImplType>
void save(Archive &archive,
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t) {
inline void save(Archive &archive,
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t) {
archive(cereal::make_nvp("name", gp.get_name()));
archive(cereal::make_nvp("params", gp.get_params()));
archive(cereal::make_nvp("insights", gp.insights));
}

template <typename Archive, typename CovFunc, typename MeanFunc,
typename ImplType>
void load(Archive &archive,
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t version) {
inline void load(Archive &archive,
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t version) {
if (version > 0) {
std::string model_name;
archive(cereal::make_nvp("name", model_name));
Expand Down
7 changes: 7 additions & 0 deletions include/albatross/src/core/declarations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ template <typename... Ts> class variant;

using mapbox::util::variant;

/*
* Permutations
*/
namespace Eigen {
using PermutationMatrixX = PermutationMatrix<Dynamic, Dynamic, Index>;
}

namespace albatross {

/*
Expand Down
22 changes: 12 additions & 10 deletions include/albatross/src/linalg/qr_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,31 @@ get_R(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
.template triangularView<Eigen::Upper>();
}

inline Eigen::PermutationMatrixX
get_P(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
return Eigen::PermutationMatrixX(
qr.colsPermutation().indices().template cast<Eigen::Index>());
}

/*
* Computes R^-T P^T rhs given R and P from a QR decomposition.
*/
template <typename MatrixType, typename PermutationIndicesType>
template <typename MatrixType, typename PermutationScalar>
inline Eigen::MatrixXd
sqrt_solve(const Eigen::MatrixXd &R,
const PermutationIndicesType &permutation_indices,
const Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
PermutationScalar> &P,
const MatrixType &rhs) {

Eigen::MatrixXd sqrt(rhs.rows(), rhs.cols());
for (Eigen::Index i = 0; i < permutation_indices.size(); ++i) {
sqrt.row(i) = rhs.row(permutation_indices.coeff(i));
}
sqrt = R.template triangularView<Eigen::Upper>().transpose().solve(sqrt);
return sqrt;
return R.template triangularView<Eigen::Upper>().transpose().solve(
P.transpose() * rhs);
}

template <typename MatrixType>
inline Eigen::MatrixXd
sqrt_solve(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr,
const MatrixType &rhs) {
const Eigen::MatrixXd R = get_R(qr);
return sqrt_solve(R, qr.colsPermutation().indices(), rhs);
return sqrt_solve(R, qr.colsPermutation(), rhs);
}

} // namespace albatross
Expand Down
11 changes: 6 additions & 5 deletions include/albatross/src/linalg/spqr_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ using SparseMatrix = Eigen::SparseMatrix<double>;

using SPQR = Eigen::SPQR<SparseMatrix>;

using SparsePermutationMatrix =
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
SPQR::StorageIndex>;

inline Eigen::MatrixXd get_R(const SPQR &qr) {
return qr.matrixR()
.topLeftCorner(qr.cols(), qr.cols())
.template triangularView<Eigen::Upper>();
}

inline Eigen::PermutationMatrixX get_P(const SPQR &qr) {
return Eigen::PermutationMatrixX(
qr.colsPermutation().indices().template cast<Eigen::Index>());
}

template <typename MatrixType>
inline Eigen::MatrixXd sqrt_solve(const SPQR &qr, const MatrixType &rhs) {
return sqrt_solve(get_R(qr), qr.colsPermutation().indices(), rhs);
return sqrt_solve(get_R(qr), get_P(qr), rhs);
}

// Matrices with any dimension smaller than this will use a special
Expand Down
59 changes: 22 additions & 37 deletions include/albatross/src/models/sparse_gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,19 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {

std::vector<FeatureType> train_features;
Eigen::SerializableLDLT train_covariance;
Eigen::MatrixXd sigma_R;
PermutationIndices permutation_indices;
Eigen::MatrixXd R;
Eigen::PermutationMatrixX P;
Eigen::VectorXd information;
Eigen::Index numerical_rank;

Fit(){};

Fit(const std::vector<FeatureType> &features_,
const Eigen::SerializableLDLT &train_covariance_,
const Eigen::MatrixXd &sigma_R_,
PermutationIndices &&permutation_indices_,
const Eigen::MatrixXd &R_, const Eigen::PermutationMatrixX &P_,
const Eigen::VectorXd &information_, Eigen::Index numerical_rank_)
: train_features(features_), train_covariance(train_covariance_),
sigma_R(sigma_R_), permutation_indices(std::move(permutation_indices_)),
information(information_), numerical_rank(numerical_rank_) {}
: train_features(features_), train_covariance(train_covariance_), R(R_),
P(P_), information(information_), numerical_rank(numerical_rank_) {}

void shift_mean(const Eigen::VectorXd &mean_shift) {
ALBATROSS_ASSERT(mean_shift.size() == information.size());
Expand All @@ -120,9 +118,8 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {

bool operator==(const Fit<SparseGPFit<FeatureType>> &other) const {
return (train_features == other.train_features &&
train_covariance == other.train_covariance &&
sigma_R == other.sigma_R &&
permutation_indices == other.permutation_indices &&
train_covariance == other.train_covariance && R == other.R &&
P.indices() == other.P.indices() &&
information == other.information &&
numerical_rank == other.numerical_rank);
}
Expand Down Expand Up @@ -325,20 +322,17 @@ class SparseGaussianProcessRegression
compute_internal_components(old_fit.train_features, features, targets,
&A_ldlt, &K_uu_ldlt, &K_fu, &y);

const Eigen::Index n_old = old_fit.sigma_R.rows();
const Eigen::Index n_old = old_fit.R.rows();
const Eigen::Index n_new = A_ldlt.rows();
const Eigen::Index k = old_fit.sigma_R.cols();
const Eigen::Index k = old_fit.R.cols();
Eigen::MatrixXd B = Eigen::MatrixXd::Zero(n_old + n_new, k);

ALBATROSS_ASSERT(n_old == k);

// Form:
// B = |R_old P_old^T| = |Q_1| R P^T
// |A^{-1/2} K_fu| |Q_2|
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
const Eigen::Index &pi = old_fit.permutation_indices.coeff(i);
B.col(pi).topRows(i + 1) = old_fit.sigma_R.col(i).topRows(i + 1);
}
B.topRows(old_fit.P.rows()) = old_fit.R * old_fit.P.transpose();
B.bottomRows(n_new) = A_ldlt.sqrt_solve(K_fu);
const auto B_qr = QRImplementation::compute(B, Base::threads_.get());

Expand All @@ -347,13 +341,9 @@ class SparseGaussianProcessRegression
// |A^{-1/2} y |
ALBATROSS_ASSERT(old_fit.information.size() == n_old);
Eigen::VectorXd y_augmented(n_old + n_new);
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
y_augmented[i] =
old_fit.information[old_fit.permutation_indices.coeff(i)];
}
y_augmented.topRows(n_old) =
old_fit.sigma_R.template triangularView<Eigen::Upper>() *
y_augmented.topRows(n_old);
old_fit.R.template triangularView<Eigen::Upper>() *
(old_fit.P.transpose() * old_fit.information);

y_augmented.bottomRows(n_new) = A_ldlt.sqrt_solve(y, Base::threads_.get());
const Eigen::VectorXd v = B_qr->solve(y_augmented);
Expand All @@ -365,10 +355,9 @@ class SparseGaussianProcessRegression
Eigen::VectorXd::Constant(B_qr->cols(), details::cSparseRNugget);
}
using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
return FitType(
old_fit.train_features, old_fit.train_covariance, R,
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
B_qr->rank());

return FitType(old_fit.train_features, old_fit.train_covariance, R,
get_P(*B_qr), v, B_qr->rank());
}

// Here we create the QR decomposition of:
Expand Down Expand Up @@ -415,10 +404,7 @@ class SparseGaussianProcessRegression
using InducingPointFeatureType = typename std::decay<decltype(u[0])>::type;

using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
return FitType(
u, K_uu_ldlt, get_R(*B_qr),
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
B_qr->rank());
return FitType(u, K_uu_ldlt, get_R(*B_qr), get_P(*B_qr), v, B_qr->rank());
}

template <typename FeatureType>
Expand Down Expand Up @@ -471,9 +457,8 @@ class SparseGaussianProcessRegression
const Eigen::MatrixXd sigma_inv_sqrt = C_ldlt.sqrt_solve(K_zz);
const auto B_qr = QRImplementation::compute(sigma_inv_sqrt, nullptr);

new_fit.permutation_indices =
B_qr->colsPermutation().indices().template cast<Eigen::Index>();
new_fit.sigma_R = get_R(*B_qr);
new_fit.P = get_P(*B_qr);
new_fit.R = get_R(*B_qr);
new_fit.numerical_rank = B_qr->rank();

return output;
Expand Down Expand Up @@ -519,8 +504,8 @@ class SparseGaussianProcessRegression
Q_sqrt.cwiseProduct(Q_sqrt).array().colwise().sum();
marginal_variance -= Q_diag;

const Eigen::MatrixXd S_sqrt = sqrt_solve(
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
const Eigen::MatrixXd S_sqrt =
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);
const Eigen::VectorXd S_diag =
S_sqrt.cwiseProduct(S_sqrt).array().colwise().sum();
marginal_variance += S_diag;
Expand All @@ -537,8 +522,8 @@ class SparseGaussianProcessRegression
this->covariance_function_(sparse_gp_fit.train_features, features);
const Eigen::MatrixXd prior_cov = this->covariance_function_(features);

const Eigen::MatrixXd S_sqrt = sqrt_solve(
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
const Eigen::MatrixXd S_sqrt =
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);

const Eigen::MatrixXd Q_sqrt =
sparse_gp_fit.train_covariance.sqrt_solve(cross_cov);
Expand Down
8 changes: 4 additions & 4 deletions tests/test_sparse_gp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,10 @@ TYPED_TEST(SparseGaussianProcessTest, test_update) {
(updated_in_place_pred.covariance - full_pred.covariance).norm();

auto compute_sigma = [](const auto &fit_model) -> Eigen::MatrixXd {
const Eigen::Index n = fit_model.get_fit().sigma_R.cols();
Eigen::MatrixXd sigma = sqrt_solve(fit_model.get_fit().sigma_R,
fit_model.get_fit().permutation_indices,
Eigen::MatrixXd::Identity(n, n));
const Eigen::Index n = fit_model.get_fit().R.cols();
Eigen::MatrixXd sigma =
sqrt_solve(fit_model.get_fit().R, fit_model.get_fit().P,
Eigen::MatrixXd::Identity(n, n));
return sigma.transpose() * sigma;
};

Expand Down
Loading