Skip to content

Commit

Permalink
Merge pull request #380 from astro-informatics/mm/stochastic_testing
Browse files Browse the repository at this point in the history
Stochastic test cases
  • Loading branch information
mmcleod89 authored Feb 10, 2025
2 parents 2e748a9 + a76ec4a commit 7bbd786
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 6 deletions.
18 changes: 18 additions & 0 deletions cpp/purify/setup_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
#include <sopt/l1_non_diff_function.h>
#include <sopt/l2_differentiable_func.h>
#include <sopt/non_differentiable_func.h>

#ifdef PURIFY_ONNXRT
#include <sopt/onnx_differentiable_func.h>
#endif

#include <sopt/power_method.h>
#include <sopt/real_indicator.h>

#ifdef PURIFY_ONNXRT
#include <sopt/tf_non_diff_function.h>
#endif

using namespace purify;

Expand Down Expand Up @@ -308,9 +315,14 @@ void setupCostFunctions(const YamlParser &params, std::unique_ptr<Differentiable
f = std::make_unique<sopt::L2DifferentiableFunc<t_complex>>(sigma, Phi);
break;
case purify::diff_func_type::L2Norm_with_CRR:
#ifdef PURIFY_ONNXRT
f = std::make_unique<sopt::ONNXDifferentiableFunc<t_complex>>(
params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma, params.CRR_mu(),
params.CRR_lambda(), Phi);
#else
throw std::runtime_error(
"To use the CRR you must compile with ONNX runtime turned on. (-Donnxrt=on)");
#endif
break;
}

Expand All @@ -319,8 +331,14 @@ void setupCostFunctions(const YamlParser &params, std::unique_ptr<Differentiable
g = std::make_unique<sopt::algorithm::L1GProximal<t_complex>>();
break;
case purify::nondiff_func_type::Denoiser:
#ifdef PURIFY_ONNXRT
g = std::make_unique<sopt::algorithm::TFGProximal<t_complex>>(params.model_path());
break;
#else
throw std::runtime_error(
"To use the Denoiser you must compile with ONNX runtime turned on. (-Donnxrt=on)");
#endif

case purify::nondiff_func_type::RealIndicator:
g = std::make_unique<sopt::algorithm::RealIndicator<t_complex>>();
break;
Expand Down
117 changes: 116 additions & 1 deletion cpp/tests/algo_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include <sopt/onnx_differentiable_func.h>
#endif

#ifdef PURIFY_H5
#include "purify/h5reader.h"
#endif

#include <sopt/power_method.h>

#include "purify/test_data.h"
Expand Down Expand Up @@ -169,7 +173,7 @@ TEST_CASE("fb_factory") {

auto const diagnostic = (*fb)();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
// pfitsio::write2d(image.real(), result_path);
pfitsio::write2d(image.real(), result_path);
// pfitsio::write2d(residual_image.real(), expected_residual_path);

double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
Expand All @@ -182,6 +186,117 @@ TEST_CASE("fb_factory") {
CHECK(mse <= average_intensity * 1e-3);
}

#ifdef PURIFY_H5
TEST_CASE("fb_factory_stochastic") {
const std::string &test_dir = "expected/fb/";
const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
const std::string &result_path = data_filename(test_dir + "fb_result_stochastic.fits");

auto uv_data = utilities::read_visibility(input_data_path, false);
uv_data.units = utilities::vis_units::radians;
CAPTURE(uv_data.vis.head(5));
REQUIRE(uv_data.size() == 13107);

t_uint const imsizey = 128;
t_uint const imsizex = 128;

// This functor would be defined in Purify
std::mt19937 rng(0);
const size_t N = 1000;
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
[&input_data_path, imsizex, imsizey, &rng, &N]() {
utilities::vis_params uv_data = utilities::read_visibility(input_data_path, false);
uv_data.units = utilities::vis_units::radians;

// Get random subset
std::vector<size_t> indices(uv_data.size());
size_t i = 0;
for (auto &x : indices) {
x = i++;
}

std::shuffle(indices.begin(), indices.end(), rng);
Vector<t_real> u_fragment(N);
Vector<t_real> v_fragment(N);
Vector<t_real> w_fragment(N);
Vector<t_complex> vis_fragment(N);
Vector<t_complex> weights_fragment(N);
for (i = 0; i < N; i++) {
size_t j = indices[i];
u_fragment[i] = uv_data.u[j];
v_fragment[i] = uv_data.v[j];
w_fragment[i] = uv_data.w[j];
vis_fragment[i] = uv_data.vis[j];
weights_fragment[i] = uv_data.weights[j];
}
utilities::vis_params uv_data_fragment(u_fragment, v_fragment, w_fragment, vis_fragment,
weights_fragment, uv_data.units, uv_data.ra,
uv_data.dec, uv_data.average_frequency);

auto phi = factory::measurement_operator_factory<Vector<t_complex>>(
factory::distributed_measurement_operator::serial, uv_data_fragment, imsizey, imsizex,
1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);

return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data_fragment.vis, phi);
};

Vector<t_complex> const init = Vector<t_complex>::Ones(imsizex * imsizey);
auto IS = random_updater();
auto Phi = IS->Phi();
auto const power_method_stuff =
sopt::algorithm::power_method<Vector<t_complex>>(Phi, 1000, 1e-5, init);
const t_real op_norm = std::get<0>(power_method_stuff);

const auto solution = pfitsio::read2d(expected_solution_path);

// wavelets
std::vector<std::tuple<std::string, t_uint>> const sara{
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);

// algorithm
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
t_real const beta = sigma * sigma;
t_real const gamma = 0.0001;

sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
fb.itermax(1000)
.step_size(beta * sqrt(2))
.sigma(sigma * sqrt(2))
.regulariser_strength(gamma)
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
.sq_op_norm(op_norm * op_norm);

auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
gp->l1_proximal_tolerance(1e-4)
.l1_proximal_nu(1)
.l1_proximal_itermax(50)
.l1_proximal_positivity_constraint(true)
.l1_proximal_real_constraint(true)
.Psi(*wavelets);
fb.g_function(gp);

auto const diagnostic = fb();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
// pfitsio::write2d(image.real(), result_path);
// pfitsio::write2d(residual_image.real(), expected_residual_path);

auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
double average_intensity = soln_flat.real().sum() / soln_flat.size();
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}
#endif

#ifdef PURIFY_ONNXRT
TEST_CASE("tf_fb_factory") {
const std::string &test_dir = "expected/fb/";
Expand Down
159 changes: 159 additions & 0 deletions cpp/tests/mpi_algo_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#include <sopt/power_method.h>
#include <sopt/wavelets.h>

#ifdef PURIFY_H5
#include "purify/h5reader.h"
#endif

#include "purify/algorithm_factory.h"
#include "purify/measurement_operator_factory.h"
#include "purify/wavelet_operator_factory.h"
Expand Down Expand Up @@ -311,6 +315,7 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {

const std::string &test_dir = "expected/fb/";
const std::string &input_data_path = data_filename(test_dir + "input_data.vis");
const std::string &result_path = data_filename(test_dir + "mpi_fb_result.fits");

auto uv_data = dirty_visibilities({input_data_path}, world);
uv_data.units = utilities::vis_units::radians;
Expand Down Expand Up @@ -344,6 +349,75 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);

auto const diagnostic = (*fb)();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
if (world.is_root()) {
pfitsio::write2d(image.real(), result_path);
// pfitsio::write2d(residual_image.real(), expected_residual_path);
}

const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");

const auto solution = pfitsio::read2d(expected_solution_path);
const auto residual = pfitsio::read2d(expected_residual_path);

double average_intensity = diagnostic.x.real().sum() / diagnostic.x.size();
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
double mse = (Vector<t_complex>::Map(solution.data(), solution.size()) - diagnostic.x)
.real()
.squaredNorm() /
solution.size();
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}

#ifdef PURIFY_H5
TEST_CASE("MPI_fb_factory_hdf5") {
auto const world = sopt::mpi::Communicator::World();
const size_t N = 13107;

const std::string &test_dir = "expected/fb/";
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
const std::string &result_path = data_filename(test_dir + "mpi_fb_result_hdf5.fits");
H5::H5Handler h5file(input_data_path, world);

auto uv_data = H5::stochread_visibility(h5file, 6000, false);
uv_data.units = utilities::vis_units::radians;
if (world.is_root()) {
CAPTURE(uv_data.vis.head(5));
}
// REQUIRE(world.all_sum_all(uv_data.size()) == 13107);

t_uint const imsizey = 128;
t_uint const imsizex = 128;

auto const measurements_transform = factory::measurement_operator_factory<Vector<t_complex>>(
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, imsizey, imsizex, 1,
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
*measurements_transform, 1000, 1e-5,
world.broadcast(Vector<t_complex>::Ones(imsizex * imsizey).eval()));
const t_real op_norm = std::get<0>(power_method_stuff);
std::vector<std::tuple<std::string, t_uint>> const sara{
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
t_real const sigma =
world.broadcast(0.016820222945913496) * std::sqrt(2); // see test_parameters file
t_real const beta = sigma * sigma;
t_real const gamma = 0.0001;
auto const fb = factory::fb_factory<sopt::algorithm::ImagingForwardBackward<t_complex>>(
factory::algo_distribution::mpi_serial, measurements_transform, wavelets, uv_data, sigma,
beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm);

auto const diagnostic = (*fb)();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
// if (world.is_root())
//{
// pfitsio::write2d(image.real(), result_path);
//}

const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");
Expand All @@ -360,3 +434,88 @@ TEST_CASE("Serial vs. Serial with MPI Forward Backward") {
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}

TEST_CASE("fb_factory_stochastic") {
const std::string &test_dir = "expected/fb/";
const std::string &input_data_path = data_filename(test_dir + "input_data.h5");
const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
const std::string &result_path = data_filename(test_dir + "fb_stochastic_result_mpi.fits");

// HDF5
auto const comm = sopt::mpi::Communicator::World();
const size_t N = 2000;
H5::H5Handler h5file(input_data_path, comm); // length 13107
using t_complexVec = Vector<t_complex>;

// This functor would be defined in Purify
std::function<std::shared_ptr<sopt::IterationState<Vector<t_complex>>>()> random_updater =
[&f = h5file, &N]() {
utilities::vis_params uv_data =
H5::stochread_visibility(f, N, false); // no w-term in this data-set
uv_data.units = utilities::vis_units::radians;
auto phi = factory::measurement_operator_factory<t_complexVec>(
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1,
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);

return std::make_shared<sopt::IterationState<Vector<t_complex>>>(uv_data.vis, phi);
};

auto IS = random_updater();
auto Phi = IS->Phi();
auto const power_method_stuff = sopt::algorithm::power_method<Vector<t_complex>>(
Phi, 1000, 1e-5, comm.broadcast(Vector<t_complex>::Ones(128 * 128).eval()));
const t_real op_norm = std::get<0>(power_method_stuff);

const auto solution = pfitsio::read2d(expected_solution_path);

t_uint const imsizey = 128;
t_uint const imsizex = 128;

// wavelets
std::vector<std::tuple<std::string, t_uint>> const sara{
std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
factory::distributed_wavelet_operator::serial, sara, imsizey, imsizex);

// algorithm
t_real const sigma = 0.016820222945913496 * std::sqrt(2); // see test_parameters file
t_real const beta = sigma * sigma;
t_real const gamma = 0.0001;

sopt::algorithm::ImagingForwardBackward<t_complex> fb(random_updater);
fb.itermax(1000)
.step_size(beta * sqrt(2))
.sigma(sigma * sqrt(2))
.regulariser_strength(gamma)
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true)
.sq_op_norm(op_norm * op_norm)
.obj_comm(comm);

auto gp = std::make_shared<sopt::algorithm::L1GProximal<t_complex>>(false);
gp->l1_proximal_tolerance(1e-4)
.l1_proximal_nu(1)
.l1_proximal_itermax(50)
.l1_proximal_positivity_constraint(true)
.l1_proximal_real_constraint(true)
.Psi(*wavelets);
fb.g_function(gp);

auto const diagnostic = fb();
const Image<t_complex> image = Image<t_complex>::Map(diagnostic.x.data(), imsizey, imsizex);
// if (comm.is_root())
//{
// //pfitsio::write2d(image.real(), result_path);
//}

auto soln_flat = Vector<t_complex>::Map(solution.data(), solution.size());
double average_intensity = soln_flat.real().sum() / soln_flat.size();
SOPT_HIGH_LOG("Average intensity = {}", average_intensity);
double mse = (soln_flat - diagnostic.x).real().squaredNorm() / solution.size();
SOPT_HIGH_LOG("MSE = {}", mse);
CHECK(mse <= average_intensity * 1e-3);
}
#endif
5 changes: 0 additions & 5 deletions cpp/uncertainty_quantification/uq_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
#include <sopt/l1_non_diff_function.h>
#include <sopt/l2_differentiable_func.h>
#include <sopt/real_indicator.h>
#include <sopt/tf_non_diff_function.h>

#ifdef PURIFY_ONNXRT
#include <sopt/onnx_differentiable_func.h>
#endif

using VectorC = sopt::Vector<std::complex<double>>;

Expand Down

0 comments on commit 7bbd786

Please sign in to comment.