Skip to content

Commit

Permalink
test: add more assertions and test these
Browse files Browse the repository at this point in the history
  • Loading branch information
jolars committed Feb 4, 2025
1 parent dc4da90 commit 815e988
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
6 changes: 6 additions & 0 deletions src/slope/regularization_sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ lambdaSequence(const int p,
}
}
} else if (type == "oscar") {
if (theta1 <= 0) {
throw std::invalid_argument("theta1 must be non-negative");
}
if (theta2 <= 0) {
throw std::invalid_argument("theta2 must be non-negative");
}
lambda = theta1 + theta2 * (p - Eigen::ArrayXd::LinSpaced(p, 1, p));
} else if (type == "lasso") {
lambda.setOnes();
Expand Down
31 changes: 21 additions & 10 deletions tests/lambda_sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
TEST_CASE("Test that regularization sequence generation works",
"[regularization sequence]")
{
double tol = 1e-6;
int n = 10;
int p = 4;
double q = 0.3;

SECTION("BH sequence")
{
Eigen::ArrayXd lambda1 = slope::lambdaSequence(4, 0.1, "bh");
Eigen::ArrayXd lambda1 = slope::lambdaSequence(p, 0.1, "bh");

std::vector<double> lambda1_expected = {
2.24140272760495, 1.95996398454005, 1.78046434169203, 1.64485362695147
Expand All @@ -29,10 +34,6 @@ TEST_CASE("Test that regularization sequence generation works",

SECTION("Gaussian sequence")
{
int n = 10;
int p = 4;
double q = 0.3;
double tol = 1e-6;

Eigen::ArrayXd l1 = slope::lambdaSequence(p, q, "gaussian", n);
std::vector<double> l1_ref = { 1.780464, 1.700998, 1.657533, 1.628381 };
Expand All @@ -50,11 +51,6 @@ TEST_CASE("Test that regularization sequence generation works",

SECTION("OSCAR sequence")
{
int n = 10;
int p = 4;
double q = 0.3;
double tol = 1e-6;

Eigen::ArrayXd l3 = slope::lambdaSequence(p, q, "oscar", n, 2.0, 0.1);
std::vector<double> l3_ref = {
2.3,
Expand All @@ -65,4 +61,19 @@ TEST_CASE("Test that regularization sequence generation works",

REQUIRE_THAT(l3, VectorApproxEqual(l3_ref, tol));
}

SECTION("Lasso sequence")
{
Eigen::ArrayXd l3 = slope::lambdaSequence(p, q, "lasso", n);
REQUIRE_THAT(l3, VectorApproxEqual(std::vector<double>(4, 1.0), tol));
}

SECTION("Assertions")
{
REQUIRE_THROWS(slope::lambdaSequence(p, q, "gaussian", -5));
REQUIRE_THROWS(slope::lambdaSequence(p, 0.0, "bh"));
REQUIRE_THROWS(slope::lambdaSequence(p, 1.0, "bh"));
REQUIRE_THROWS(slope::lambdaSequence(p, 1.0, "oscar", 0, -1.0, 1.0));
REQUIRE_THROWS(slope::lambdaSequence(p, 1.0, "oscar", 0, 1.0, -1.0));
}
}

0 comments on commit 815e988

Please sign in to comment.