Skip to content

Commit

Permalink
full coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
gykovacs committed Dec 18, 2023
1 parent 8a8350f commit 767216e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 10 deletions.
8 changes: 4 additions & 4 deletions smote_variants/evaluation/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ def _pivot_best_scores(pdf):
for score in all_scores:
tmp = pdf[pdf[score + "_mean"] == pdf[score + "_mean"].max()]

results_dict_classifier[f"{score}_classifier_params"] = [None]
results_dict_classifier[f"{score}_oversampler_params"] = [None]
score_means[f"{score}_mean"] = [None]
score_stds[f"{score}_std"] = [None]
results_dict_classifier[f"{score}_classifier_params"] = pd.Series([None], dtype='str')
results_dict_classifier[f"{score}_oversampler_params"] = pd.Series([None], dtype='str')
score_means[f"{score}_mean"] = pd.Series([None], dtype='float')
score_stds[f"{score}_std"] = pd.Series([None], dtype='float')

if len(tmp) > 0:
results_dict_classifier[f"{score}_classifier_params"] = [
Expand Down
2 changes: 2 additions & 0 deletions smote_variants/evaluation/_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"ThreadTimeoutProcessPool",
"wait_for_lock",
"queue_get_default",
"FunctionWrapperJob",
"execute_job_object"
]


Expand Down
10 changes: 5 additions & 5 deletions smote_variants/oversampling/_smotewb.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,11 @@ def sampling_algorithm(
for j in range(C[i]):
synt_sample_list.append(X_min[i].copy())
elif fl_arr[i] == SMOTEWB.sample_good and C[i] > 0:
if noise_mask_min[i]:
nn = indices[i, 0 : k_arr[i]]
else:
# removing self index
nn = indices[i, 1 : k_arr[i] + 1]
nn = (
indices[i, 0 : k_arr[i]]
if noise_mask_min[i]
else indices[i, 1 : k_arr[i] + 1]
)

if len(nn) > 0:
k_ids = self.random_state.choice(nn, C[i])
Expand Down
34 changes: 34 additions & 0 deletions tests/evaluation/test_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from smote_variants.evaluation import (
TimeoutJobBase,
ThreadTimeoutProcessPool,
FunctionWrapperJob,
wait_for_lock,
queue_get_default,
execute_job_object
)

sleeps = [1, 2, 6, 7]
Expand Down Expand Up @@ -66,6 +68,38 @@ def sleep_job(sleep):
return {"slept": sleep}


def test_sleeping():
"""
Testing the sleeping
"""
result = SleepJob(1).execute()
assert result['slept'] == 1

result = sleep_job(1)
assert result['slept'] == 1


def test_function_wrapper_job():
"""
Testing the function wrapper job
"""
fwj = FunctionWrapperJob(sleep_job, 1)
assert fwj.execute()['slept'] == 1


def test_execute_job_object():
"""
Testing the job object execution
"""
fwj = FunctionWrapperJob(sleep_job, 1)

queue = multiprocessing.Queue()
queue_lock = multiprocessing.Lock()

execute_job_object(fwj, queue, queue_lock)

assert True

def test_jobs_objects_timeout():
"""
Testing the job objects with timeout.
Expand Down
3 changes: 2 additions & 1 deletion tests/evaluation/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def test_sampling():
for fold in folding.fold():
sjob = SamplingJob(fold, oversampler, oversampler_params, cache_path=cache_path)

result = sjob.do_oversampling()
# note that here the execute is used
result = sjob.execute()
assert os.path.exists(result)

assert isinstance(sjob.timeout(), str)
Expand Down

0 comments on commit 767216e

Please sign in to comment.