Skip to content

Commit

Permalink
Updates docs
Browse files Browse the repository at this point in the history
  • Loading branch information
djgagne committed Mar 29, 2024
1 parent 96b50ae commit 521cceb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
14 changes: 14 additions & 0 deletions bridgescaler/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytdigest import TDigest
from functools import partial
from scipy.stats import logistic
from warnings import warn

CENTROID_DTYPE = np.dtype([('mean', np.float64), ('weight', np.float64)])

Expand Down Expand Up @@ -359,6 +360,8 @@ def transform_variable(td_obj, xv,
x_transformed[:] = np.maximum(x_transformed, min_val)
if distribution == "normal":
x_transformed[:] = norm.ppf(x_transformed, loc=0, scale=1)
elif distribution == "logistic":
x_transformed[:] = logistic.ppf(x_transformed)
return x_transformed


Expand All @@ -367,6 +370,8 @@ def inv_transform_variable(td_obj, xv,
x_transformed = np.zeros(xv.shape, dtype=xv.dtype)
if distribution == "normal":
x_transformed = norm.cdf(xv, loc=0, scale=1)
elif distribution == "logistic":
x_transformed = logistic.cdf(xv)
x_transformed[:] = td_obj.quantile(x_transformed)
return x_transformed

Expand All @@ -378,6 +383,12 @@ class DQuantileScaler(DBaseScaler):
in parallel using the multiprocessing library. Multidimensional arrays are stored in shared memory across
processes to minimize inter-process communication.
Attributes:
compression: Recommended number of centroids to use.
distribution: "uniform", "normal", or "logistic".
min_val: Minimum value for quantile to prevent -inf results when distribution is normal or logistic.
max_val: Maximum value for quantile to prevent inf results when distribution is normal or logistic.
channels_last: Whether to assume the last dim or second dim are the channel/variable dimension.
"""
def __init__(self, compression=250, distribution="uniform", min_val=0.0000001, max_val=0.9999999, channels_last=True):
self.compression = compression
Expand Down Expand Up @@ -551,6 +562,9 @@ class DQuantileTransformer(DBaseScaler):
or the second dimension (False).
"""
def __init__(self, max_merged_centroids=1000, distribution="uniform", channels_last=True):
warn(f'{self.__class__.__name__} will be deprecated in the next version.',
DeprecationWarning, stacklevel=2)

self.max_merged_centroids = max_merged_centroids
self.distribution = distribution
self.centroids_ = None
Expand Down
19 changes: 17 additions & 2 deletions doc/source/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and muliti-dimensional patch data in numpy, pandas DataFrame, and xarray DataArr
The distributed scalers allow you to calculate scaling
parameters on different subsets of a dataset and then combine the scaling factors
together to get representative scaling values for the full dataset. Distributed
Standard Scalers, MinMax Scalers, and Quantile Transformers have been implemented and work with both tabular
Standard Scalers, MinMax Scalers, and Quantile Scalers have been implemented and work with both tabular
and muliti-dimensional patch data in numpy, pandas DataFrame, and xarray DataArray formats.
By default, the scaler assumes your channel/variable dimension is the last
dimension, but if `channels_last=False` is set in the `__init__`, `transform`,
Expand Down Expand Up @@ -43,4 +43,19 @@ Example:
dss_2.fit(x_2)
dss_combined = np.sum([dss_1, dss_2])
dss_combined.transform(x_1, channels_last=False)
dss_combined.transform(x_1, channels_last=False)
Distributed scalers can be stored in individual json files or within
a pandas DataFrame for easy loading and combining later.

.. code-block:: python
import pandas as pd
from bridgescaler import print_scaler, read_scaler
scaler_list = [dss_1, dss_2]
df = pd.DataFrame({"scalers": [print_scaler(s) in scaler_list]}])
df.to_parquet("scalers.parquet")
df_new = df.read_parquet("scalers.parquet")
scaler_objs = df_new["scalers"].apply(read_scaler)
total_scaler = scaler_objs.sum()

0 comments on commit 521cceb

Please sign in to comment.