diff --git a/tests/test_CosineAnnealingBS.py b/tests/test_CosineAnnealingBS.py index 163a6c6..a0fd82b 100644 --- a/tests/test_CosineAnnealingBS.py +++ b/tests/test_CosineAnnealingBS.py @@ -37,7 +37,7 @@ def test_dataloader_lengths(self): def test_dataloader_batch_size(self): base_batch_size = 10 total_iters = 50 - n_epochs = 500 + n_epochs = 200 max_batch_size = 100 dataloader = create_dataloader(self.dataset, batch_size=base_batch_size) scheduler = CosineAnnealingBS(dataloader, total_iters=total_iters, max_batch_size=max_batch_size) @@ -47,7 +47,7 @@ def test_dataloader_batch_size(self): 47, 49, 52, 55, 58, 61, 63, 66, 69, 72, 74, 77, 79, 81, 84, 86, 88, 90, 91, 93, 94, 96, 97, 98, 99, 99, 100, 100, 100, 100, 100, 99, 99, 98, 97, 96, 94, 93, 91, 90, 88, 86, 84, 81, 79, 77, 74, 72, 69, 66, 63, 61, 58, 55, 52, 49, 47, 44, 41, 38, 36, 33, 31, 29, 26, - 24, 22, 20, 19, 17, 16, 14, 13, 12, 11, 11, 10, 10] * 5 + 24, 22, 20, 19, 17, 16, 14, 13, 12, 11, 11, 10, 10] * 2 self.assertEqual(batch_sizes, expected_batch_sizes) diff --git a/tests/test_CyclicBS.py b/tests/test_CyclicBS.py index 1fde1ea..0b902cf 100644 --- a/tests/test_CyclicBS.py +++ b/tests/test_CyclicBS.py @@ -201,11 +201,11 @@ def test_graphic_exp_range(self): base_batch_size = 100 dataloader = create_dataloader(self.dataset, batch_size=base_batch_size) max_batch_size = 200 - step_size_down = 50 + step_size_down = 25 gamma = 0.9 scheduler = CyclicBS(dataloader, base_batch_size=base_batch_size, max_batch_size=max_batch_size, step_size_down=step_size_down, mode='exp_range', gamma=gamma) - n_epochs = 10 * step_size_down + n_epochs = 6 * step_size_down batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs) plt.plot(batch_sizes) diff --git a/tests/test_LambdaBS.py b/tests/test_LambdaBS.py index 34a7f21..ae30d86 100644 --- a/tests/test_LambdaBS.py +++ b/tests/test_LambdaBS.py @@ -22,7 +22,7 @@ def test_sanity(self): "always be equal to the inferred length except for Iterable Datasets for " "which the __len__ could be inaccurate.") - dataloader.batch_sampler.batch_size = 526 + dataloader.batch_sampler.batch_size = 256 real, inferred = iterate(dataloader) self.assertEqual(real, inferred, "Dataloader __len__ does not return the real length. The real length should " "always be equal to the inferred length except for Iterable Datasets for " @@ -46,7 +46,7 @@ def test_dataloader_batch_size(self): dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size) fn = lambda epoch: 10 * epoch # noqa: E731 scheduler = LambdaBS(dataloader, fn) - n_epochs = 15 + n_epochs = 10 batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs) expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, fn, diff --git a/tests/test_MultiStepBS.py b/tests/test_MultiStepBS.py index ee6523c..ef41deb 100644 --- a/tests/test_MultiStepBS.py +++ b/tests/test_MultiStepBS.py @@ -47,10 +47,10 @@ def test_dataloader_lengths(self): def test_dataloader_batch_size(self): dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size) - milestones = [5, 10, 10, 12] + milestones = [5, 7, 7, 9] gamma = 3.0 scheduler = MultiStepBS(dataloader, milestones=milestones, gamma=gamma, max_batch_size=5000, verbose=False) - n_epochs = 15 + n_epochs = 10 batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs) expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, milestones, gamma, @@ -60,7 +60,7 @@ def test_dataloader_batch_size(self): def test_loading_and_unloading(self): dataloader = create_dataloader(self.dataset) - milestones = [5, 10, 10, 12] + milestones = [5, 7, 7, 9] gamma = 3.0 scheduler = MultiStepBS(dataloader, milestones=milestones, gamma=gamma, max_batch_size=5000, verbose=False) @@ -76,10 +76,10 @@ def test_graphic(self): warnings.filterwarnings("ignore", category=UserWarning) dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size) - milestones = [5, 10, 10, 12] + milestones = [5, 7, 7, 9] gamma = 3.0 scheduler = MultiStepBS(dataloader, milestones=milestones, gamma=gamma, max_batch_size=5000, verbose=False) - n_epochs = 15 + n_epochs = 10 batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs) plt.plot(batch_sizes) diff --git a/tests/test_MultiplicativeBS.py b/tests/test_MultiplicativeBS.py index e79d7ba..7ec27e4 100644 --- a/tests/test_MultiplicativeBS.py +++ b/tests/test_MultiplicativeBS.py @@ -37,7 +37,7 @@ def test_dataloader_batch_size(self): dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size) fn = lambda epoch: epoch / 100 + 2 # noqa: E731 scheduler = MultiplicativeBS(dataloader, fn, max_batch_size=5000, verbose=False) - n_epochs = 15 + n_epochs = 10 batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs) expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, fn, diff --git a/tests/test_StepBS.py b/tests/test_StepBS.py index 43b4222..0896dd5 100644 --- a/tests/test_StepBS.py +++ b/tests/test_StepBS.py @@ -107,7 +107,7 @@ def test_graphic(self): step_size = 50 gamma = 1.1 scheduler = StepBS(dataloader, step_size=step_size, gamma=gamma) - n_epochs = 300 + n_epochs = 200 batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs) plt.plot(batch_sizes)