Skip to content

Commit

Permalink
fix her2k HH1 + remove console output
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Aug 16, 2024
1 parent 6468bc0 commit 9fc6fbf
Showing 1 changed file with 12 additions and 25 deletions.
37 changes: 12 additions & 25 deletions include/dlaf/eigensolver/reduction_to_band/ca-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
#include <dlaf/sender/transform.h>
#include <dlaf/sender/transform_mpi.h>

//
#include <dlaf/matrix/print_numpy.h>

namespace dlaf::eigensolver::internal {

namespace ca_red2band {
Expand Down Expand Up @@ -209,15 +206,19 @@ void her2kUpdateTrailingMatrix(comm::Index2D rank_qr, const matrix::SubMatrixVie
her2kDiag<B>(priority, V.read(ij_lc), W3.read(ij_lc), getSubA());
}
else {
// TODO check and document why tranposed operand can be accessed with same index locally

const GlobalTileIndex ijT = transposed(ij);
const LocalTileIndex ijT_lc = dist.local_tile_index(ijT);
// TODO fix doc
// Note:
// - We are updating from both L and R.
// - We are computing all combinations of W3 and V (and viceversa), and putting results in A
// - By looping on position of A that will contain the result
// - We use the same row for the first operand
// - We use the col as the row for the second operand
const SizeType iT_lc = dist.template local_tile_from_global_tile<Coord::Row>(ij.col());

// A -= W3 . V*
her2kOffDiag<B>(priority, W3.read(ij_lc), V.read(ijT_lc), getSubA());
her2kOffDiag<B>(priority, W3.read(ij_lc), V.read({iT_lc, 0}), getSubA());
// A -= V . W3*
her2kOffDiag<B>(priority, V.read(ij_lc), W3.read(ijT_lc), getSubA());
her2kOffDiag<B>(priority, V.read(ij_lc), W3.read({iT_lc, 0}), getSubA());
}
}
}
Expand Down Expand Up @@ -404,9 +405,6 @@ void her2k_2nd(const SizeType i_end, const SizeType j_end, const matrix::SubMatr
const auto priority =
(j_lc == at_offset_lc.col()) ? thread_priority::high : thread_priority::normal;

std::cout << "HER2K-EXTRA" << dist.global_tile_index(ij_lc) << std::endl;
print_sync("tile_w1", W1.read(ij_lc));
print_sync("tile_v", VT.read(ij_lc));
// A -= X . V*
her2kOffDiag<B>(priority, W1.read(ij_lc), VT.read(ij_lc), getSubA());
}
Expand Down Expand Up @@ -698,9 +696,8 @@ CARed2BandResult<T, D> CAReductionToBand<B, D, T>::call(comm::CommunicatorGrid&
auto&& panel_heads = panels_heads.nextResource();
panel_heads.setRangeEnd({n_qr_heads, 0});

const matrix::Distribution dist_heads_current(LocalElementSize(n_qr_heads * dist.block_size().rows(),
dist.block_size().cols()),
dist.block_size());
const matrix::Distribution dist_heads_current(
LocalElementSize(n_qr_heads * dist.block_size().rows(), band_size), dist.block_size());
const matrix::SubPanelView panel_heads_view(dist_heads_current, {0, 0}, band_size);

const bool rank_has_head_row = !panel_view.iteratorLocal().empty();
Expand Down Expand Up @@ -823,17 +820,9 @@ CARed2BandResult<T, D> CAReductionToBand<B, D, T>::call(comm::CommunicatorGrid&
}));
}

for (const auto& i_lc : ws_V.iteratorLocal()) {
std::ostringstream ss;
ss << "V2nd(" << dist.global_tile_index(i_lc) << ")";
print_sync(ss.str(), ws_V.read(i_lc));
}

using factorization::internal::computeTFactor;
const GlobalTileIndex j_tau(j, 0);
computeTFactor<B>(panel_heads, mat_taus_2nd.read(j_tau), ws_T.readwrite(zero_lc));

print_sync("T2nd", ws_T.read(zero_lc));
}

auto& ws_VT = panels_vt.nextResource();
Expand Down Expand Up @@ -884,8 +873,6 @@ CARed2BandResult<T, D> CAReductionToBand<B, D, T>::call(comm::CommunicatorGrid&
ws_W2.readwrite(zero_lc)));
}

print_sync("W2", ws_W2.read(zero_lc));

// W1 = W1 - 0.5 V W2
red2band::local::gemmUpdateX<B, D>(ws_W1, ws_W2, ws_V);
}
Expand Down

0 comments on commit 9fc6fbf

Please sign in to comment.