diff --git a/DroneClassification/models/how_to_build_model.ipynb b/DroneClassification/models/how_to_build_model.ipynb index 9ebda95..14a5ef5 100755 --- a/DroneClassification/models/how_to_build_model.ipynb +++ b/DroneClassification/models/how_to_build_model.ipynb @@ -11,9 +11,18 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/gage/anaconda3/envs/i2sb/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import numpy as np\n", "import torch\n", @@ -22,70 +31,72 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "model = ResNet18_UNet()\n", + "model = ResNet_UNet()\n", "model.load_state_dict(torch.load('ResNet18_UNet.pth', map_location='cpu', weights_only=True))\n", "wrapper = SegmentModelWrapper(model)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "Single output:\n", "torch.Size([1, 1, 256, 256])\n", - "tensor([[[[1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", + "tensor([[[[0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", " ...,\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0]]]], dtype=torch.uint8)\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8)\n", "\n", "\n", "\n", + "Batch output:\n", "torch.Size([4, 1, 256, 256])\n", - "tensor([[[[1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", + "tensor([[[[0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", " ...,\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0]]],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0]]],\n", "\n", "\n", - " [[[1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", + " [[[0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", " ...,\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0]]],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0]]],\n", "\n", "\n", - " [[[1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", + " [[[0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", " ...,\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0]]],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0]]],\n", "\n", "\n", - " [[[1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 0, 0],\n", - " [1, 1, 1, ..., 1, 1, 1],\n", + " [[[0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", " ...,\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0],\n", - " [1, 1, 1, ..., 0, 0, 0]]]], dtype=torch.uint8)\n" + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8)\n" ] } ], @@ -95,10 +106,12 @@ "\n", "wrapper.eval()\n", "output = wrapper(test_input)\n", + "print(\"Single output:\")\n", "print(output.shape)\n", "print(output)\n", "output = wrapper(batch_test)\n", "print(\"\\n\\n\")\n", + "print(\"Batch output:\")\n", "print(output.shape)\n", "print(output)\n" ] @@ -106,7 +119,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mangrove", + "display_name": "i2sb", "language": "python", "name": "python3" }, @@ -120,7 +133,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.8.18" } }, "nbformat": 4,