Skip to content

Commit

Permalink
fix Unknown train step (#133)
Browse files Browse the repository at this point in the history
* issue (#129) fixed

* training steps added (#129)

* Update datagenerator.py

* Update trainer.py

Co-authored-by: Aniket Maurya <[email protected]>
  • Loading branch information
Adk2001tech and aniketmaurya authored Aug 23, 2021
1 parent 156e975 commit cddc9fc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion chitra/datagenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
self.num_files = len(self.filenames)
self.image_size = image_size
self.img_sz_list = ImageSizeList(self.image_size)

self.step_size = None
self.labels = kwargs.get("labels", self.get_labels())

def __len__(self):
Expand Down
17 changes: 10 additions & 7 deletions chitra/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,19 @@ def cyclic_fit(
validation_data: Data on which to evaluate
callbacks: List of `tf.keras.callbacks` instances.
kwargs:
step_size (int): step size for the Cyclic learning rate. By default it is `2 * len(self.ds)//batch_size`
step_size (int): step size for the Cyclic learning rate.
By default it is `2 * (self.ds.num_files//batch_size)`
scale_mode (str): cycle or exp
shuffle(bool): Dataset will be shuffle on each epoch if True
"""
self.step_size = 2 * (self.ds.num_files // batch_size)
self.ds.step_size = self.step_size
if not self.cyclic_opt_set:
self.max_lr, self.min_lr = lr_range
step_size = 2 * len(self.ds) // batch_size
lr_schedule = tfa.optimizers.Triangular2CyclicalLearningRate(
initial_learning_rate=lr_range[0],
maximal_learning_rate=lr_range[1],
step_size=kwargs.get("step_size", step_size),
step_size=kwargs.get("step_size", self.step_size),
scale_mode=kwargs.get("scale_mode", "cycle"),
)

Expand All @@ -262,6 +264,7 @@ def cyclic_fit(

return self.model.fit(
self._prepare_dl(batch_size, kwargs.get("shuffle", True)),
steps_per_epoch=self.step_size,
validation_data=validation_data,
epochs=epochs,
callbacks=callbacks,
Expand All @@ -287,7 +290,8 @@ def compile2(
optimizer (str, keras.optimizer.Optimizer): Keras optimizer
kwargs:
step_size (int): step size for the Cyclic learning rate. By default it is `2 * len(self.ds)//batch_size`
step_size (int): step size for the Cyclic learning rate.
By default it is `2 * (self.ds.num_files // batch_size)`
scale_mode (str): cycle or exp
momentum(int): momentum for the optimizer when optimizer is of type str
"""
Expand All @@ -296,12 +300,12 @@ def compile2(
self.max_lr, self.min_lr = lr_range
self.batch_size = batch_size

self.step_size = step_size = 2 * len(self.ds) // batch_size
self.step_size = 2 * (self.ds.num_files // batch_size)

lr_schedule = tfa.optimizers.Triangular2CyclicalLearningRate(
initial_learning_rate=lr_range[0],
maximal_learning_rate=lr_range[1],
step_size=kwargs.get("step_size", step_size),
step_size=kwargs.get("step_size", self.step_size),
scale_mode=kwargs.get("scale_mode", "cycle"),
)

Expand Down Expand Up @@ -394,7 +398,6 @@ def fit(
train_data,
epochs,
val_data=None,
test_data=None,
callbacks=None,
**kwargs,
):
Expand Down

0 comments on commit cddc9fc

Please sign in to comment.