Skip to content

Commit

Permalink
Harmonise VQVAE with AutoencoderKL (#248)
Browse files Browse the repository at this point in the history
* [WIP] Remove num_levels

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Remove num_levels

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add Encoder and Decoder classes

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Remove unused dropout_dim

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add more checks for the parameters and tests

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add annotations and remove __constants__

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add docstring for Encoder and Decoder

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add docstring for Encoder and Decoder

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Set dropout as float value and update docstring

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Fix torchscript error

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Update tutorials

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Update tutorials

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* remove adn_ordering

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add ActType

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Remove text about subpixel layers and ActType.

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

---------

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
  • Loading branch information
Warvito authored Feb 15, 2023
1 parent b075d86 commit 06a57bc
Show file tree
Hide file tree
Showing 10 changed files with 403 additions and 350 deletions.
450 changes: 258 additions & 192 deletions generative/networks/nets/vqvae.py

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
"upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
"num_res_layers": 1,
"num_channels": [4, 4],
"num_res_layers": 1,
"num_res_channels": [4, 4],
"downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
"upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
"num_embeddings": 16,
"embedding_dim": 3,
},
Expand Down
237 changes: 120 additions & 117 deletions tests/test_vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,146 +26,78 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": [(2, 4, 1, 1)] * 2,
"upsample_parameters": [(2, 4, 1, 1, 0)] * 2,
"num_channels": (4, 4),
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": [8, 8],
"num_embeddings": 16,
"num_res_channels": (4, 4),
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_embeddings": 8,
"embedding_dim": 8,
"embedding_init": "normal",
"commitment_cost": 0.25,
"decay": 0.5,
"epsilon": 1e-5,
"adn_ordering": "NDA",
"dropout": 0.1,
"act": "RELU",
"output_act": None,
},
(1, 1, 16, 16),
(1, 1, 16, 16),
(1, 1, 8, 8),
(1, 1, 8, 8),
],
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": [(2, 4, 1, 1)] * 2,
"upsample_parameters": [(2, 4, 1, 1, 0)] * 2,
"num_channels": (4, 4),
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": 8,
"num_embeddings": 16,
"num_res_channels": 4,
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_embeddings": 8,
"embedding_dim": 8,
"embedding_init": "normal",
"commitment_cost": 0.25,
"decay": 0.5,
"epsilon": 1e-5,
"adn_ordering": "NDA",
"dropout": 0.1,
"act": "RELU",
"output_act": None,
},
(1, 1, 16, 16),
(1, 1, 16, 16),
(1, 1, 8, 8),
(1, 1, 8, 8),
],
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": [(2, 4, 1, 1)] * 2,
"upsample_parameters": [(2, 4, 1, 1, 0)] * 2,
"num_channels": (4, 4),
"num_res_layers": 1,
"num_channels": [8, 8],
"num_res_channels": [8, 8],
"num_embeddings": 16,
"num_res_channels": (4, 4),
"downsample_parameters": (2, 4, 1, 1),
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_embeddings": 8,
"embedding_dim": 8,
"embedding_init": "normal",
"commitment_cost": 0.25,
"decay": 0.5,
"epsilon": 1e-5,
"adn_ordering": "NDA",
"dropout": 0.1,
"act": "RELU",
"output_act": None,
},
(1, 1, 16, 16),
(1, 1, 16, 16),
(1, 1, 8, 8),
(1, 1, 8, 8),
],
[
{
"spatial_dims": 3,
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": [(2, 4, 1, 1)] * 2,
"upsample_parameters": [(2, 4, 1, 1, 0)] * 2,
"num_channels": (4, 4),
"num_res_layers": 1,
"num_channels": [8, 8],
"num_res_channels": [8, 8],
"num_embeddings": 16,
"num_res_channels": (4, 4),
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": (2, 4, 1, 1, 0),
"num_embeddings": 8,
"embedding_dim": 8,
"embedding_init": "normal",
"commitment_cost": 0.25,
"decay": 0.5,
"epsilon": 1e-5,
"adn_ordering": "NDA",
"dropout": 0.1,
"act": "RELU",
"output_act": None,
},
(1, 1, 16, 16, 16),
(1, 1, 16, 16, 16),
(1, 1, 8, 8),
(1, 1, 8, 8),
],
]

# 1-channel 2D, should fail because of number of levels, number of downsamplings, number of upsamplings mismatch.
TEST_CASE_FAIL = {
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 1,
"num_levels": 3,
"downsample_parameters": [(2, 4, 1, 1)] * 2,
"upsample_parameters": [(2, 4, 1, 1, 0)] * 4,
"num_res_layers": 1,
"num_channels": [8, 8],
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 8,
"embedding_init": "normal",
"commitment_cost": 0.25,
"decay": 0.5,
"epsilon": 1e-5,
"adn_ordering": "NDA",
"dropout": 0.1,
"act": "RELU",
"output_act": None,
}

TEST_LATENT_SHAPE = {
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": [(2, 4, 1, 1)] * 2,
"upsample_parameters": [(2, 4, 1, 1, 0)] * 2,
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_res_layers": 1,
"num_channels": [8, 8],
"num_res_channels": [8, 8],
"num_channels": (8, 8),
"num_res_channels": (8, 8),
"num_embeddings": 16,
"embedding_dim": 8,
"embedding_init": "normal",
"commitment_cost": 0.25,
"decay": 0.5,
"epsilon": 1e-5,
"adn_ordering": "NDA",
"dropout": 0.1,
"act": "RELU",
"output_act": None,
}


Expand All @@ -186,30 +118,101 @@ def test_script(self):
spatial_dims=2,
in_channels=1,
out_channels=1,
num_levels=2,
downsample_parameters=tuple([(2, 4, 1, 1)] * 2),
upsample_parameters=tuple([(2, 4, 1, 1, 0)] * 2),
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
num_res_layers=1,
num_channels=[8, 8],
num_res_channels=[8, 8],
num_channels=(8, 8),
num_res_channels=(8, 8),
num_embeddings=16,
embedding_dim=8,
embedding_init="normal",
commitment_cost=0.25,
decay=0.5,
epsilon=1e-5,
adn_ordering="NDA",
dropout=0.1,
act="RELU",
output_act=None,
ddp_sync=False,
)
test_data = torch.randn(1, 1, 16, 16)
test_script_save(net, test_data)

def test_level_upsample_downsample_difference(self):
with self.assertRaises(AssertionError):
VQVAE(**TEST_CASE_FAIL)
def test_num_channels_not_same_size_of_num_res_channels(self):
with self.assertRaises(ValueError):
VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(16, 16),
num_res_channels=(16, 16, 16),
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
)

def test_num_channels_not_same_size_of_downsample_parameters(self):
with self.assertRaises(ValueError):
VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(16, 16),
num_res_channels=(16, 16),
downsample_parameters=((2, 4, 1, 1),) * 3,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
)

def test_num_channels_not_same_size_of_upsample_parameters(self):
with self.assertRaises(ValueError):
VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(16, 16),
num_res_channels=(16, 16),
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 3,
)

def test_downsample_parameters_not_sequence_or_int(self):
with self.assertRaises(ValueError):
VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(16, 16),
num_res_channels=(16, 16, 16),
downsample_parameters=(("test", 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
)

def test_upsample_parameters_not_sequence_or_int(self):
with self.assertRaises(ValueError):
VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(16, 16),
num_res_channels=(16, 16, 16),
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=(("test", 4, 1, 1, 0),) * 2,
)

def test_downsample_parameter_length_different_4(self):
with self.assertRaises(ValueError):
VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(16, 16),
num_res_channels=(16, 16, 16),
downsample_parameters=((2, 4, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 3,
)

def test_upsample_parameter_length_different_5(self):
with self.assertRaises(ValueError):
VQVAE(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(16, 16),
num_res_channels=(16, 16, 16),
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3,
)

def test_encode_shape(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down
15 changes: 6 additions & 9 deletions tests/test_vqvaetransformer_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"num_channels": (8, 8),
"num_res_channels": (8, 8),
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 8,
},
Expand All @@ -52,12 +51,11 @@
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"num_channels": (8, 8),
"num_res_channels": (8, 8),
"downsample_parameters": ((2, 4, 1, 1),) * 2,
"upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
"num_res_layers": 1,
"num_channels": 8,
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 8,
},
Expand Down Expand Up @@ -100,12 +98,11 @@ def test_sample(self):
spatial_dims=2,
in_channels=1,
out_channels=1,
num_levels=2,
num_channels=(8, 8),
num_res_channels=(8, 8),
downsample_parameters=((2, 4, 1, 1),) * 2,
upsample_parameters=((2, 4, 1, 1, 0),) * 2,
num_res_layers=1,
num_channels=8,
num_res_channels=(8, 8),
num_embeddings=16,
embedding_dim=8,
)
Expand Down
5 changes: 2 additions & 3 deletions tutorials/generative/2d_vqgan/2d_vqgan_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,11 @@
" spatial_dims=2,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" num_channels=(256, 512),\n",
" num_res_channels=512,\n",
" num_res_layers=2,\n",
" num_levels=2,\n",
" downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n",
" upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n",
" num_channels=[256, 512],\n",
" num_res_channels=[256, 512],\n",
" num_embeddings=256,\n",
" embedding_dim=32,\n",
")\n",
Expand Down
5 changes: 2 additions & 3 deletions tutorials/generative/2d_vqgan/2d_vqgan_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,11 @@
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(256, 512),
num_res_channels=512,
num_res_layers=2,
num_levels=2,
downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),
upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
num_channels=[256, 512],
num_res_channels=[256, 512],
num_embeddings=256,
embedding_dim=32,
)
Expand Down
Loading

0 comments on commit 06a57bc

Please sign in to comment.