Skip to content

Commit

Permalink
Updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Dec 6, 2024
1 parent f52655b commit 67d8f10
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 31 deletions.
13 changes: 11 additions & 2 deletions tests/test_ConstantBS.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import unittest

from bs_scheduler import ConstantBS
from bs_scheduler import ConstantBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest, rint
get_batch_sizes_across_epochs, BSTest, rint, batched_dataset


class TestConstantBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'factor': 5.0,
'milestone': 5
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, ConstantBS, batch_size_manager, **kwargs)

def test_dataloader_lengths(self):
dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)
factor = 5.0
Expand Down
13 changes: 11 additions & 2 deletions tests/test_CosineAnnealingBS.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import unittest

from bs_scheduler import CosineAnnealingBS
from bs_scheduler import CosineAnnealingBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest
get_batch_sizes_across_epochs, BSTest, batched_dataset


class TestCosineAnnealingBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'total_iters': 5,
'max_batch_size': 100,
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, CosineAnnealingBS, batch_size_manager, **kwargs)

def test_dataloader_lengths(self):
base_batch_size = 10
total_iters = 5
Expand Down
13 changes: 11 additions & 2 deletions tests/test_CosineAnnealingBSWithWarmRestarts.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import unittest

from bs_scheduler import CosineAnnealingBSWithWarmRestarts
from bs_scheduler import CosineAnnealingBSWithWarmRestarts, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest
get_batch_sizes_across_epochs, BSTest, batched_dataset


class TestCosineAnnealingBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
't_0': 5,
'max_batch_size': 100,
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, CosineAnnealingBSWithWarmRestarts, batch_size_manager, **kwargs)

def test_dataloader_lengths(self):
base_batch_size = 10
t_0 = 5
Expand Down
15 changes: 13 additions & 2 deletions tests/test_CyclicBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@
import random
import unittest

from bs_scheduler import CyclicBS
from tests.test_utils import create_dataloader, fashion_mnist, get_batch_sizes_across_epochs, BSTest, rint
from bs_scheduler import CyclicBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, fashion_mnist, get_batch_sizes_across_epochs, BSTest, rint, \
batched_dataset


class TestConstantBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'base_batch_size': 100,
'step_size_down': 10,
'mode': 'triangular',
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, CyclicBS, batch_size_manager, **kwargs)

def test_dataloader_batch_size_triangular(self):
base_batch_size = 100
dataloader = create_dataloader(self.dataset, batch_size=base_batch_size)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_ExponentialBS.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import os
import unittest

from bs_scheduler import ExponentialBS
from bs_scheduler import ExponentialBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest
get_batch_sizes_across_epochs, BSTest, batched_dataset


class TestExponentialBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'gamma': 1.01,
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, ExponentialBS, batch_size_manager, **kwargs)

def test_dataloader_lengths(self):
dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)

Expand Down
13 changes: 11 additions & 2 deletions tests/test_IncreaseBSOnPlateau.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import unittest

from bs_scheduler import IncreaseBSOnPlateau
from bs_scheduler import IncreaseBSOnPlateau, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
BSTest, get_batch_sizes_across_epochs
BSTest, get_batch_sizes_across_epochs, batched_dataset


class TestIncreaseBSOnPlateau(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'mode': 'min',
'threshold_mode': 'rel',
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, IncreaseBSOnPlateau, batch_size_manager, **kwargs)

def test_constant_metric(self):
base_batch_size = 10
max_batch_size = 100
Expand Down
14 changes: 12 additions & 2 deletions tests/test_LinearBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,26 @@

import torch

from bs_scheduler import LinearBS
from bs_scheduler import LinearBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest, clip, rint
get_batch_sizes_across_epochs, BSTest, clip, rint, batched_dataset


class TestLinearBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'start_factor': 10.0,
'end_factor': 5.0,
'milestone': 5
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, LinearBS, batch_size_manager, **kwargs)

@staticmethod
def compute_expected_batch_sizes(epochs, base_batch_size, start_factor, end_factor, milestone, min_batch_size,
max_batch_size):
Expand Down
13 changes: 11 additions & 2 deletions tests/test_MultiStepBS.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import unittest

from bs_scheduler import MultiStepBS
from bs_scheduler import MultiStepBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest, clip, rint
get_batch_sizes_across_epochs, BSTest, clip, rint, batched_dataset


class TestMultiStepBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'gamma': 1.1,
'milestones': [70, 70, 80, 10, 50]
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, MultiStepBS, batch_size_manager, **kwargs)

@staticmethod
def compute_expected_batch_sizes(epochs, base_batch_size, milestones, gamma, min_batch_size, max_batch_size):
expected_batch_sizes = [base_batch_size] # Base batch size is added as a boundary condition.
Expand Down
16 changes: 14 additions & 2 deletions tests/test_OneCycleBS.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import os
import unittest

from bs_scheduler import OneCycleBS
from bs_scheduler import OneCycleBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest, rint
get_batch_sizes_across_epochs, BSTest, rint, batched_dataset


class TestOneCycleBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'max_batch_size': 100,
'min_batch_size': 10,
'total_steps': 100,
'decay_percentage': 0.3,
'strategy': 'linear',
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, OneCycleBS, batch_size_manager, **kwargs)

def test_dataloader_batch_size_linear(self):
base_batch_size = 40
n_epochs = 120
Expand Down
13 changes: 11 additions & 2 deletions tests/test_PolynomialBS.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import unittest

from bs_scheduler import PolynomialBS
from bs_scheduler import PolynomialBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest
get_batch_sizes_across_epochs, BSTest, batched_dataset


class TestPolynomialBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'total_iters': 5,
'power': 1.0,
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, PolynomialBS, batch_size_manager, **kwargs)

def test_dataloader_lengths(self):
base_batch_size = 10
total_iters = 5
Expand Down
56 changes: 46 additions & 10 deletions tests/test_StepBS.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import os
import unittest

from bs_scheduler import StepBS
from bs_scheduler import StepBS, CustomBatchSizeManager
from tests.test_utils import create_dataloader, simulate_n_epochs, fashion_mnist, \
get_batch_sizes_across_epochs, BSTest, clip, rint
get_batch_sizes_across_epochs, BSTest, clip, rint, batched_dataset


class TestStepBS(BSTest):
def setUp(self):
self.base_batch_size = 64
self.dataset = fashion_mnist()

def test_create(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
kwargs = {
'step_size': 50,
'gamma': 1.1
}
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
self.create_scheduler(dataloader, StepBS, batch_size_manager, **kwargs)


@staticmethod
def compute_expected_batch_sizes(epochs, base_batch_size, step_size, gamma, min_batch_size, max_batch_size):
expected_batch_sizes = [base_batch_size] # Base batch size is added as a boundary condition.
Expand All @@ -23,6 +33,14 @@ def compute_expected_batch_sizes(epochs, base_batch_size, step_size, gamma, min_
expected_batch_sizes.pop(0) # Removing base batch size.
return expected_batch_sizes

def run_scheduler(self, dataloader, scheduler, n_epochs):
batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, scheduler.step_size,
scheduler.gamma, scheduler.min_batch_size,
scheduler.max_batch_size)

self.assertEqual(batch_sizes, expected_batch_sizes)

def test_dataloader_lengths(self):
dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)
step_size = 50
Expand All @@ -31,8 +49,9 @@ def test_dataloader_lengths(self):
n_epochs = 300

epoch_lengths = simulate_n_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, step_size, gamma,
scheduler.min_batch_size, scheduler.max_batch_size)
expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, scheduler.step_size,
scheduler.gamma, scheduler.min_batch_size,
scheduler.max_batch_size)
expected_lengths = self.compute_epoch_lengths(expected_batch_sizes, len(self.dataset), drop_last=False)

self.assertEqual(epoch_lengths, expected_lengths)
Expand All @@ -43,12 +62,29 @@ def test_dataloader_batch_size(self):
gamma = 3.0
scheduler = StepBS(dataloader, step_size=step_size, gamma=gamma, max_batch_size=5000, verbose=False)
n_epochs = 15

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, step_size, gamma,
scheduler.min_batch_size, scheduler.max_batch_size)

self.assertEqual(batch_sizes, expected_batch_sizes)
self.run_scheduler(dataloader, scheduler, n_epochs)

def test_batched_dataset(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
step_size = 2
gamma = 2.0
n_epochs = 10

scheduler = StepBS(dataloader, step_size=step_size, gamma=gamma, max_batch_size=5000, verbose=False,
batch_size_manager=batch_size_manager)
self.run_scheduler(dataloader, scheduler, n_epochs)

def test_dataloader_none(self):
dataloader = create_dataloader(batched_dataset(batch_size=self.base_batch_size), batch_size=None)
batch_size_manager = CustomBatchSizeManager(dataloader.dataset)
step_size = 2
gamma = 2.0
n_epochs = 10

scheduler = StepBS(None, step_size=step_size, gamma=gamma, max_batch_size=5000, verbose=False,
batch_size_manager=batch_size_manager)
self.run_scheduler(dataloader, scheduler, n_epochs)

def test_loading_and_unloading(self):
dataloader = create_dataloader(self.dataset)
Expand Down
Loading

0 comments on commit 67d8f10

Please sign in to comment.