diff --git a/_downloads/315c4c52fb68082a731b192d944e2ede/tutorials_python.zip b/_downloads/315c4c52fb68082a731b192d944e2ede/tutorials_python.zip index d82406b83..0a1c7b62c 100644 Binary files a/_downloads/315c4c52fb68082a731b192d944e2ede/tutorials_python.zip and b/_downloads/315c4c52fb68082a731b192d944e2ede/tutorials_python.zip differ diff --git a/_downloads/a5659940aa3f8f568547d47752a43172/tutorials_jupyter.zip b/_downloads/a5659940aa3f8f568547d47752a43172/tutorials_jupyter.zip index 5788b4cad..cc466e1c2 100644 Binary files a/_downloads/a5659940aa3f8f568547d47752a43172/tutorials_jupyter.zip and b/_downloads/a5659940aa3f8f568547d47752a43172/tutorials_jupyter.zip differ diff --git a/_sources/sg_execution_times.rst.txt b/_sources/sg_execution_times.rst.txt index 8b1cdf166..1ebfef416 100644 --- a/_sources/sg_execution_times.rst.txt +++ b/_sources/sg_execution_times.rst.txt @@ -6,7 +6,7 @@ Computation times ================= -**02:01.082** total execution time for 10 files **from all galleries**: +**01:59.897** total execution time for 10 files **from all galleries**: .. container:: @@ -33,22 +33,22 @@ Computation times - Time - Mem (MB) * - :ref:`sphx_glr_tutorials_tensorclass_fashion.py` (``reference/generated/tutorials/tensorclass_fashion.py``) - - 00:58.981 + - 00:58.373 - 0.0 * - :ref:`sphx_glr_tutorials_data_fashion.py` (``reference/generated/tutorials/data_fashion.py``) - - 00:50.364 + - 00:49.828 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_module.py` (``reference/generated/tutorials/tensordict_module.py``) - - 00:10.144 + - 00:10.137 - 0.0 * - :ref:`sphx_glr_tutorials_tensorclass_imagenet.py` (``reference/generated/tutorials/tensorclass_imagenet.py``) - - 00:01.535 + - 00:01.501 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_memory.py` (``reference/generated/tutorials/tensordict_memory.py``) - - 00:00.023 + - 00:00.025 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_keys.py` (``reference/generated/tutorials/tensordict_keys.py``) - - 00:00.010 + - 00:00.009 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_shapes.py` (``reference/generated/tutorials/tensordict_shapes.py``) - 00:00.008 @@ -57,7 +57,7 @@ Computation times - 00:00.007 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_module_functional.py` (``reference/generated/tutorials/tensordict_module_functional.py``) - - 00:00.007 + - 00:00.006 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_preallocation.py` (``reference/generated/tutorials/tensordict_preallocation.py``) - 00:00.003 diff --git a/_sources/tutorials/data_fashion.rst.txt b/_sources/tutorials/data_fashion.rst.txt index f53654983..bc3a38488 100644 --- a/_sources/tutorials/data_fashion.rst.txt +++ b/_sources/tutorials/data_fashion.rst.txt @@ -423,156 +423,156 @@ adjust how we unpack the data to the more explicit key-based retrieval offered b is_shared=False) Epoch 1 ------------------------- - loss: 2.297689 [ 0/60000] - loss: 2.280892 [ 6400/60000] - loss: 2.274654 [12800/60000] - loss: 2.279074 [19200/60000] - loss: 2.246684 [25600/60000] - loss: 2.236528 [32000/60000] - loss: 2.231100 [38400/60000] - loss: 2.205158 [44800/60000] - loss: 2.201136 [51200/60000] - loss: 2.183861 [57600/60000] + loss: 2.300480 [ 0/60000] + loss: 2.293064 [ 6400/60000] + loss: 2.266259 [12800/60000] + loss: 2.261617 [19200/60000] + loss: 2.261147 [25600/60000] + loss: 2.215564 [32000/60000] + loss: 2.237646 [38400/60000] + loss: 2.195081 [44800/60000] + loss: 2.190470 [51200/60000] + loss: 2.164365 [57600/60000] Test Error: - Accuracy: 46.9%, Avg loss: 2.167354 + Accuracy: 41.1%, Avg loss: 2.155924 Epoch 2 ------------------------- - loss: 2.171290 [ 0/60000] - loss: 2.158412 [ 6400/60000] - loss: 2.111888 [12800/60000] - loss: 2.133768 [19200/60000] - loss: 2.084503 [25600/60000] - loss: 2.034350 [32000/60000] - loss: 2.050703 [38400/60000] - loss: 1.979632 [44800/60000] - loss: 1.976614 [51200/60000] - loss: 1.923289 [57600/60000] + loss: 2.162857 [ 0/60000] + loss: 2.156318 [ 6400/60000] + loss: 2.098328 [12800/60000] + loss: 2.116032 [19200/60000] + loss: 2.078008 [25600/60000] + loss: 2.007684 [32000/60000] + loss: 2.046304 [38400/60000] + loss: 1.959306 [44800/60000] + loss: 1.963503 [51200/60000] + loss: 1.905034 [57600/60000] Test Error: - Accuracy: 60.4%, Avg loss: 1.910134 + Accuracy: 56.5%, Avg loss: 1.897783 Epoch 3 ------------------------- - loss: 1.936399 [ 0/60000] - loss: 1.905847 [ 6400/60000] - loss: 1.795367 [12800/60000] - loss: 1.838003 [19200/60000] - loss: 1.743419 [25600/60000] - loss: 1.686000 [32000/60000] - loss: 1.699681 [38400/60000] - loss: 1.603609 [44800/60000] - loss: 1.619376 [51200/60000] - loss: 1.525605 [57600/60000] + loss: 1.922749 [ 0/60000] + loss: 1.898134 [ 6400/60000] + loss: 1.786096 [12800/60000] + loss: 1.830101 [19200/60000] + loss: 1.727412 [25600/60000] + loss: 1.672580 [32000/60000] + loss: 1.700479 [38400/60000] + loss: 1.590516 [44800/60000] + loss: 1.615487 [51200/60000] + loss: 1.517848 [57600/60000] Test Error: - Accuracy: 61.5%, Avg loss: 1.536318 + Accuracy: 61.8%, Avg loss: 1.532749 Epoch 4 ------------------------- - loss: 1.596518 [ 0/60000] - loss: 1.559388 [ 6400/60000] - loss: 1.415173 [12800/60000] - loss: 1.491498 [19200/60000] - loss: 1.381014 [25600/60000] - loss: 1.365470 [32000/60000] - loss: 1.370804 [38400/60000] - loss: 1.300947 [44800/60000] - loss: 1.330507 [51200/60000] - loss: 1.238865 [57600/60000] + loss: 1.593874 [ 0/60000] + loss: 1.560890 [ 6400/60000] + loss: 1.416031 [12800/60000] + loss: 1.487803 [19200/60000] + loss: 1.369440 [25600/60000] + loss: 1.358816 [32000/60000] + loss: 1.374556 [38400/60000] + loss: 1.287937 [44800/60000] + loss: 1.326393 [51200/60000] + loss: 1.228221 [57600/60000] Test Error: - Accuracy: 63.2%, Avg loss: 1.261017 + Accuracy: 64.3%, Avg loss: 1.257580 Epoch 5 ------------------------- - loss: 1.331413 [ 0/60000] - loss: 1.310852 [ 6400/60000] - loss: 1.153299 [12800/60000] - loss: 1.264552 [19200/60000] - loss: 1.141564 [25600/60000] - loss: 1.159877 [32000/60000] - loss: 1.169980 [38400/60000] - loss: 1.117384 [44800/60000] - loss: 1.152227 [51200/60000] - loss: 1.074354 [57600/60000] + loss: 1.331571 [ 0/60000] + loss: 1.314824 [ 6400/60000] + loss: 1.154153 [12800/60000] + loss: 1.257705 [19200/60000] + loss: 1.134324 [25600/60000] + loss: 1.153518 [32000/60000] + loss: 1.174747 [38400/60000] + loss: 1.101449 [44800/60000] + loss: 1.145374 [51200/60000] + loss: 1.062342 [57600/60000] Test Error: - Accuracy: 64.7%, Avg loss: 1.091176 + Accuracy: 65.2%, Avg loss: 1.087329 - TensorDict training done! time: 8.2466 s + TensorDict training done! time: 8.1166 s Epoch 1 ------------------------- - loss: 2.302497 [ 0/60000] - loss: 2.293708 [ 6400/60000] - loss: 2.275291 [12800/60000] - loss: 2.277487 [19200/60000] - loss: 2.256965 [25600/60000] - loss: 2.226137 [32000/60000] - loss: 2.235041 [38400/60000] - loss: 2.204915 [44800/60000] - loss: 2.205817 [51200/60000] - loss: 2.175438 [57600/60000] + loss: 2.298001 [ 0/60000] + loss: 2.287161 [ 6400/60000] + loss: 2.270800 [12800/60000] + loss: 2.270414 [19200/60000] + loss: 2.245303 [25600/60000] + loss: 2.225265 [32000/60000] + loss: 2.231304 [38400/60000] + loss: 2.200896 [44800/60000] + loss: 2.198071 [51200/60000] + loss: 2.161562 [57600/60000] Test Error: - Accuracy: 41.0%, Avg loss: 2.167767 + Accuracy: 50.7%, Avg loss: 2.157821 Epoch 2 ------------------------- - loss: 2.173955 [ 0/60000] - loss: 2.167189 [ 6400/60000] - loss: 2.113836 [12800/60000] - loss: 2.137049 [19200/60000] - loss: 2.091704 [25600/60000] - loss: 2.024699 [32000/60000] - loss: 2.058104 [38400/60000] - loss: 1.987428 [44800/60000] - loss: 1.997960 [51200/60000] - loss: 1.928781 [57600/60000] + loss: 2.168909 [ 0/60000] + loss: 2.158751 [ 6400/60000] + loss: 2.096658 [12800/60000] + loss: 2.108140 [19200/60000] + loss: 2.061918 [25600/60000] + loss: 2.008899 [32000/60000] + loss: 2.032122 [38400/60000] + loss: 1.955184 [44800/60000] + loss: 1.954244 [51200/60000] + loss: 1.873682 [57600/60000] Test Error: - Accuracy: 57.3%, Avg loss: 1.920751 + Accuracy: 58.9%, Avg loss: 1.876599 Epoch 3 ------------------------- - loss: 1.947450 [ 0/60000] - loss: 1.923763 [ 6400/60000] - loss: 1.809623 [12800/60000] - loss: 1.859243 [19200/60000] - loss: 1.749499 [25600/60000] - loss: 1.684024 [32000/60000] - loss: 1.717253 [38400/60000] - loss: 1.618240 [44800/60000] - loss: 1.651242 [51200/60000] - loss: 1.544094 [57600/60000] + loss: 1.909632 [ 0/60000] + loss: 1.884026 [ 6400/60000] + loss: 1.755143 [12800/60000] + loss: 1.790438 [19200/60000] + loss: 1.690238 [25600/60000] + loss: 1.638952 [32000/60000] + loss: 1.660928 [38400/60000] + loss: 1.561142 [44800/60000] + loss: 1.582468 [51200/60000] + loss: 1.474488 [57600/60000] Test Error: - Accuracy: 61.4%, Avg loss: 1.550630 + Accuracy: 61.6%, Avg loss: 1.500441 Epoch 4 ------------------------- - loss: 1.611791 [ 0/60000] - loss: 1.576542 [ 6400/60000] - loss: 1.422823 [12800/60000] - loss: 1.509073 [19200/60000] - loss: 1.374687 [25600/60000] - loss: 1.361015 [32000/60000] - loss: 1.386337 [38400/60000] - loss: 1.310513 [44800/60000] - loss: 1.355154 [51200/60000] - loss: 1.253993 [57600/60000] + loss: 1.560883 [ 0/60000] + loss: 1.537471 [ 6400/60000] + loss: 1.383097 [12800/60000] + loss: 1.453247 [19200/60000] + loss: 1.343728 [25600/60000] + loss: 1.335195 [32000/60000] + loss: 1.352697 [38400/60000] + loss: 1.274515 [44800/60000] + loss: 1.309380 [51200/60000] + loss: 1.211425 [57600/60000] Test Error: - Accuracy: 63.9%, Avg loss: 1.267246 + Accuracy: 63.2%, Avg loss: 1.241849 Epoch 5 ------------------------- - loss: 1.339752 [ 0/60000] - loss: 1.319095 [ 6400/60000] - loss: 1.151958 [12800/60000] - loss: 1.273998 [19200/60000] - loss: 1.129366 [25600/60000] - loss: 1.151123 [32000/60000] - loss: 1.181103 [38400/60000] - loss: 1.123158 [44800/60000] - loss: 1.170845 [51200/60000] - loss: 1.086041 [57600/60000] + loss: 1.310621 [ 0/60000] + loss: 1.302632 [ 6400/60000] + loss: 1.136563 [12800/60000] + loss: 1.238647 [19200/60000] + loss: 1.124559 [25600/60000] + loss: 1.142481 [32000/60000] + loss: 1.167050 [38400/60000] + loss: 1.099875 [44800/60000] + loss: 1.140567 [51200/60000] + loss: 1.059313 [57600/60000] Test Error: - Accuracy: 65.0%, Avg loss: 1.093209 + Accuracy: 64.6%, Avg loss: 1.082458 - Training done! time: 32.7524 s + Training done! time: 32.8354 s @@ -580,7 +580,7 @@ adjust how we unpack the data to the more explicit key-based retrieval offered b .. rst-class:: sphx-glr-timing - **Total running time of the script:** (0 minutes 50.364 seconds) + **Total running time of the script:** (0 minutes 49.828 seconds) .. _sphx_glr_download_tutorials_data_fashion.py: diff --git a/_sources/tutorials/sg_execution_times.rst.txt b/_sources/tutorials/sg_execution_times.rst.txt index cb0a8771e..f63d3e7db 100644 --- a/_sources/tutorials/sg_execution_times.rst.txt +++ b/_sources/tutorials/sg_execution_times.rst.txt @@ -6,7 +6,7 @@ Computation times ================= -**02:01.082** total execution time for 10 files **from tutorials**: +**01:59.897** total execution time for 10 files **from tutorials**: .. container:: @@ -33,22 +33,22 @@ Computation times - Time - Mem (MB) * - :ref:`sphx_glr_tutorials_tensorclass_fashion.py` (``tensorclass_fashion.py``) - - 00:58.981 + - 00:58.373 - 0.0 * - :ref:`sphx_glr_tutorials_data_fashion.py` (``data_fashion.py``) - - 00:50.364 + - 00:49.828 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_module.py` (``tensordict_module.py``) - - 00:10.144 + - 00:10.137 - 0.0 * - :ref:`sphx_glr_tutorials_tensorclass_imagenet.py` (``tensorclass_imagenet.py``) - - 00:01.535 + - 00:01.501 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_memory.py` (``tensordict_memory.py``) - - 00:00.023 + - 00:00.025 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_keys.py` (``tensordict_keys.py``) - - 00:00.010 + - 00:00.009 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_shapes.py` (``tensordict_shapes.py``) - 00:00.008 @@ -57,7 +57,7 @@ Computation times - 00:00.007 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_module_functional.py` (``tensordict_module_functional.py``) - - 00:00.007 + - 00:00.006 - 0.0 * - :ref:`sphx_glr_tutorials_tensordict_preallocation.py` (``tensordict_preallocation.py``) - 00:00.003 diff --git a/_sources/tutorials/tensorclass_fashion.rst.txt b/_sources/tutorials/tensorclass_fashion.rst.txt index 340ff1979..f6f03dde4 100644 --- a/_sources/tutorials/tensorclass_fashion.rst.txt +++ b/_sources/tutorials/tensorclass_fashion.rst.txt @@ -97,22 +97,22 @@ the image (e.g. "Bag", "Sneaker" etc.). Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz - 0%| | 0/26421880 [00:00ProbabilisticTensorDictModule>>> td_module = ProbabilisticTensorDictSequential( ... module, normal_params, prob_module ... ) ->>> params = make_functional(td_module, funs_to_decorate=["forward", "get_dist", "log_prob"]) ->>> _ = td_module(td, params=params) +>>> params = TensorDict.from_module(td_module) +>>> with params.to_module(td_module): +... _ = td_module(td) >>> print(td) TensorDict( fields={ @@ -486,13 +487,17 @@

ProbabilisticTensorDictModule batch_size=torch.Size([3]), device=None, is_shared=False) ->>> dist = td_module.get_dist(td, params=params) +>>> with params.to_module(td_module): +... dist = td_module.get_dist(td) >>> print(dist) Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])) >>> # we can also apply the module to the TensorDict with vmap >>> from torch import vmap >>> params = params.expand(4) ->>> td_vmap = vmap(td_module, (None, 0))(td, params) +>>> def func(td, params): +... with params.to_module(td_module): +... return td_module(td) +>>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ diff --git a/reference/generated/tensordict.nn.TensorDictModule.html b/reference/generated/tensordict.nn.TensorDictModule.html index d4e988a46..37b0179ab 100644 --- a/reference/generated/tensordict.nn.TensorDictModule.html +++ b/reference/generated/tensordict.nn.TensorDictModule.html @@ -371,9 +371,6 @@

TensorDictModule class tensordict.nn.TensorDictModule(*args, **kwargs)

A TensorDictModule, is a python wrapper around a nn.Module that reads and writes to a TensorDict.

-

By default, TensorDictModule subclasses are always functional, -meaning that they support the td_module(input, params=params) function -call signature.

Parameters: