Skip to content

Commit

Permalink
[Minor] reduce batch size for small datasets, reduce overall epochs f…
Browse files Browse the repository at this point in the history
…or most datasets (#1533)

* reduce minimum batch size from 32 to 8
* reduce epochs by half
* fix tests, round epochs
  • Loading branch information
ourownstory authored Feb 14, 2024
1 parent 587b374 commit e0c95bf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
12 changes: 6 additions & 6 deletions neuralprophet/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,21 @@ def set_quantiles(self):
def set_auto_batch_epoch(
self,
n_data: int,
min_batch: int = 32,
max_batch: int = 1024,
min_epoch: int = 10,
max_epoch: int = 1000,
min_batch: int = 8,
max_batch: int = 2048,
min_epoch: int = 20,
max_epoch: int = 500,
):
assert n_data >= 1
self.n_data = n_data
if self.batch_size is None:
self.batch_size = int(2 ** (3 + int(np.log10(n_data))))
self.batch_size = int(2 ** (1 + int(1.5 * np.log10(int(n_data)))))
self.batch_size = min(max_batch, max(min_batch, self.batch_size))
self.batch_size = min(self.n_data, self.batch_size)
log.info(f"Auto-set batch_size to {self.batch_size}")
if self.epochs is None:
# this should (with auto batch size) yield about 1000 steps minimum and 100,000 steps at upper cutoff
self.epochs = int(2 ** (2.5 * np.log10(100 + n_data)) / (n_data / 1000.0))
self.epochs = 10 * int(np.ceil(100 / n_data * 2 ** (2.25 * np.log10(10 + n_data))))
self.epochs = min(max_epoch, max(min_epoch, self.epochs))
log.info(f"Auto-set epochs to {self.epochs}")
# also set lambda_delay:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,14 @@ def test_auto_batch_epoch():
# for epochs = int(2 ** (2.3 * np.log10(100 + n_data)) / (n_data / 1000.0))
# for epochs = int(2 ** (2.5 * np.log10(100 + n_data)) / (n_data / 1000.0))
check = {
"1": (1, 1000),
"10": (10, 1000),
"100": (32, 539),
"1000": (64, 194),
"10000": (128, 103),
"100000": (256, 57),
"1000000": (512, 32),
"10000000": (1024, 18),
"1": (1, 500),
"10": (8, 500),
"100": (16, 250),
"1000": (32, 110),
"10000": (128, 60),
"100000": (256, 30),
"1000000": (1024, 20),
"10000000": (2048, 20),
}

for n_data, (batch_size, epochs) in check.items():
Expand Down

0 comments on commit e0c95bf

Please sign in to comment.