Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning Lite Examples #9987

Merged
merged 385 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from 124 commits
Commits
Show all changes
385 commits
Select commit Hold shift + click to select a range
8ddb777
move scheduler
awaelchli Oct 19, 2021
c2b4b74
convert
awaelchli Oct 19, 2021
0d430c5
fix precision + dataloader wrapping for DP
awaelchli Oct 19, 2021
5d09298
Merge remote-tracking branch 'origin/lite-poc' into lightning-lite/li…
awaelchli Oct 19, 2021
2a21ad9
remove unused import
awaelchli Oct 19, 2021
23d6786
Merge branch 'master' into lightning-lite/litghtning-lite
awaelchli Oct 19, 2021
eeff843
Update pl_examples/lite_examples/simple/mnist_example.py
kaushikb11 Oct 19, 2021
3f7a2ce
Update pl_examples/lite_examples/simple/mnist_example.py
kaushikb11 Oct 19, 2021
61c825c
Add LightningLite Example (#9991)
tchaton Oct 19, 2021
21ada36
call process_dataloader()
awaelchli Oct 19, 2021
1bc966c
Merge branch 'master' into lightning-lite/litghtning-lite
awaelchli Oct 19, 2021
6dd770f
refactor spawn
awaelchli Oct 19, 2021
724c2a9
update new_process
awaelchli Oct 19, 2021
866e14d
call accelerator.setup_environment()
awaelchli Oct 19, 2021
a2e576b
set_device
awaelchli Oct 19, 2021
ce1b3ce
Merge branch 'master' into lightning-lite/litghtning-lite
awaelchli Oct 19, 2021
6005040
remove unused methods
awaelchli Oct 19, 2021
6674f01
Update TPUSpawn plugin to support Lightning Lite (#9999)
kaushikb11 Oct 19, 2021
5bdde62
move tpu optimizer step
awaelchli Oct 19, 2021
dab9532
Merge branch 'master' into lightning-lite/litghtning-lite
awaelchli Oct 19, 2021
9f4943c
clean up imports
awaelchli Oct 19, 2021
365fc8d
remove distributed_backend
awaelchli Oct 19, 2021
f63b85e
access protected methods
awaelchli Oct 19, 2021
2b77ce2
update precision handling in sharded
awaelchli Oct 19, 2021
db00696
sharded spawn
awaelchli Oct 19, 2021
71a6489
add spawn shaded support
awaelchli Oct 19, 2021
0a06e85
add zero grad stub to LiteOptimizer
awaelchli Oct 19, 2021
fd9a4d6
Merge branch 'master' into lightning-lite/litghtning-lite
awaelchli Oct 20, 2021
a488fd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2021
072fff0
Remove TPUAcc check in setup_dataloaders
kaushikb11 Oct 20, 2021
cd2f0d6
trigger ci
awaelchli Oct 20, 2021
4ed755b
Add parity tests for LightningLite vs. pure PyTorch (#10002)
tchaton Oct 20, 2021
75100ad
update test
tchaton Oct 20, 2021
72b47cd
Add tests for Lite wrappers (#10048)
awaelchli Oct 20, 2021
fe65b74
update closure
awaelchli Oct 20, 2021
b4a0c4a
update zero_grad
awaelchli Oct 20, 2021
e5bd182
tests for device
awaelchli Oct 20, 2021
809767e
Merge branch 'master' into lightning-lite/litghtning-lite
awaelchli Oct 20, 2021
74b11eb
merge conflict fixes
awaelchli Oct 20, 2021
4c81c78
add tests for distributed sampler
awaelchli Oct 21, 2021
b33dda2
update is_distrib access
awaelchli Oct 21, 2021
68c74ac
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 21, 2021
5935ce4
remove comment
awaelchli Oct 21, 2021
12de35c
update spawn() for tpu_spawn plugin
awaelchli Oct 21, 2021
56271bc
update bloat check
awaelchli Oct 21, 2021
0758aff
Merge branch 'lightning-lite/test-setup-data' into lightning-lite/lit…
awaelchli Oct 21, 2021
7b83347
update optimizer step test
awaelchli Oct 21, 2021
4f0d82a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
f86fd6f
add guards to example
awaelchli Oct 21, 2021
eeee4e8
Merge remote-tracking branch 'origin/lite-poc' into lightning-lite/li…
awaelchli Oct 21, 2021
f703112
move scrips to debug folder for removal later on
awaelchli Oct 21, 2021
3421372
test invalid choices
awaelchli Oct 21, 2021
3744ea4
test to_device
awaelchli Oct 21, 2021
d197b49
save checkpoint
awaelchli Oct 21, 2021
806fffc
Merge branch 'lightning-lite/tests' into lightning-lite/lite-poc
awaelchli Oct 21, 2021
2551877
update test description
awaelchli Oct 21, 2021
5c46acd
document public api
awaelchli Oct 21, 2021
98c8066
add api docs for wrappers
awaelchli Oct 21, 2021
76f2d6a
Merge branch 'lightning-lite/tests' into lightning-lite/lite-poc
awaelchli Oct 21, 2021
079fd27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
a4e035c
simple tests
awaelchli Oct 21, 2021
b14f22e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
eadc10d
Add more LightningLite tests (#10047)
tchaton Oct 21, 2021
c74b231
merge all tests together
awaelchli Oct 21, 2021
b4310ad
_num_models checks
awaelchli Oct 21, 2021
143aef8
use decorator for patching
awaelchli Oct 21, 2021
08cd122
docs for deepspeed special case
awaelchli Oct 21, 2021
f8a0f45
rename wrapper for sharded context
awaelchli Oct 21, 2021
0b30c8e
add todo
awaelchli Oct 21, 2021
538f6de
improve typing
tchaton Oct 21, 2021
ae7af78
delete debug examples
awaelchli Oct 21, 2021
fadca44
Fix sampler not being defined bug
kaushikb11 Oct 21, 2021
2074c8b
Add support for auto with the accelerator flag
kaushikb11 Oct 21, 2021
63fd036
failing assert for strategy.model
awaelchli Oct 21, 2021
240f95b
support Accelerator object and TrainingType strategy object to be pas…
awaelchli Oct 21, 2021
56e3b7a
support vararg optimizer sequence input to setup()
awaelchli Oct 21, 2021
a6cf010
remove redundant Iterable annotation from setup_dataloaders since we …
awaelchli Oct 21, 2021
070fa23
to_device overload for mypy
awaelchli Oct 21, 2021
6aa4ac0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2021
0838ea9
harden tests for setup()
awaelchli Oct 21, 2021
e6e2895
simplify test
awaelchli Oct 21, 2021
314da4a
fix mypy for setup() return type
awaelchli Oct 22, 2021
5efbfb3
organize
awaelchli Oct 22, 2021
676f765
remove dataloader type check (already checked above)
awaelchli Oct 22, 2021
f1716f8
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 22, 2021
04e1b41
update examples, setup() syntax
awaelchli Oct 22, 2021
024fa6a
skip test if dependency not available
awaelchli Oct 22, 2021
3e446d9
skip test if deepspeed unavailable
awaelchli Oct 22, 2021
18c58fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2021
186ed8d
test run() input outputs
awaelchli Oct 22, 2021
442a184
re-organize tests
awaelchli Oct 22, 2021
3539f2c
rename
awaelchli Oct 22, 2021
1510661
Add LightningLite documentation (#10043)
tchaton Oct 22, 2021
2d88340
remove "mixed"
awaelchli Oct 23, 2021
3171bee
fix title levels
awaelchli Oct 23, 2021
8d3e33d
fix spacing
awaelchli Oct 23, 2021
ce86e6e
fix title levels
awaelchli Oct 23, 2021
bad8356
add optional name to barrier
awaelchli Oct 25, 2021
8e5ddc3
re-add "mixed" as it is defined in NativeMixedPrecisionPlugin
awaelchli Oct 25, 2021
d987133
add lite flags section
awaelchli Oct 25, 2021
25c5b99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
86c163e
Update docs/source/starter/lightning_lite.rst
awaelchli Oct 25, 2021
ae1d793
Add documentation for essential methods
awaelchli Oct 25, 2021
9382fa4
add spacers
awaelchli Oct 25, 2021
219de45
add save_checkpoint, barrier to docs
awaelchli Oct 25, 2021
66a3f1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
ec7565f
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 25, 2021
16ac4df
update lite with latest master changes
awaelchli Oct 25, 2021
debe472
remove unused method in tpu spawn
awaelchli Oct 25, 2021
f3cb163
remove unused import
awaelchli Oct 25, 2021
1687387
fix precommit formatting issue
awaelchli Oct 25, 2021
5d0a72b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
7906eb3
remove reduce_decision and execute_on_rank
awaelchli Oct 25, 2021
4550510
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 25, 2021
0482860
update on comments
tchaton Oct 26, 2021
2359d0c
Add barrier for TPU Spawn
kaushikb11 Oct 26, 2021
915438f
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 26, 2021
889e319
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 26, 2021
d566082
Add Mnist examples with lite (#10131)
tchaton Oct 26, 2021
f22ac90
update
tchaton Oct 26, 2021
b914b74
update
tchaton Oct 26, 2021
2969f68
update
tchaton Oct 26, 2021
57f82e9
update
tchaton Oct 26, 2021
bc082a9
resolve doctest
tchaton Oct 26, 2021
ef6b591
resolve mypy
tchaton Oct 26, 2021
c90aff5
update
tchaton Oct 26, 2021
fd8660c
switch to Any
tchaton Oct 26, 2021
4f8e3a5
update
tchaton Oct 26, 2021
e5fd5b6
update
tchaton Oct 26, 2021
bca53f6
update
tchaton Oct 26, 2021
e0cee6a
resolve bugs
tchaton Oct 26, 2021
144eee4
update
tchaton Oct 26, 2021
2c6214b
Revert "switch to Any"
tchaton Oct 26, 2021
2ccae27
try to fix mypy
awaelchli Oct 26, 2021
0f11f70
x
awaelchli Oct 26, 2021
d2c27dc
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 26, 2021
782c70f
lightning lite package and tests
awaelchli Oct 27, 2021
be39098
update changelog
awaelchli Oct 27, 2021
824c11d
update test to ensure spawn result
awaelchli Oct 27, 2021
81636fe
Add sleep to fix the rendezous error
kaushikb11 Oct 27, 2021
34b0e89
Merge branch 'master' into lightning-lite/lite-core
tchaton Oct 27, 2021
e45f736
update
tchaton Oct 27, 2021
0deceba
Docstrings and CHANGELOG
carmocca Oct 27, 2021
5d14e83
Fixes to previous commit. Mention devices=auto (not yet implemented).…
carmocca Oct 27, 2021
11862e8
Fix test
carmocca Oct 27, 2021
ffed5ce
Fix test
carmocca Oct 27, 2021
538b969
Merge branch 'lightning-lite/lite-core' of https://github.com/PyTorch…
carmocca Oct 27, 2021
c614cf0
Improve Lite Examples (#10195)
tchaton Oct 28, 2021
93b7940
update
tchaton Oct 28, 2021
13587df
Merge branch 'master' into lightning-lite/lite-core
awaelchli Oct 28, 2021
84ac310
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 28, 2021
cf34e7b
Merge branch 'master' into lightning-lite/lite-core
awaelchli Oct 28, 2021
a6414a2
update access to deepspeed internal vars
awaelchli Oct 28, 2021
2e55f6c
Merge branch 'master' into lightning-lite/lite-poc
awaelchli Oct 28, 2021
5e1aeb8
fix check for multiple models in deepspeed
awaelchli Oct 28, 2021
f885b35
fix deepspeed precision
awaelchli Oct 28, 2021
de4ef79
update
tchaton Oct 28, 2021
7a474a7
Merge branch 'master' into lite-poc
tchaton Oct 28, 2021
7f62394
update
tchaton Oct 28, 2021
5546084
Merge branch 'master' into lightning-lite/lite-core
awaelchli Oct 28, 2021
db34e09
fix line too long
awaelchli Oct 28, 2021
992fd45
Minor changes
carmocca Oct 28, 2021
b8d44ce
remove identity wrapper
awaelchli Oct 28, 2021
732de7a
Merge remote-tracking branch 'origin/lightning-lite/lite-core' into l…
awaelchli Oct 28, 2021
04094c3
Same annotations as Lightning which are identical to those in torch
carmocca Oct 28, 2021
1d9920a
Add comment
carmocca Oct 28, 2021
a6df052
Simplify _LiteOptimizer
carmocca Oct 28, 2021
5208e19
Didn't mean to remove this :)
carmocca Oct 28, 2021
31406ae
rename cast to autocast
awaelchli Oct 28, 2021
bda0f8a
test: Remove unused parametrization
carmocca Oct 28, 2021
c34d006
rename save_checkpoint to save
awaelchli Oct 28, 2021
f45c2c8
update docstring
awaelchli Oct 28, 2021
c84acb1
update comment
awaelchli Oct 28, 2021
ddd7c4f
Merge remote-tracking branch 'origin/lightning-lite/lite-core' into l…
awaelchli Oct 28, 2021
92752e6
add load
awaelchli Oct 28, 2021
c0ffc71
tests: update autocast use
carmocca Oct 28, 2021
af40009
add test for autocast
awaelchli Oct 28, 2021
eb9b92e
simplify test
awaelchli Oct 28, 2021
3e261e1
add test description
awaelchli Oct 28, 2021
5754ad7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2021
85fe0cf
remove "mixed" string support
awaelchli Oct 28, 2021
91a6b3c
More mixed references
carmocca Oct 28, 2021
f45a97a
Implement `seed_everything`
carmocca Oct 28, 2021
b543655
Merge branch 'master' into lightning-lite/lite-core
awaelchli Oct 28, 2021
ba7ac5f
add isinstance check
awaelchli Oct 28, 2021
f04b398
add bfloat16
awaelchli Oct 28, 2021
229b024
rename params_on_cpu
awaelchli Oct 28, 2021
95db246
Pass down the barrier name
carmocca Oct 28, 2021
0c8e914
Add back __del__
carmocca Oct 28, 2021
a93278d
Fix mypy
carmocca Oct 28, 2021
65e289b
Fix test
carmocca Oct 28, 2021
d408228
Add worker init fn
carmocca Oct 28, 2021
952e11c
Forgot to pass the global rank
carmocca Oct 28, 2021
50d5124
add back skip of expensive spawn test
awaelchli Oct 28, 2021
13fb58a
resolve todo in _LiteModule
awaelchli Oct 28, 2021
2e92fe1
Merge remote-tracking branch 'origin/lightning-lite/lite-core' into l…
awaelchli Oct 28, 2021
9a1e93f
Add seed everything test
carmocca Oct 28, 2021
f47c2ad
fix type error
awaelchli Oct 28, 2021
4c40d71
Merge branch 'lightning-lite/lite-core' into lightning-lite/lite-poc
awaelchli Oct 29, 2021
d7b430f
Update pl_examples/basic_examples/mnist_examples/image_classifier_2_l…
awaelchli Oct 29, 2021
d51c71c
Update pl_examples/basic_examples/mnist_examples/image_classifier_2_l…
awaelchli Oct 29, 2021
5b5fbd5
update on master
tchaton Oct 30, 2021
ae2fe70
update examples
tchaton Oct 30, 2021
3d4a5ef
update
tchaton Oct 30, 2021
0ffc7d2
update
tchaton Oct 30, 2021
1b3fb60
update
tchaton Oct 30, 2021
357869e
replace links with file paths
awaelchli Nov 1, 2021
70067df
fix link
awaelchli Nov 1, 2021
24f0ff4
typos, grammar, fix links
awaelchli Nov 1, 2021
2cd5ba4
create a sentence
awaelchli Nov 1, 2021
99acc1e
duplicate fixes
awaelchli Nov 1, 2021
68fcf07
Merge branch 'master' into lite-poc
tchaton Nov 1, 2021
b5e9a94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
3e100f7
fix changelog
awaelchli Nov 1, 2021
e01da0b
typos and formatting in mnist lite/lightning examples
awaelchli Nov 1, 2021
c385998
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
eb53884
fixes for loop example mnist_lite
awaelchli Nov 1, 2021
d32f428
undo gitignore changes
awaelchli Nov 1, 2021
ac75983
use strategy arg
awaelchli Nov 1, 2021
b441745
backticks for run
awaelchli Nov 1, 2021
15588c5
Merge remote-tracking branch 'origin/lite-poc' into lightning-lite/li…
awaelchli Nov 1, 2021
8045e0c
auto accelerator and devices
awaelchli Nov 1, 2021
5b6243f
update comments about dp/strategy
awaelchli Nov 1, 2021
51ca1e6
address a couple missed commens from ari
awaelchli Nov 1, 2021
9fd0bba
Update pl_examples/basic_examples/README.md
awaelchli Nov 1, 2021
60a65dc
update links, latest -> stable
awaelchli Nov 1, 2021
f33911d
switch order in run_examples.sh
awaelchli Nov 1, 2021
6e76183
capitalization
awaelchli Nov 1, 2021
d644ba6
Merge remote-tracking branch 'origin/lite-poc' into lightning-lite/li…
awaelchli Nov 1, 2021
6c7e630
Update pl_examples/basic_examples/README.md
awaelchli Nov 1, 2021
be1d820
Update pl_examples/basic_examples/README.md
awaelchli Nov 1, 2021
138f4f5
Merge branch 'master' into lite-poc
tchaton Nov 1, 2021
09ccc0d
update on comments
tchaton Nov 1, 2021
a27d0e3
hotfix
tchaton Nov 1, 2021
8a970ac
update
tchaton Nov 1, 2021
fd3d286
update
tchaton Nov 1, 2021
29bb0c9
update
tchaton Nov 1, 2021
3b9496b
remove test.py
tchaton Nov 1, 2021
8139093
update
tchaton Nov 1, 2021
af13b28
update
tchaton Nov 1, 2021
af1ad85
update
tchaton Nov 1, 2021
389e535
update
tchaton Nov 1, 2021
b4b63fb
update
tchaton Nov 1, 2021
3b82b57
update
tchaton Nov 1, 2021
7a88161
update
tchaton Nov 1, 2021
c05a102
Merge branch 'master' into lite-poc
tchaton Nov 1, 2021
506fc19
update
tchaton Nov 1, 2021
81ea111
update
tchaton Nov 1, 2021
5209601
update
tchaton Nov 1, 2021
33b8758
update
tchaton Nov 2, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
180 changes: 180 additions & 0 deletions pl_examples/lite_examples/gan/gan_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""
DCGAN - Adapted from pytorch/examples

Launch it with this command:

python -m torch.distributed.run --nproc_per_node=2 gan_example.py

"""

import argparse
import os
import random

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DistributedSampler

from pl_examples.lite_examples.gan.models import Discriminator, Generator, weights_init
from pytorch_lightning import seed_everything
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteModule, _LiteOptimizer

parser = argparse.ArgumentParser()
parser.add_argument("--workers", type=int, help="number of data loading workers", default=0)
parser.add_argument("--batchSize", type=int, default=64, help="input batch size")
parser.add_argument(
"--imageSize",
type=int,
default=64,
help="the height / width of the input image to network",
)
parser.add_argument("--niter", type=int, default=25, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate, default=0.0002")
parser.add_argument("--beta1", type=float, default=0.5, help="beta1 for adam. default=0.5")
parser.add_argument("--ngpu", type=int, default=1, help="number of GPUs to use")
parser.add_argument("--netG", default="", help="path to netG (to continue training)")
parser.add_argument("--netD", default="", help="path to netD (to continue training)")
parser.add_argument("--outf", default="./lightning_logs", help="folder to output images and model checkpoints")
parser.add_argument("--local_rank", type=int, default=0)

opt, _ = parser.parse_known_args()
os.makedirs(opt.outf, exist_ok=True)
ngpu = int(opt.ngpu)

nz = 100


class GANTrainer(LightningLite):
def run(self):
print("strategy: ", self._strategy)
print("precision plugin: ", self._precision_plugin)
seed_everything(123)

# TODO: how do we handle this in Accelerator?
# torch.cuda.set_device(opt.local_rank)
# TODO: how do we handle this?
# os.environ["LOCAL_RANK"] = str(opt.local_rank)
# os.environ["NODE_RANK"] = str(opt.local_rank)

if self.local_rank == 0:
dset.MNIST(root=".", download=True)

self.barrier()
dataset = dset.MNIST(
root=".",
transform=transforms.Compose(
[
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
),
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=opt.batchSize, shuffle=True, num_workers=opt.workers
)

dataloader = self.setup_dataloaders(dataloader)
# assert isinstance(dataloader.sampler, DistributedSampler)

netG = Generator()
netG.apply(weights_init)

netD = Discriminator()
netD.apply(weights_init)

# self.to_device(netG)
# self.to_device(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=self.device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

netG, optimizerG = self.setup(netG, optimizerG)
netD, optimizerD = self.setup(netD, optimizerD)

assert isinstance(optimizerG, _LiteOptimizer)
assert isinstance(netG, _LiteModule)
print("parameters dtype", next(netG.parameters()).dtype)

for epoch in range(opt.niter):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_cpu = data[0]
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), real_label, dtype=real_cpu.dtype, device=self.device)

output = netD(real_cpu)
errD_real = criterion(output, label)
self.backward(errD_real)
D_x = output.mean().item()

# train with fake
noise = torch.randn(batch_size, nz, 1, 1, device=self.device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach())
errD_fake = criterion(output, label)
self.backward(errD_fake)
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()

############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, label)
self.backward(errG)
D_G_z2 = output.mean().item()
optimizerG.step()

print(
"[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f"
% (
epoch,
opt.niter,
i,
len(dataloader),
errD.item(),
errG.item(),
D_x,
D_G_z1,
D_G_z2,
)
)
if i % 100 == 0:
vutils.save_image(real_cpu, "%s/real_samples.png" % opt.outf, normalize=True)
fake = netG(fixed_noise)
vutils.save_image(
fake.detach(),
"%s/fake_samples_epoch_%03d.png" % (opt.outf, epoch),
normalize=True,
)
# do checkpointing
torch.save(netG.state_dict(), "%s/netG_epoch_%d.pth" % (opt.outf, epoch))
torch.save(netD.state_dict(), "%s/netD_epoch_%d.pth" % (opt.outf, epoch))


if __name__ == "__main__":
gan = GANTrainer(accelerator="ddp", devices=2)
gan.run()
78 changes: 78 additions & 0 deletions pl_examples/lite_examples/gan/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
from torch import nn as nn

nc = 1
nz = 100
ngf = 64
ndf = 64


def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)

def forward(self, input):
print("autocast enabled in generator: ", torch.is_autocast_enabled())
return self.main(input)


class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)

def forward(self, input):
print("autocast enabled in discriminator: ", torch.is_autocast_enabled())
output = self.main(input)
print("double precision: ", input.dtype == torch.double)
return output.view(-1, 1).squeeze(1)
15 changes: 15 additions & 0 deletions pl_examples/lite_examples/gan/run_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import argparse

from pl_examples.lite_examples.gan.gan_example import GANTrainer

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--accelerator", type=str, default=None)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--strategy", type=str, default=None)
parser.add_argument("--gpus", type=int, default=None)
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--precision", type=int, default=32)
args = parser.parse_args()

trainer = GANTrainer(**vars(args))
trainer.run()
Empty file.
Loading