Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated README and extended documentation #14

Merged
merged 6 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,26 @@ A Batch Size Scheduler library compatible with PyTorch DataLoaders.

***

Documentation: [API Reference](https://ancestor-mithril.github.io/bs-scheduler/).
## Documentation

<!--Examples: TODO. -->
* [API Reference](https://ancestor-mithril.github.io/bs-scheduler).

* [Examples](https://ancestor-mithril.github.io/bs-scheduler/tutorials).

<!--For Release Notes, see TODO. -->

***

## Why use a Batch Size Scheduler?

<!--TODO: Cite papers and explain why. -->
* Using a big batch size has several advantages:
* Better hardware utilization.
* Enhanced parallelism.
* Faster training.
* However, using a big batch size from the start may lead to a generalization gap.
* Therefore, the solution is to gradually increase the batch size, similar to a learning rate decay policy.
* See [Don't Decay the Learning Rate, Increase the Batch Size](https://arxiv.org/abs/1711.00489).


## Available Schedulers

Expand Down Expand Up @@ -57,14 +68,12 @@ Please install [PyTorch](https://github.com/pytorch/pytorch) first before instal
pip install bs-scheduler
```

Or from git:

```
pip install git+https://github.com/ancestor-mithril/bs-scheduler.git@master
```

## Licensing

The library is licensed under the [BSD-3-Clause license](LICENSE).

## Citation

To be added...

<!--Citation: TODO. -->
31 changes: 30 additions & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,30 @@
::: bs_scheduler
# bs-scheduler

A Batch Size Scheduler library compatible with PyTorch DataLoaders.

***

### Batch Size Schedulers

1. [LambdaBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.LambdaBS) - sets the batch size to the base batch size times a given lambda.
2. [MultiplicativeBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.MultiplicativeBS) - sets the batch size to the current batch size times a given lambda.
3. [StepBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.StepBS) - multiplies the batch size with a given factor at a given number of steps.
4. [MultiStepBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.MultiStepBS) - multiplies the batch size with a given factor each time a milestone is reached.
5. [ConstantBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.ConstantBS) - multiplies the batch size by a given factor once and decreases it again to its base value after a
given number of steps.
6. [LinearBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.LinearBS) - increases the batch size by a linearly changing multiplicative factor for a given number of steps.
7. [ExponentialBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.ExponentialBS) - increases the batch size by a given $\gamma$ each step.
8. [PolynomialBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.PolynomialBS) - increases the batch size using a polynomial function in a given number of steps.
9. [CosineAnnealingBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.CosineAnnealingBS) - increases the batch size to a maximum batch size and decreases it again following a cyclic
cosine curve.
10. [IncreaseBSOnPlateau](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.IncreaseBSOnPlateau) - increases the batch size each time a given metric has stopped improving for a given number
of steps.
11. [CyclicBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.CyclicBS) - cycles the batch size between two boundaries with a constant frequency, while also scaling the
distance between boundaries.
12. [CosineAnnealingBSWithWarmRestarts](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.CosineAnnealingBSWithWarmRestarts) - increases the batch size to a maximum batch size following a cosine curve,
then restarts while also scaling the number of iterations until the next restart.
13. [OneCycleBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.OneCycleBS) - decreases the batch size to a minimum batch size then increases it to a given maximum batch size,
following a linear or cosine annealing strategy.
14. [SequentialBS](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.SequentialBS) - calls a list of schedulers sequentially given a list of milestone points which reflect which
scheduler should be called when.
15. [ChainedBSScheduler](https://ancestor-mithril.github.io/bs-scheduler/reference/#bs_scheduler.ChainedBSScheduler) - chains a list of batch size schedulers and calls them together each step.
1 change: 1 addition & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: bs_scheduler
115 changes: 115 additions & 0 deletions docs/tutorials.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
## Basic usage

Integrating a Batch Size Scheduler inside a PyTorch training script is simple:

```python
from torch.utils.data import DataLoader
from bs_scheduler import StepBS
# We use StepBS in this example, but we can use any BS Scheduler

# Define the Dataset and the DataLoader
dataset = ...
dataloader = DataLoader(..., batch_size=16)
scheduler = StepBS(dataloader, step_size=30, gamma=2)
# Activates every 30 epochs and doubles the batch size.

for _ in range(100):
train(...)
validate(...)
scheduler.step()

# We expect the batch size to have the following values:
# epoch 0 - 29: 16
# epoch 30 - 59: 32
# epoch 60 - 89: 64
# epoch 90 - 99: 128
```

Full example:

```python
import timm
import torch.cuda
import torchvision.datasets
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import v2

from bs_scheduler import StepBS

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
train_loader = DataLoader(
torchvision.datasets.CIFAR10(
root="../data",
train=True,
download=True,
transform=transforms,
),
batch_size=100,
)
val_loader = DataLoader(
torchvision.datasets.CIFAR10(root="../data", train=False, transform=transforms),
batch_size=500,
)
scheduler = StepBS(train_loader, step_size=10)

model = timm.create_model("hf_hub:grodino/resnet18_cifar10", pretrained=False).to(
device
)
criterion = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.001)


def train():
correct = 0
total = 0

model.train()
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()

predicted = outputs.argmax(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

return correct / total


@torch.inference_mode()
def val():
correct = 0
total = 0

model.eval()
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)

predicted = outputs.argmax(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

return correct / total


def main():
for epoch in range(100):
train_accuracy = train()
val_accuracy = val()

scheduler.step()

print(train_accuracy, val_accuracy)


if __name__ == "__main__":
main()
```
Loading