Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ts plots #48

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions atom/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pandas as pd
from beartype import beartype
from joblib.memory import Memory
from pandas._typing import DtypeObj

Check notice on line 26 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _typing of a class
from scipy import stats
from sklearn.pipeline import Pipeline as SkPipeline
from sklearn.utils.metaestimators import available_if
Expand Down Expand Up @@ -56,9 +56,9 @@
FloatZeroToOneInc, Index, IndexSelector, Int, IntLargerEqualZero,
IntLargerTwo, IntLargerZero, MetricConstructor, ModelsConstructor, NItems,
NJobs, NormalizerStrats, NumericalStrats, Operators, Pandas, Predictor,
PrunerStrats, RowSelector, Scalar, ScalerStrats, Seasonality, Sequence,
Series, TargetSelector, Transformer, VectorizerStarts, Verbose, Warnings,
XSelector, YSelector, sequence_t,
PrunerStrats, RowSelector, Scalar, ScalerStrats, Seasonality,
SeasonalityMode, Sequence, Series, TargetSelector, Transformer,
VectorizerStarts, Verbose, Warnings, XSelector, YSelector, sequence_t,
)
from atom.utils.utils import (
ClassMap, DataConfig, DataContainer, Goal, adjust_verbosity, bk,
Expand Down Expand Up @@ -288,7 +288,7 @@
@ignore.setter
def ignore(self, value: ColumnSelector | None):
if value is not None:
self._config.ignore = tuple(self.branch._get_columns(value, include_target=False))

Check notice on line 291 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_columns of a class
else:
self._config.ignore = ()

Expand Down Expand Up @@ -481,7 +481,7 @@
- **p_value:** Corresponding p-value.

"""
columns_c = self.branch._get_columns(columns, only_numerical=True)

Check notice on line 484 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_columns of a class

df = pd.DataFrame(
index=pd.MultiIndex.from_product(
Expand All @@ -493,7 +493,7 @@

for col in columns_c:
# Drop missing values from the column before testing
X = replace_missing(self[col], self.missing).dropna().to_numpy(dtype=float)

Check notice on line 496 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

for test in ("adf", "kpss", "lb"):
if test == "adf":
Expand All @@ -505,7 +505,7 @@
stat = l_jung.loc[l_jung["lb_pvalue"].idxmin()]

# Add as column to the dataframe
df.loc[(test, "score"), col] = round(stat[0], 4)

Check warning on line 508 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'stat' might be referenced before assignment
df.loc[(test, "p_value"), col] = round(stat[1], 4)

return df
Expand Down Expand Up @@ -567,7 +567,7 @@
else:
distributions_c = lst(distributions)

columns_c = self.branch._get_columns(columns, only_numerical=True)

Check notice on line 570 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_columns of a class

df = pd.DataFrame(
index=pd.MultiIndex.from_product(
Expand All @@ -579,7 +579,7 @@

for col in columns_c:
# Drop missing values from the column before testing
X = replace_missing(self[col], self.missing).dropna().to_numpy(dtype=float)

Check notice on line 582 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

for dist in distributions_c:
# Get KS-statistic with fitted distribution parameters
Expand Down Expand Up @@ -634,22 +634,22 @@
self._log("Creating EDA report...", 1)

if isinstance(rows, str):
rows_c = [(self.branch._get_rows(rows), rows)]

Check notice on line 637 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class
elif isinstance(rows, sequence_t):
rows_c = [(self.branch._get_rows(r), r) for r in rows]

Check notice on line 639 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class
elif isinstance(rows, dict):
rows_c = [(self.branch._get_rows(v), k) for k, v in rows.items()]

Check notice on line 641 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class

if len(rows_c) == 1:

Check warning on line 643 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'rows_c' might be referenced before assignment
self.report = sv.analyze(
source=rows_c[0],
target_feat=self.branch._get_target(target, only_columns=True),

Check notice on line 646 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_target of a class
)
elif len(rows_c) == 2:
self.report = sv.compare(

Check notice on line 649 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute report defined outside __init__
source=rows_c[0],
compare=rows_c[1],
target_feat=self.branch._get_target(target, only_columns=True),

Check notice on line 652 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_target of a class
)
else:
raise ValueError(
Expand All @@ -661,12 +661,12 @@
if (path := Path(filename)).suffix != ".html":
path = path.with_suffix(".html")

self.report.show_notebook(filepath=path if filename else None)

Check warning on line 664 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'path' might be referenced before assignment

@composed(crash, method_to_log)
def inverse_transform(
self,
X: XSelector | None = None,

Check notice on line 669 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YSelector | None = None,
*,
verbose: Verbose | None = None,
Expand Down Expand Up @@ -709,9 +709,9 @@
Original target column. Only returned if provided.

"""
X, y = self._check_input(X, y, columns=self.branch.features, name=self.branch.target)

Check notice on line 712 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

with adjust_verbosity(self.pipeline, verbose) as pipeline:

Check warning on line 714 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Incorrect call arguments

Unexpected argument
return pipeline.inverse_transform(X, y)

@classmethod
Expand Down Expand Up @@ -786,14 +786,14 @@

if data is not None:
# Prepare the provided data
container, holdout = atom._get_data(data)

Check notice on line 789 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_data of a class

# Assign the data to the original branch
if atom._branches._og is not None:

Check notice on line 792 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _og of a class

Check notice on line 792 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _branches of a class
atom._branches._og._container = container

Check notice on line 793 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _og of a class

Check notice on line 793 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _branches of a class

# Apply transformations per branch
for branch in atom._branches:

Check notice on line 796 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _branches of a class
if branch._container is None:
branch._container = deepcopy(container)
branch._holdout = holdout
Expand All @@ -803,22 +803,22 @@
f"already contains data in branch {branch.name}."
)

if len(atom._branches) > 2 and branch.pipeline:

Check notice on line 806 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _branches of a class
atom._log(f"Transforming data for branch {branch.name}:", 1)

Check notice on line 807 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _log of a class

X_train, y_train = branch.pipeline.transform(

Check notice on line 809 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
X=branch.X_train,
y=branch.y_train,
filter_train_only=False,
)
X_test, y_test = branch.pipeline.transform(branch.X_test, branch.y_test)

Check notice on line 814 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Update complete dataset
branch._container.data = bk.concat(
[merge(X_train, y_train), merge(X_test, y_test)]
)

if atom._config.index is False:

Check notice on line 821 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _config of a class
branch._container = DataContainer(
data=(dataset := branch._container.data.reset_index(drop=True)),
train_idx=dataset.index[:len(branch._container.train_idx)],
Expand All @@ -830,7 +830,7 @@
if branch is not atom.branch:
branch.store()

atom._log(f"{atom.__class__.__name__} successfully loaded.", 1)

Check notice on line 833 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _log of a class

return atom

Expand Down Expand Up @@ -883,7 +883,7 @@
else:
path = path.with_name(f"{self.__class__.__name__}.csv")

self.branch._get_rows(rows).to_csv(path, **kwargs)

Check notice on line 886 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_rows of a class
self._log("Data set successfully saved.", 1)

@composed(crash, method_to_log)
Expand Down Expand Up @@ -922,7 +922,7 @@
Whether to convert all features to sparse format. The value
that is compressed is the most frequent value in the column.

columns: int, str, segment, sequence or None, default=None
columns: int, str, segment, sequence, dataframe or None, default=None
[Selection of columns][row-and-column-selection] to shrink. If
None, transform all columns.

Expand Down Expand Up @@ -972,7 +972,7 @@
"float": [(x.name, np.finfo(x.type).min, np.finfo(x.type).max) for x in t3],
}

data = self.branch.dataset[self.branch._get_columns(columns)]

Check notice on line 975 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_columns of a class

# Convert back since convert_dtypes doesn't work properly for pyarrow dtypes
data = data.astype({n: to_pyarrow(c, inverse=True) for n, c in data.items()})
Expand All @@ -986,20 +986,20 @@

if old_t.name.startswith("string"):
if str2cat and column.nunique() <= int(len(column) * 0.3):
self.branch._data.data[name] = get_data(pd.CategoricalDtype())

Check notice on line 989 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
continue

try:
# Get the types to look at
t = next(v for k, v in types.items() if old_t.name.lower().startswith(k))
except StopIteration:
self.branch._data.data[name] = get_data(column.dtype)

Check notice on line 996 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
continue

# Use bool if values are in (0, 1)
if int2bool and (t == types["int"] or t == types["uint"]):
if column.isin([0, 1]).all() or column.isin([-1, 1]).all():
self.branch._data.data[name] = get_data(pd.BooleanDtype())

Check notice on line 1002 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
continue

# Use uint if values are strictly positive
Expand All @@ -1007,7 +1007,7 @@
t = types["uint"]

# Find the smallest type that fits
self.branch._data.data[name] = next(

Check notice on line 1010 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
get_data(r[0]) for r in t if r[1] <= column.min() and r[2] >= column.max()
)

Expand Down Expand Up @@ -1098,7 +1098,7 @@
@composed(crash, method_to_log)
def transform(
self,
X: XSelector | None = None,

Check notice on line 1101 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YSelector | None = None,
*,
verbose: Verbose | None = None,
Expand Down Expand Up @@ -1141,9 +1141,9 @@
Transformed target column. Only returned if provided.

"""
X, y = self._check_input(X, y, columns=self.og.features, name=self.og.target)

Check notice on line 1144 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

with adjust_verbosity(self.pipeline, verbose) as pipeline:

Check warning on line 1146 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Incorrect call arguments

Unexpected argument
return pipeline.transform(X, y)

# Base transformers ============================================ >>
Expand Down Expand Up @@ -1201,7 +1201,7 @@
has the `n_jobs` and/or `random_state` parameters, it
adopts atom's values.

columns: int, str, segment, sequence or None, default=None
columns: int, str, segment, sequence, dataframe or None, default=None
Columns in the dataset to transform. If None, transform
all features.

Expand Down Expand Up @@ -1234,7 +1234,7 @@
transformer_c._train_only = train_only

if columns is not None:
cols = self.branch._get_columns(columns)

Check notice on line 1237 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_columns of a class
else:
cols = list(self.branch.features)

Expand Down Expand Up @@ -1271,7 +1271,7 @@
# Check if the fitted estimator is retrieved from cache to inform
# the user, else user might notice the lack of printed messages
if self.memory.location is not None:
if fit._is_in_cache_and_valid([*fit._get_output_identifiers(**kwargs)]):

Check notice on line 1274 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _is_in_cache_and_valid of a class

Check notice on line 1274 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_output_identifiers of a class
self._log(
f"Retrieving cached results for {transformer_c.__class__.__name__}...", 1
)
Expand All @@ -1283,33 +1283,33 @@
self._branches.add("og")

if transformer_c._train_only:
X, y = self.pipeline._mem_transform(transformer_c, self.X_train, self.y_train)

Check notice on line 1286 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

Check notice on line 1286 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _mem_transform of a class
self.train = merge(
self.X_train if X is None else X,
self.y_train if y is None else y,
)
else:
X, y = self.pipeline._mem_transform(transformer_c, self.X, self.y)

Check notice on line 1292 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

Check notice on line 1292 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _mem_transform of a class
data = merge(self.X if X is None else X, self.y if y is None else y)

# y can change the number of columns or remove rows -> reassign index
self.branch._container = DataContainer(
data=data,
train_idx=self.branch._data.train_idx.intersection(data.index),

Check notice on line 1298 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
test_idx=self.branch._data.test_idx.intersection(data.index),

Check notice on line 1299 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
n_cols=self.branch._data.n_cols if y is None else len(get_cols(y)),

Check notice on line 1300 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
)

if self._config.index is False:
self.branch._container = DataContainer(
data=(data := self.dataset.reset_index(drop=True)),
train_idx=data.index[: len(self.branch._data.train_idx)],

Check notice on line 1306 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
test_idx=data.index[-len(self.branch._data.test_idx):],

Check notice on line 1307 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
n_cols=self.branch._data.n_cols,

Check notice on line 1308 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _data of a class
)
if self.branch._holdout is not None:

Check notice on line 1310 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _holdout of a class
self.branch._holdout.index = range(

Check notice on line 1311 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _holdout of a class
len(data), len(data) + len(self.branch._holdout)

Check notice on line 1312 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _holdout of a class
)
elif self.dataset.index.duplicated().any():
raise ValueError(
Expand Down Expand Up @@ -1388,7 +1388,7 @@
instance), and it has the `n_jobs` and/or `random_state`
parameters, it adopts atom's values.

columns: int, str, segment, sequence or None, default=None
columns: int, str, segment, sequence, dataframe or None, default=None
[Selection of columns][row-and-column-selection] to
transform. Only select features or the target column, not
both at the same time (if that happens, the target column
Expand Down Expand Up @@ -1463,7 +1463,7 @@
Additional keyword arguments for the inverse function.

"""
FunctionTransformer = self._get_est_class("FunctionTransformer", "preprocessing")

Check notice on line 1466 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

columns = kwargs.pop("columns", None)
transformer = FunctionTransformer(
Expand Down Expand Up @@ -1557,14 +1557,14 @@
cleaner.missing_ = self.missing

cleaner = self._add_transformer(cleaner, columns=columns)
self.branch._mapping.update(cleaner.mapping_)

Check notice on line 1560 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _mapping of a class

@composed(crash, method_to_log)
def decompose(
self,
*,
model: str | Predictor | None = None,
mode: Literal["additive", "multiplicative"] = "additive",
mode: SeasonalityMode = "additive",
**kwargs,
):
"""Detrend and deseasonalize the time series.
Expand All @@ -1584,9 +1584,7 @@
* Use the `columns` parameter to only decompose the target
column, e.g., `atom.decompose(columns=atom.target)`.
* Use the [plot_decomposition][] method to visualize the
trend, seasonality and residuals of the time series. This
can help to determine if the data follows an additive or
multiplicative trend.
trend, seasonality and residuals of the time series.

"""
columns = kwargs.pop("columns", None)
Expand Down Expand Up @@ -1679,7 +1677,7 @@
)

encoder = self._add_transformer(encoder, columns=columns)
self.branch._mapping.update(encoder.mapping_)

Check notice on line 1680 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _mapping of a class

@composed(crash, method_to_log)
def impute(
Expand Down Expand Up @@ -2053,7 +2051,7 @@
columns = kwargs.pop("columns", None)
feature_grouper = FeatureGrouper(
groups={
name: self.branch._get_columns(fxs, include_target=False)

Check notice on line 2054 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_columns of a class
for name, fxs in groups.items()
},
operators=operators,
Expand Down Expand Up @@ -2182,7 +2180,7 @@
trainer.run()

# Overwrite models with the same name as new ones
for model in trainer._models:

Check notice on line 2183 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _models of a class
if model.name in self._models:
self._delete_models(model.name)
self._log(
Expand All @@ -2190,7 +2188,7 @@
"The former model has been overwritten.", 3,
)

self._models.extend(trainer._models)

Check notice on line 2191 in atom/atom.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _models of a class
self._metric = trainer._metric

@composed(crash, method_to_log)
Expand Down
7 changes: 4 additions & 3 deletions atom/data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TomekLinks,
)
from scipy.stats import zscore
from sklearn.base import BaseEstimator, _clone_parametrized

Check notice on line 37 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _clone_parametrized of a class
from sklearn.compose import ColumnTransformer
from sklearn.experimental import enable_iterative_imputer # noqa: F401
from sklearn.impute import IterativeImputer, KNNImputer
Expand All @@ -47,8 +47,9 @@
Bins, Bool, CategoricalStrats, DataFrame, DiscretizerStrats, Engine,
Estimator, FloatLargerZero, IntLargerEqualZero, IntLargerTwo,
IntLargerZero, NJobs, NormalizerStrats, NumericalStrats, Pandas, Predictor,
PrunerStrats, Scalar, ScalerStrats, Sequence, Series, Transformer, Verbose,
XSelector, YSelector, dataframe_t, sequence_t, series_t,
PrunerStrats, Scalar, ScalerStrats, SeasonalityMode, Sequence, Series,
Transformer, Verbose, XSelector, YSelector, dataframe_t, sequence_t,
series_t,
)
from atom.utils.utils import (
Goal, bk, composed, crash, get_col_order, get_cols, it, lst, merge,
Expand Down Expand Up @@ -90,7 +91,7 @@

def fit(
self,
X: DataFrame | None = None,

Check notice on line 94 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas | None = None,
**fit_params,
) -> Self:
Expand Down Expand Up @@ -133,7 +134,7 @@
@composed(crash, method_to_log)
def fit_transform(
self,
X: XSelector | None = None,

Check notice on line 137 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: YSelector | None = None,
**fit_params,
) -> Pandas | tuple[DataFrame, Pandas]:
Expand Down Expand Up @@ -175,7 +176,7 @@
@composed(crash, method_to_log)
def inverse_transform(
self,
X: DataFrame | None = None,

Check notice on line 179 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas | None = None,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Do nothing.
Expand Down Expand Up @@ -346,7 +347,7 @@
self.kwargs = kwargs

@composed(crash, method_to_log)
def fit(self, X: DataFrame, y: Pandas = -1) -> Self:

Check notice on line 350 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Parameters
Expand All @@ -373,7 +374,7 @@

"""
if isinstance(y, series_t):
self.target_names_in_ = np.array([y.name])

Check notice on line 377 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute target_names_in_ defined outside __init__
else:
raise ValueError("The Balancer class does not support multioutput tasks.")

Expand Down Expand Up @@ -421,13 +422,13 @@

# Create dict of class counts in y
if not hasattr(self, "mapping_"):
self.mapping_ = {str(v): v for v in y.sort_values().unique()}

Check notice on line 425 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute mapping_ defined outside __init__

self._counts = {}

Check notice on line 427 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _counts defined outside __init__
for key, value in self.mapping_.items():
self._counts[key] = np.sum(y == value)

self._estimator = estimator.fit(X, y)

Check notice on line 431 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimator defined outside __init__

# Add the estimator as attribute to the instance
setattr(self, f"{estimator.__class__.__name__.lower()}_", self._estimator)
Expand All @@ -435,7 +436,7 @@
return self

@composed(crash, method_to_log)
def transform(self, X: DataFrame, y: Pandas = -1) -> tuple[DataFrame, Series]:

Check notice on line 439 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Balance the data.

Parameters
Expand All @@ -460,10 +461,10 @@

"""

def log_changes(y):

Check notice on line 464 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'y' from outer scope
"""Print the changes per target class."""
for key, value in self.mapping_.items():

Check notice on line 466 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'value' from outer scope

Check notice on line 466 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'key' from outer scope
diff = self._counts[key] - np.sum(y == value)

Check notice on line 467 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'diff' from outer scope
if diff > 0:
self._log(f" --> Removing {diff} samples from class {key}.", 2)
elif diff < 0:
Expand All @@ -473,7 +474,7 @@
self._log(f"Oversampling with {self._estimator.__class__.__name__}...", 1)

index = X.index # Save indices for later reassignment
X, y = self._estimator.fit_resample(X, y)

Check notice on line 477 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Create indices for the new samples
n_idx: list[int | str]
Expand All @@ -498,7 +499,7 @@

# Select chosen rows (imblearn doesn't return them in order)
samples = sorted(self._estimator.sample_indices_)
X, y = X.iloc[samples], y.iloc[samples] # type: ignore[call-overload]

Check notice on line 502 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

log_changes(y)

Expand All @@ -506,7 +507,7 @@
self._log(f"Balancing with {self._estimator.__class__.__name__}...", 1)

index = X.index
X_new, y_new = self._estimator.fit_resample(X, y)

Check notice on line 510 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Select rows kept by the undersampler
if self._estimator.__class__.__name__ == "SMOTEENN":
Expand All @@ -515,8 +516,8 @@
samples = sorted(self._estimator.tomek_.sample_indices_)

# Select the remaining samples from the old dataframe
o_samples = [s for s in samples if s < len(X)]

Check warning on line 519 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'samples' might be referenced before assignment
X, y = X.iloc[o_samples], y.iloc[o_samples] # type: ignore[call-overload]

Check notice on line 520 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Create indices for the new samples
if index.dtype.kind in "ifu":
Expand All @@ -528,7 +529,7 @@
]

# Select the new samples and assign the new indices
X_new = X_new.iloc[-len(X_new) + len(o_samples):]

Check notice on line 532 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
X_new.index = n_idx
y_new = y_new.iloc[-len(y_new) + len(o_samples):]
y_new.index = n_idx
Expand All @@ -544,7 +545,7 @@
self._log(f" --> Removing {diff} samples from class: {key}.", 2)

# Add the new samples to the old dataframe
X, y = bk.concat([X, X_new]), bk.concat([y, y_new])

Check notice on line 548 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return X, y

Expand Down Expand Up @@ -720,7 +721,7 @@
self.encode_target = encode_target

@composed(crash, method_to_log)
def fit(self, X: DataFrame | None = None, y: Pandas | None = None) -> Self:

Check notice on line 724 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Parameters
Expand All @@ -747,11 +748,11 @@
Estimator instance.

"""
self.mapping_: dict[str, Any] = {}

Check notice on line 751 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute mapping_ defined outside __init__
self._estimators = {}

Check notice on line 752 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimators defined outside __init__

if not hasattr(self, "missing_"):
self.missing_ = DEFAULT_MISSING

Check notice on line 755 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute missing_ defined outside __init__

self._log("Fitting Cleaner...", 1)

Expand All @@ -759,7 +760,7 @@
if isinstance(y, series_t):
self.target_names_in_ = np.array([y.name])
else:
self.target_names_in_ = y.columns.to_numpy()

Check notice on line 763 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute target_names_in_ defined outside __init__

if self.drop_chars:
if isinstance(y, series_t):
Expand All @@ -773,13 +774,13 @@
if self.encode_target:
for col in get_cols(y):
if isinstance(col.iloc[0], sequence_t): # Multilabel
MultiLabelBinarizer = self._get_est_class(

Check notice on line 777 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
name="MultiLabelBinarizer",
module="preprocessing",
)
self._estimators[col.name] = MultiLabelBinarizer().fit(col)
elif list(uq := np.unique(col)) != list(range(col.nunique())):
LabelEncoder = self._get_est_class("LabelEncoder", "preprocessing")

Check notice on line 783 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
self._estimators[col.name] = LabelEncoder().fit(col)
self.mapping_.update({col.name: {str(it(v)): i for i, v in enumerate(uq)}})

Expand All @@ -788,7 +789,7 @@
@composed(crash, method_to_log)
def transform(
self,
X: DataFrame | None = None,

Check notice on line 792 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas | None = None,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Apply the data cleaning steps to the data.
Expand Down Expand Up @@ -824,7 +825,7 @@

if X is not None:
# Unify all missing values
X = replace_missing(X, self.missing_)

Check notice on line 828 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

for name, column in X.items():
dtype = column.dtype.name
Expand All @@ -835,7 +836,7 @@
f" --> Dropping feature {name} for having a prohibited type: {dtype}.",
2,
)
X = X.drop(columns=name)

Check notice on line 839 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
continue

elif dtype in CAT_TYPES:
Expand All @@ -847,14 +848,14 @@

# Drop prohibited chars from column names
if self.drop_chars:
X = X.rename(columns=lambda x: re.sub(self.drop_chars, "", str(x)))

Check notice on line 851 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Drop duplicate samples
if self.drop_duplicates:
X = X.drop_duplicates(ignore_index=True)

Check notice on line 855 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if self.convert_dtypes:
X = X.convert_dtypes()

Check notice on line 858 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if y is not None:
if self.drop_chars:
Expand All @@ -869,7 +870,7 @@
y = replace_missing(y, self.missing_).dropna()

if X is not None:
X = X[X.index.isin(y.index)] # Select only indices that remain

Check notice on line 873 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if (d := length - len(y)) > 0:
self._log(f" --> Dropping {d} rows with missing values in target.", 2)
Expand Down Expand Up @@ -909,7 +910,7 @@
@composed(crash, method_to_log)
def inverse_transform(
self,
X: DataFrame | None = None,

Check notice on line 913 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas | None = None,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Inversely transform the label encoding.
Expand Down Expand Up @@ -962,7 +963,7 @@

# Replace encoded columns with target column
if isinstance(y, series_t):
yt = to_series(out, y.index, col)

Check warning on line 966 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'out' might be referenced before assignment
else:
yt = merge(yt, to_series(out, y.index, col))

Expand Down Expand Up @@ -1083,7 +1084,7 @@
*,
model: str | Predictor | None = None,
sp: IntLargerZero | None = None,
mode: Literal["additive", "multiplicative"] = "additive",
mode: SeasonalityMode = "additive",
n_jobs: NJobs = 1,
verbose: Verbose = 0,
logger: str | Path | Logger | None = None,
Expand All @@ -1100,7 +1101,7 @@
self.mode = mode

@composed(crash, method_to_log)
def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

Check notice on line 1104 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Parameters
Expand All @@ -1126,7 +1127,7 @@
**{x: getattr(self, x) for x in BaseTransformer.attrs if hasattr(self, x)},
)
model.task = Goal.forecast.infer_task(y)
forecaster = model._get_est({})

Check notice on line 1130 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _get_est of a class
else:
raise ValueError(
"Invalid value for the model parameter. Unknown "
Expand All @@ -1135,7 +1136,7 @@
[
f" --> {m.__name__} ({m.acronym})"
for m in MODELS
if "forecast" in m._estimators

Check notice on line 1139 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _estimators of a class
]
)
)
Expand All @@ -1146,7 +1147,7 @@

self._log("Fitting Decomposer...", 1)

self._estimators: dict[Hashable, tuple[Transformer, Transformer]] = {}

Check notice on line 1150 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimators defined outside __init__
for name, column in X.select_dtypes(include="number").items():
trend = Detrender(
forecaster=forecaster,
Expand All @@ -1163,7 +1164,7 @@
return self

@composed(crash, method_to_log)
def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 1167 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase

Check notice on line 1167 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused local symbols

Parameter 'y' value is not used
"""Decompose the data.

Parameters
Expand All @@ -1188,7 +1189,7 @@
return X

@composed(crash, method_to_log)
def inverse_transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 1192 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Inversely transform the data.

Parameters
Expand Down Expand Up @@ -1389,7 +1390,7 @@
self.labels = labels

@composed(crash, method_to_log)
def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

Check notice on line 1393 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Parameters
Expand All @@ -1407,7 +1408,7 @@

"""

def get_labels(col: str, bins: Sequence[Scalar]) -> tuple[str, ...]:

Check notice on line 1411 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'col' from outer scope
"""Get labels for the specified bins.

Parameters
Expand All @@ -1426,7 +1427,7 @@
"""
default = [
f"({np.round(bins[i], 2)}, {np.round(bins[i + 1], 1)}]"
for i in range(len(bins[:-1]))

Check notice on line 1430 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Shadowing names from outer scopes

Shadows name 'i' from outer scope
]

if self.labels is None:
Expand All @@ -1445,12 +1446,12 @@

return labels

Xt, yt = self._check_input(X, y)

Check notice on line 1449 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
self._check_feature_names(Xt, reset=True)
self._check_n_features(Xt, reset=True)

self._estimators: dict[str, Estimator] = {}

Check notice on line 1453 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimators defined outside __init__
self._labels: dict[str, Sequence[str]] = {}

Check notice on line 1454 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _labels defined outside __init__

self._log("Fitting Discretizer...", 1)

Expand All @@ -1467,7 +1468,7 @@
if self.strategy != "custom":
if isinstance(bins_c, sequence_t):
try:
bins_x = bins_c[i] # Fetch the i-th bin for the i-th column

Check warning on line 1471 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Incorrect type

Expected type 'str' (matched generic type '_KT'), got 'int' instead
except IndexError:
raise ValueError(
"Invalid value for the bins parameter. The length of the "
Expand All @@ -1477,7 +1478,7 @@
else:
bins_x = bins_c

KBinsDiscretizer = self._get_est_class("KBinsDiscretizer", "preprocessing")

Check notice on line 1481 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# cuML implementation has no subsample and random_state
kwargs: dict[str, Any] = {}
Expand Down Expand Up @@ -1507,7 +1508,7 @@
else:
bins_c = [-np.inf, *bins_c, np.inf]

FunctionTransformer = self._get_est_class(

Check notice on line 1511 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
name="FunctionTransformer",
module="preprocessing",
)
Expand All @@ -1521,7 +1522,7 @@
return self

@composed(crash, method_to_log)
def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 1525 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase

Check notice on line 1525 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused local symbols

Parameter 'y' value is not used
"""Bin the data into intervals.

Parameters
Expand Down Expand Up @@ -1716,7 +1717,7 @@
self.kwargs = kwargs

@composed(crash, method_to_log)
def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

Check notice on line 1720 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Note that leaving y=None can lead to errors if the `strategy`
Expand Down Expand Up @@ -1746,9 +1747,9 @@
Estimator instance.

"""
self.mapping_ = {}

Check notice on line 1750 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute mapping_ defined outside __init__
self._to_value = {}

Check notice on line 1751 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _to_value defined outside __init__
self._categories = {}

Check notice on line 1752 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _categories defined outside __init__

strategies = {
"backwarddifference": BackwardDifferenceEncoder,
Expand Down Expand Up @@ -1801,7 +1802,7 @@
# Replace infrequent classes with the string in `value`
if self.infrequent_to_value:
values = column.value_counts()
self._to_value[name] = values[values <= infrequent_to_value].index.tolist()

Check warning on line 1805 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'infrequent_to_value' might be referenced before assignment
X[name] = column.replace(self._to_value[name], self.value)

# Get the unique categories before fitting
Expand Down Expand Up @@ -1846,14 +1847,14 @@
handle_unknown="value",
)

rest_enc = estimator(

Check warning on line 1850 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Attempt to call a non-callable object

'str' object is not callable
cols=encoders["rest"],
handle_missing="return_nan",
handle_unknown="value",
**self.kwargs,
)

self._estimator = ColumnTransformer(

Check notice on line 1857 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimator defined outside __init__
transformers=[
("ordinal", ordinal_enc, encoders["ordinal"]),
("onehot", onehot_enc, encoders["onehot"]),
Expand All @@ -1867,7 +1868,7 @@
return self

@composed(crash, method_to_log)
def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 1871 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase

Check notice on line 1871 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused local symbols

Parameter 'y' value is not used
"""Encode the data.

Parameters
Expand All @@ -1887,7 +1888,7 @@
self._log("Encoding categorical columns...", 1)

# Convert infrequent classes to value
X = X.replace(self._to_value, self.value)

Check notice on line 1891 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

for name, categories in self._categories.items():
if name in self._estimator.transformers_[0][2]:
Expand All @@ -1911,10 +1912,10 @@
if uc := len(X[name].dropna()[~X[name].isin(categories)]):
self._log(f" --> Handling {uc} unknown classes.", 2)

Xt = self._estimator.transform(X)

Check notice on line 1915 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Drop _nan columns (since missing values are propagated)
Xt = Xt.loc[:, ~Xt.columns.str.endswith("_nan")]

Check notice on line 1918 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return Xt[get_col_order(Xt, X.columns.tolist(), self._estimator.feature_names_in_)]

Expand Down Expand Up @@ -2101,7 +2102,7 @@
self.max_nan_cols = max_nan_cols

@composed(crash, method_to_log)
def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

Check notice on line 2105 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Parameters
Expand All @@ -2119,20 +2120,20 @@

"""
if not hasattr(self, "missing_"):
self.missing_ = DEFAULT_MISSING

Check notice on line 2123 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute missing_ defined outside __init__

self._log("Fitting Imputer...", 1)

# Unify all values to impute
X = replace_missing(X, self.missing_)

Check notice on line 2128 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

if self.max_nan_rows is not None:
if self.max_nan_rows <= 1:
self._max_nan_rows = int(X.shape[1] * self.max_nan_rows)
else:
self._max_nan_rows = int(self.max_nan_rows)

Check notice on line 2134 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _max_nan_rows defined outside __init__

X = X.dropna(axis=0, thresh=X.shape[1] - self._max_nan_rows)

Check notice on line 2136 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
if X.empty:
raise ValueError(
"Invalid value for the max_nan_rows parameter, got "
Expand All @@ -2147,10 +2148,10 @@
else:
max_nan_cols = int(self.max_nan_cols)

X = X.drop(columns=X.columns[X.isna().sum() > max_nan_cols])

Check notice on line 2151 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Load the imputer class from sklearn or cuml (note the different modules)
SimpleImputer = self._get_est_class(

Check notice on line 2154 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
name="SimpleImputer",
module="preprocessing" if self.engine.get("estimator") == "cuml" else "impute",
)
Expand Down Expand Up @@ -2185,11 +2186,11 @@
fill_value=self.strat_cat,
)

ColumnTransformer = self._get_est_class("ColumnTransformer", "compose")

Check notice on line 2189 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

self._estimator = ColumnTransformer(

Check notice on line 2191 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimator defined outside __init__
transformers=[
("num_imputer", num_imputer, list(X.select_dtypes(include="number"))),

Check warning on line 2193 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'num_imputer' might be referenced before assignment
("cat_imputer", cat_imputer, list(X.select_dtypes(include=CAT_TYPES))),
],
remainder="passthrough",
Expand All @@ -2202,7 +2203,7 @@
@composed(crash, method_to_log)
def transform(
self,
X: DataFrame,

Check notice on line 2206 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas | None = None,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Impute the missing values.
Expand Down Expand Up @@ -2240,17 +2241,17 @@
num_imputer = self._estimator.named_transformers_["num_imputer"]
cat_imputer = self._estimator.named_transformers_["cat_imputer"]

get_stat = lambda est, n: est.statistics_[est.feature_names_in_.tolist().index(n)]

Check notice on line 2244 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 coding style violation

PEP 8: E731 do not assign a lambda expression, use a def

self._log("Imputing missing values...", 1)

# Unify all values to impute
X = replace_missing(X, self.missing_)

Check notice on line 2249 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Drop rows with too many missing values
if self.max_nan_rows is not None:
length = len(X)
X = X.dropna(axis=0, thresh=X.shape[1] - self._max_nan_rows)

Check notice on line 2254 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
if diff := length - len(X):
self._log(
f" --> Dropping {diff} samples for containing more "
Expand All @@ -2260,7 +2261,7 @@

if self.strat_num == "drop":
length = len(X)
X = X.dropna(subset=self._estimator.transformers_[0][2])

Check notice on line 2264 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
if diff := length - len(X):
self._log(
f" --> Dropping {diff} samples for containing "
Expand All @@ -2270,7 +2271,7 @@

if self.strat_cat == "drop":
length = len(X)
X = X.dropna(subset=self._estimator.transformers_[1][2])

Check notice on line 2274 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
if diff := length - len(X):
self._log(
f" --> Dropping {diff} samples for containing "
Expand All @@ -2288,7 +2289,7 @@
f"({nans * 100 // len(X)}%) missing values.",
2,
)
X = X.drop(columns=name)

Check notice on line 2292 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
continue

if self.strat_num != "drop" and name in num_imputer.feature_names_in_:
Expand Down Expand Up @@ -2325,14 +2326,14 @@
2,
)

X = self._estimator.transform(X)

Check notice on line 2329 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Make y consistent with X
if y is not None:
y = y[y.index.isin(X.index)]

# Reorder columns to original order
X = X[[col for col in self.feature_names_in_ if col in X.columns]]

Check notice on line 2336 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

return variable_return(X, y)

Expand Down Expand Up @@ -2487,7 +2488,7 @@
self.kwargs = kwargs

@composed(crash, method_to_log)
def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

Check notice on line 2491 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Parameters
Expand Down Expand Up @@ -2519,7 +2520,7 @@
elif self.strategy == "quantile":
kwargs = self.kwargs.copy()
estimator = self._get_est_class(strategies[self.strategy], "preprocessing")
self._estimator = estimator(

Check notice on line 2523 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimator defined outside __init__
output_distribution=kwargs.pop("output_distribution", "normal"),
random_state=kwargs.pop("random_state", self.random_state),
**kwargs,
Expand Down Expand Up @@ -2547,7 +2548,7 @@
return self

@composed(crash, method_to_log)
def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 2551 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase

Check notice on line 2551 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused local symbols

Parameter 'y' value is not used
"""Apply the transformations to the data.

Parameters
Expand All @@ -2565,14 +2566,14 @@

"""
self._log("Normalizing features...", 1)
Xt = self._estimator.transform(X[self._estimator.feature_names_in_])

Check notice on line 2569 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

X.update(Xt)

return X[self.feature_names_in_]

@composed(crash, method_to_log)
def inverse_transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 2576 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Apply the inverse transformation to the data.

Parameters
Expand All @@ -2590,8 +2591,8 @@

"""
self._log("Inversely normalizing features...", 1)
Xt = self._estimator.inverse_transform(X[self._estimator.feature_names_in_])

Check notice on line 2594 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
Xt = to_df(Xt, index=X.index, columns=self._estimator.feature_names_in_)

Check notice on line 2595 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

X.update(Xt)

Expand Down Expand Up @@ -2769,7 +2770,7 @@
@composed(crash, method_to_log)
def transform(
self,
X: DataFrame,

Check notice on line 2773 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
y: Pandas | None = None,
) -> Pandas | tuple[DataFrame, Pandas]:
"""Apply the outlier strategy on the data.
Expand Down Expand Up @@ -2908,7 +2909,7 @@
self._log(f" --> Dropping {len(mask) - sum(mask)} outliers.", 2)

# Keep only the non-outliers from the data
X = X[mask]

Check notice on line 2912 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
if y is not None:
y = y[mask]

Expand Down Expand Up @@ -3054,7 +3055,7 @@
self.kwargs = kwargs

@composed(crash, method_to_log)
def fit(self, X: DataFrame, y: Pandas | None = None) -> Self:

Check notice on line 3058 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Fit to data.

Parameters
Expand Down Expand Up @@ -3091,7 +3092,7 @@
)

estimator = self._get_est_class(strategies[self.strategy], "preprocessing")
self._estimator = estimator(**self.kwargs)

Check notice on line 3095 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

An instance attribute is defined outside `__init__`

Instance attribute _estimator defined outside __init__

self._log("Fitting Scaler...", 1)
self._estimator.fit(num_cols)
Expand All @@ -3102,7 +3103,7 @@
return self

@composed(crash, method_to_log)
def transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 3106 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase

Check notice on line 3106 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused local symbols

Parameter 'y' value is not used
"""Perform standardization by centering and scaling.

Parameters
Expand All @@ -3120,14 +3121,14 @@

"""
self._log("Scaling features...", 1)
Xt = self._estimator.transform(X[self._estimator.feature_names_in_])

Check notice on line 3124 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

X.update(Xt)

return X

@composed(crash, method_to_log)
def inverse_transform(self, X: DataFrame, y: Pandas | None = None) -> DataFrame:

Check notice on line 3131 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Argument name should be lowercase
"""Apply the inverse transformation to the data.

Parameters
Expand All @@ -3145,8 +3146,8 @@

"""
self._log("Inversely scaling features...", 1)
Xt = self._estimator.inverse_transform(X[self._estimator.feature_names_in_])

Check notice on line 3149 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
Xt = to_df(Xt, index=X.index, columns=self._estimator.feature_names_in_)

Check notice on line 3150 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

X.update(Xt)

Expand Down
2 changes: 1 addition & 1 deletion atom/plots/basefigure.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def get_elem(
else:
return self.style[element].setdefault(name, next(getattr(self, element)))

def showlegend(self, name: str, legend: Legend | dict | None) -> bool:
def showlegend(self, name: str, legend: Legend | dict[str, Any] | None) -> bool:
"""Get whether the trace should be showed in the legend.

If there's already a trace with the same name, it's not
Expand Down
18 changes: 7 additions & 11 deletions atom/plots/baseplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@
elif isinstance(rows, dict):
rows_c = rows

yield from rows_c.items()

Check warning on line 226 in atom/plots/baseplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unbound local variables

Local variable 'rows_c' might be referenced before assignment

def _get_metric(self, metric: MetricSelector, *, max_one: Bool = False) -> list[str]:
"""Check and return the provided metric index.
Expand Down Expand Up @@ -392,8 +392,8 @@
child: str | None = None,
legend: Legend | dict[str, Any] | None = None,
**kwargs,
) -> go.Scatter:
"""Draw a line.
):
"""Draw a line on the current figure.

Unify the style to draw a line, where parent and child
(e.g., model - data set or column - distribution) keep the
Expand All @@ -408,19 +408,16 @@
child: str or None, default=None
Name of the secondary attribute.

legend: str, dict or None
legend: str, dict or None, default=None
Legend argument provided by the user.

**kwargs
Additional keyword arguments for the trace.

Returns
-------
go.Scatter
New trace to add to figure.

"""
return go.Scatter(
BasePlot._fig.figure.add_scatter(
name=kwargs.pop("name", child or parent),
mode=kwargs.pop("mode", "lines"),
line=kwargs.pop(
"line", {
"width": self.line_width,
Expand All @@ -440,15 +437,14 @@
"hovertemplate",
f"(%{{x}}, %{{y}})<extra>{parent}{f' - {child}' if child else ''}</extra>",
),
name=kwargs.pop("name", child or parent),
legendgroup=kwargs.pop("legendgroup", parent),
legendgrouptitle=kwargs.pop(
"legendgrouptitle",
{"text": parent, "font_size": self.label_fontsize} if child else None,
),
showlegend=kwargs.pop(
"showlegend",
BasePlot._fig.showlegend(f"{parent}-{child}", legend)
BasePlot._fig.showlegend(f"{parent}-{child}" if child else parent, legend)
),
**kwargs,
)
Expand Down
Loading
Loading