Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT, API Change trunk-classification into three separate functions for generating trunk, trunk-mix, trunk-overlap and marron-wand #227

Merged
merged 8 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions sktree/datasets/hyppo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import numpy as np
from scipy.integrate import nquad
from scipy.stats import entropy, multivariate_normal
Expand Down Expand Up @@ -78,7 +80,7 @@ def make_trunk_classification(
rho: int = 0,
band_type: str = "ma",
return_params: bool = False,
mix: float = 0.5,
mix: Optional[float] = None,
seed=None,
):
"""Generate trunk and/or Marron-Wand datasets.
Expand Down Expand Up @@ -128,7 +130,7 @@ def make_trunk_classification(
Whether or not to return the distribution parameters of the classes normal distributions.
mix : int, optional
The probabilities associated with the mixture of Gaussians in the ``trunk-mix`` simulation.
By default 0.5.
By default None. Must be specified if ``simulation = trunk_mix``.
seed : int, optional
Random seed, by default None.

Expand Down Expand Up @@ -162,6 +164,12 @@ def make_trunk_classification(
f"Number of informative dimensions {n_informative} must be less than number "
f"of dimensions, {n_dim}"
)
if mix is not None and simulation != "trunk_mix":
raise ValueError(
f"Mix should not be specified when simulation is not 'trunk_mix'. Simulation is {simulation}."
)
if mix is None and simulation == "trunk_mix":
raise ValueError("Mix must be specified when simulation is 'trunk_mix'.")
rng = np.random.default_rng(seed=seed)

mu_1 = np.array([1 / np.sqrt(i) for i in range(1, n_informative + 1)])
Expand All @@ -177,7 +185,7 @@ def make_trunk_classification(
else:
cov = np.identity(n_informative)

if mix < 0 or mix > 1:
if mix is not None and mix < 0 or mix > 1: # type: ignore
raise ValueError("Mix must be between 0 and 1.")

# speed up computations for large multivariate normal matrix with SVD approximation
Expand Down Expand Up @@ -205,7 +213,7 @@ def make_trunk_classification(
)
)
elif simulation == "trunk_mix":
mixture_idx = rng.choice(2, n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix])
mixture_idx = rng.choice(2, n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix]) # type: ignore
norm_params = [[mu_0, cov * (2 / 3) ** 2], [mu_1, cov * (2 / 3) ** 2]]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sampan501 can you explain why we have a (2/3)**2 here and only in trunk-mix?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When variance is 1, trunk-mix does not look bimodal at low dimensions. I set it to (2/3)**2 since that is consistent with Marron and Wand bimodal

X_mixture = np.fromiter(
(
Expand Down
13 changes: 13 additions & 0 deletions sktree/datasets/tests/test_hyppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ def test_make_trunk_classification_invalid_simulation_name():
make_trunk_classification(n_samples=50, rho=0.5, simulation=None)


def test_make_trunk_classification_errors_trunk_mix():
# test with mix but not trunk_mix
with pytest.raises(
ValueError,
match="Mix should not be specified when simulation is not 'trunk_mix'. Simulation is trunk.",
):
make_trunk_classification(n_samples=2, simulation="trunk", mix=0.5)

# test without mix but trunk_mix
with pytest.raises(ValueError, match="Mix must be specified when simulation is 'trunk_mix'."):
make_trunk_classification(n_samples=2, simulation="trunk_mix")


@pytest.mark.parametrize(
"simulation", ["trunk", "trunk_overlap", "trunk_mix", *MARRON_WAND_SIMS.keys()]
)
Expand Down
Loading