Skip to content

Commit

Permalink
rewrite for reproduction purpose only
Browse files Browse the repository at this point in the history
donglihe-hub committed Nov 30, 2023
1 parent 8fa186e commit 173ee1b
Showing 2 changed files with 2 additions and 25 deletions.
9 changes: 0 additions & 9 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
@@ -186,9 +186,6 @@ def train_thresholding(
Returns:
A model which can be used in predict_values.
"""
if not is_multilabel:
raise ValueError("thresholding method doesn't support binary/multiclass datasets.")

x, options, bias = _prepare_options(x, options)

y = y.tocsc()
@@ -413,9 +410,6 @@ def train_cost_sensitive(
Returns:
A model which can be used in predict_values.
"""
if not is_multilabel:
raise ValueError("cost_sensitive method doesn't support binary/multiclass datasets.")

# Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
x, options, bias = _prepare_options(x, options)

@@ -520,9 +514,6 @@ def train_cost_sensitive_micro(
Returns:
A model which can be used in predict_values.
"""
if not is_multilabel:
raise ValueError("cost_sensitive_micro method doesn't support binary/multiclass datasets.")

# Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
x, options, bias = _prepare_options(x, options)

18 changes: 2 additions & 16 deletions linear_trainer.py
Original file line number Diff line number Diff line change
@@ -11,9 +11,7 @@

def linear_test(config, model, datasets, label_mapping):
metrics = linear.get_metrics(
config.monitor_metrics,
datasets["test"]["y"].shape[1],
multiclass=not model.is_multilabel,
config.monitor_metrics, datasets["test"]["y"].shape[1], multiclass=not model.is_multilabel
)
num_instance = datasets["test"]["x"].shape[0]
k = config.save_k_predictions
@@ -40,19 +38,7 @@ def linear_test(config, model, datasets, label_mapping):

def linear_train(datasets, config):
# detect task type
is_multilabel = config.get("is_multilabel", "auto")
if is_multilabel == "auto":
is_multilabel = not is_multiclass_dataset(datasets["train"], "y")
elif not isinstance(is_multilabel, bool):
raise ValueError(
f'"is_multilabel" is expected to be either "auto", "True", or "False". But got "{is_multilabel}" instead.'
)

task_type = "multilabel" if is_multilabel else "binary/multiclass"
logging.info(
f'is_multilabel is set to "{config.get("is_multilabel", "auto")}". '
f"Model will be trained in {task_type} mode."
)
is_multilabel = not is_multiclass_dataset(datasets["train"], "y")

# train
if config.linear_technique == "tree":

0 comments on commit 173ee1b

Please sign in to comment.