Skip to content

Commit

Permalink
updates on docs et notebook + resolve issue #17
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Dec 2, 2024
1 parent 1fd81f4 commit 58d8fba
Show file tree
Hide file tree
Showing 16 changed files with 1,815 additions and 574 deletions.
2 changes: 1 addition & 1 deletion deel/torchlip/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def forward(self, x):

class BatchCentering(nn.Module):
r"""
Applies Batch Normalization over a 2D, 3D, 4D input.
Applies Batch Centering over a 2D, 3D, 4D input.
.. math::
Expand Down
93 changes: 62 additions & 31 deletions docs/notebooks/wasserstein_classification_MNIST08.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down Expand Up @@ -200,33 +193,33 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"loss: -0.0655 - KR: 3.3978 - acc: 0.9913 - val_loss: -0.0769 - val_KR: 4.2157 - val_acc: 0.9933\n",
"loss: -0.0340 - KR: 1.2288 - acc: 0.8649 - val_loss: -0.0363 - val_KR: 2.3215 - val_acc: 0.9928\n",
"Epoch 2/10\n",
"loss: -0.1013 - KR: 4.7773 - acc: 0.9945 - val_loss: -0.0989 - val_KR: 5.3608 - val_acc: 0.9928\n",
"loss: -0.0630 - KR: 2.8186 - acc: 0.9943 - val_loss: -0.0607 - val_KR: 3.4102 - val_acc: 0.9939\n",
"Epoch 3/10\n",
"loss: -0.0946 - KR: 5.6133 - acc: 0.9951 - val_loss: -0.1112 - val_KR: 5.9211 - val_acc: 0.9949\n",
"loss: -0.0901 - KR: 3.8766 - acc: 0.9960 - val_loss: -0.0805 - val_KR: 4.4241 - val_acc: 0.9939\n",
"Epoch 4/10\n",
"loss: -0.1145 - KR: 6.0779 - acc: 0.9963 - val_loss: -0.1180 - val_KR: 6.2546 - val_acc: 0.9939\n",
"loss: -0.0964 - KR: 4.7411 - acc: 0.9965 - val_loss: -0.0957 - val_KR: 5.1178 - val_acc: 0.9933\n",
"Epoch 5/10\n",
"loss: -0.1133 - KR: 6.2920 - acc: 0.9962 - val_loss: -0.1206 - val_KR: 6.3919 - val_acc: 0.9944\n",
"loss: -0.1084 - KR: 5.3850 - acc: 0.9957 - val_loss: -0.1036 - val_KR: 5.7095 - val_acc: 0.9923\n",
"Epoch 6/10\n",
"loss: -0.1371 - KR: 6.5019 - acc: 0.9965 - val_loss: -0.1255 - val_KR: 6.6471 - val_acc: 0.9939\n",
"loss: -0.1095 - KR: 5.8155 - acc: 0.9954 - val_loss: -0.1126 - val_KR: 6.0285 - val_acc: 0.9944\n",
"Epoch 7/10\n",
"loss: -0.1226 - KR: 6.6214 - acc: 0.9969 - val_loss: -0.1261 - val_KR: 6.7408 - val_acc: 0.9939\n",
"loss: -0.1090 - KR: 6.1108 - acc: 0.9960 - val_loss: -0.1178 - val_KR: 6.3084 - val_acc: 0.9933\n",
"Epoch 8/10\n",
"loss: -0.1395 - KR: 6.7325 - acc: 0.9967 - val_loss: -0.1280 - val_KR: 6.7204 - val_acc: 0.9944\n",
"loss: -0.1266 - KR: 6.3128 - acc: 0.9959 - val_loss: -0.1192 - val_KR: 6.4553 - val_acc: 0.9923\n",
"Epoch 9/10\n",
"loss: -0.1271 - KR: 6.7927 - acc: 0.9971 - val_loss: -0.1255 - val_KR: 6.8759 - val_acc: 0.9898\n",
"loss: -0.1263 - KR: 6.4460 - acc: 0.9966 - val_loss: -0.1208 - val_KR: 6.4837 - val_acc: 0.9939\n",
"Epoch 10/10\n",
"loss: -0.1316 - KR: 6.8134 - acc: 0.9970 - val_loss: -0.1286 - val_KR: 6.8696 - val_acc: 0.9928\n"
"loss: -0.1316 - KR: 6.5416 - acc: 0.9967 - val_loss: -0.1240 - val_KR: 6.6313 - val_acc: 0.9933\n"
]
}
],
Expand Down Expand Up @@ -326,14 +319,14 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.1439)\n"
"tensor(0.1420)\n"
]
}
],
Expand Down Expand Up @@ -361,14 +354,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.8923, dtype=torch.float64)\n"
"tensor(0.8950, dtype=torch.float64)\n"
]
}
],
Expand Down Expand Up @@ -403,7 +396,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -428,7 +421,7 @@
" (1): _BjorckNorm()\n",
" )\n",
" )\n",
"), min=0.9999998211860657, max=1.0000001192092896\n",
"), min=0.9999998211860657, max=1.000000238418579\n",
"ParametrizedSpectralLinear(\n",
" in_features=64, out_features=32, bias=True\n",
" (parametrizations): ModuleDict(\n",
Expand All @@ -437,15 +430,15 @@
" (1): _BjorckNorm()\n",
" )\n",
" )\n",
"), min=0.9999998211860657, max=1.0\n",
"), min=0.9999998807907104, max=1.0\n",
"ParametrizedFrobeniusLinear(\n",
" in_features=32, out_features=1, bias=True\n",
" (parametrizations): ModuleDict(\n",
" (weight): ParametrizationList(\n",
" (0): _FrobeniusNorm()\n",
" )\n",
" )\n",
"), min=0.9999999403953552, max=0.9999999403953552\n"
"), min=0.9999998807907104, max=0.9999998807907104\n"
]
}
],
Expand All @@ -459,9 +452,47 @@
" print(f\"{layer}, min={s.min()}, max={s.max()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2 Model export\n",
"\n",
"Once training is finished, the model can be optimized for inference by using the\n",
"`vanilla_export()` method. The `torchlip` layers are converted to their PyTorch\n",
"counterparts, e.g. `SpectralConv2d` layers will be converted into `torch.nn.Conv2d`\n",
"layers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Warnings:\n",
"vanilla_export method modifies the model in-place.\n",
"\n",
"In order to build and export a new model while keeping the reference one, it is required to follow these steps:\n",
"\n",
"\\# Build e new mode for instance with torchlip.Sequential( torchlip.SpectralConv2d(...), ...)\n",
"\n",
"`wexport = <your_function_to_build_the_model>()`\n",
"\n",
"\\# Copy the parameters from the reference t the new model\n",
"\n",
"`wexport.load_state_dict(wass.state_dict())`\n",
"\n",
"\\# one forward required to initialize pamatrizations\n",
"\n",
"`vanilla_model(one_input)`\n",
"\n",
"\\# vanilla_export the new model\n",
"\n",
"`wexport = wexport.vanilla_export()`"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -470,9 +501,9 @@
"text": [
"=== After export ===\n",
"Linear(in_features=784, out_features=128, bias=True), min=0.9999998211860657, max=1.0\n",
"Linear(in_features=128, out_features=64, bias=True), min=0.9999998211860657, max=1.0000001192092896\n",
"Linear(in_features=64, out_features=32, bias=True), min=0.9999998211860657, max=1.0\n",
"Linear(in_features=32, out_features=1, bias=True), min=0.9999999403953552, max=0.9999999403953552\n"
"Linear(in_features=128, out_features=64, bias=True), min=0.9999998211860657, max=1.000000238418579\n",
"Linear(in_features=64, out_features=32, bias=True), min=0.9999998807907104, max=1.0\n",
"Linear(in_features=32, out_features=1, bias=True), min=0.9999998807907104, max=0.9999998807907104\n"
]
}
],
Expand Down Expand Up @@ -505,7 +536,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "deel-pt1.10",
"language": "python",
"name": "python3"
},
Expand Down
Loading

0 comments on commit 58d8fba

Please sign in to comment.