Skip to content

Commit

Permalink
fix: fixed docstring issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jrfiedler authored and Optimox committed Oct 21, 2020
1 parent da77060 commit d216fbf
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 226 deletions.
179 changes: 93 additions & 86 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,39 +90,40 @@ def fit(
Parameters
----------
X_train: np.ndarray
Train set
y_train : np.array
Train targets
eval_set: list of tuple
List of eval tuple set (X, y).
The last one is used for early stopping
eval_name: list of str
List of eval set names.
eval_metric : list of str
List of evaluation metrics.
The last metric is used for early stopping.
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
dict for custom weights per class
max_epochs : int
Maximum number of epochs during training
patience : int
Number of consecutive non improving epoch before early stopping
batch_size : int
Training batch size
virtual_batch_size : int
Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size)
num_workers : int
Number of workers used in torch.utils.data.DataLoader
drop_last : bool
Whether to drop last batch during training
callbacks : list of callback function
List of custom callbacks
pin_memory: bool
Whether to set pin_memory to True or False during training
X_train : np.ndarray
Train set
y_train : np.array
Train targets
eval_set : list of tuple
List of eval tuple set (X, y).
The last one is used for early stopping
eval_name : list of str
List of eval set names.
eval_metric : list of str
List of evaluation metrics.
The last metric is used for early stopping.
loss_fn : callable or None
a PyTorch loss function
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
dict for custom weights per class
max_epochs : int
Maximum number of epochs during training
patience : int
Number of consecutive non improving epoch before early stopping
batch_size : int
Training batch size
virtual_batch_size : int
Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size)
num_workers : int
Number of workers used in torch.utils.data.DataLoader
drop_last : bool
Whether to drop last batch during training
callbacks : list of callback function
List of custom callbacks
pin_memory: bool
Whether to set pin_memory to True or False during training
"""
# update model name

Expand Down Expand Up @@ -196,13 +197,13 @@ def predict(self, X):
Parameters
----------
X: a :tensor: `torch.Tensor`
Input data
X : a :tensor: `torch.Tensor`
Input data
Returns
-------
predictions: np.array
Predictions of the regression problem
predictions : np.array
Predictions of the regression problem
"""
self.network.eval()
dataloader = DataLoader(
Expand All @@ -226,15 +227,15 @@ def explain(self, X):
Parameters
----------
X: tensor: `torch.Tensor`
Input data
X : tensor: `torch.Tensor`
Input data
Returns
-------
M_explain: matrix
Importance per sample, per columns.
masks: matrix
Sparse matrix showing attention masks used by network.
M_explain : matrix
Importance per sample, per columns.
masks : matrix
Sparse matrix showing attention masks used by network.
"""
self.network.eval()

Expand Down Expand Up @@ -274,8 +275,14 @@ def save_model(self, path):
Parameters
----------
filepath : str
Path of the model.
path : str
Path of the model.
Returns
-------
str
input filepath with ".zip" appended
"""
saved_params = {}
for key, val in self.get_params().items():
Expand Down Expand Up @@ -304,8 +311,8 @@ def load_model(self, filepath):
Parameters
----------
filepath : str
Path of the model.
filepath : str
Path of the model.
"""
try:
with zipfile.ZipFile(filepath) as z:
Expand Down Expand Up @@ -338,8 +345,8 @@ def _train_epoch(self, train_loader):
Parameters
----------
train_loader: a :class: `torch.utils.data.Dataloader`
DataLoader with train set
train_loader : a :class: `torch.utils.data.Dataloader`
DataLoader with train set
"""
self.network.train()

Expand All @@ -361,17 +368,17 @@ def _train_batch(self, X, y):
Parameters
----------
X: torch.tensor
Train matrix
y: torch.tensor
Target matrix
X : torch.Tensor
Train matrix
y : torch.Tensor
Target matrix
Returns
-------
batch_outs : dict
Dictionnary with "y": target and "score": prediction scores.
batch_logs : dict
Dictionnary with "batch_size" and "loss".
batch_outs : dict
Dictionnary with "y": target and "score": prediction scores.
batch_logs : dict
Dictionnary with "batch_size" and "loss".
"""
batch_logs = {"batch_size": X.shape[0]}

Expand Down Expand Up @@ -403,10 +410,10 @@ def _predict_epoch(self, name, loader):
Parameters
----------
name: str
Name of the validation set
loader: torch.utils.data.Dataloader
DataLoader with validation set
name : str
Name of the validation set
loader : torch.utils.data.Dataloader
DataLoader with validation set
"""
# Setting network on evaluation mode (no dropout etc...)
self.network.eval()
Expand All @@ -433,13 +440,13 @@ def _predict_batch(self, X):
Parameters
----------
x: torch.tensor
Owned products
X : torch.Tensor
Owned products
Returns
-------
np.array
model scores
np.array
model scores
"""
X = X.to(self.device).float()

Expand Down Expand Up @@ -518,7 +525,7 @@ def _set_callbacks(self, custom_callbacks):
Parameters
----------
callbacks : list of func
custom_callbacks : list of func
List of callback functions.
"""
Expand Down Expand Up @@ -569,7 +576,7 @@ def _construct_loaders(self, X_train, y_train, eval_set):
Train set.
y_train : np.array
Train targets.
eval_set: list of tuple
eval_set : list of tuple
List of eval tuple set (X, y).
Returns
Expand Down Expand Up @@ -626,15 +633,15 @@ def update_fit_params(self, X_train, y_train, eval_set, weights):
Parameters
----------
X_train: np.ndarray
Train set
y_train : np.array
Train targets
eval_set: list of tuple
List of eval tuple set (X, y).
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
X_train : np.ndarray
Train set
y_train : np.array
Train targets
eval_set : list of tuple
List of eval tuple set (X, y).
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
"""
raise NotImplementedError(
"users must define update_fit_params to use this base class"
Expand All @@ -647,15 +654,15 @@ def compute_loss(self, y_score, y_true):
Parameters
----------
y_score: a :tensor: `torch.Tensor`
Score matrix
y_true: a :tensor: `torch.Tensor`
Target matrix
y_score : a :tensor: `torch.Tensor`
Score matrix
y_true : a :tensor: `torch.Tensor`
Target matrix
Returns
-------
float
Loss value
float
Loss value
"""
raise NotImplementedError(
"users must define compute_loss to use this base class"
Expand All @@ -668,13 +675,13 @@ def prepare_target(self, y):
Parameters
----------
y: a :tensor: `torch.Tensor`
Target matrix.
y : a :tensor: `torch.Tensor`
Target matrix.
Returns
-------
`torch.Tensor`
Converted target matrix.
`torch.Tensor`
Converted target matrix.
"""
raise NotImplementedError(
"users must define prepare_target to use this base class"
Expand Down
Loading

0 comments on commit d216fbf

Please sign in to comment.