Skip to content

Commit

Permalink
Changed Example Model
Browse files Browse the repository at this point in the history
  • Loading branch information
gagewrye committed Nov 23, 2024
1 parent c46ff9e commit 15f9fb3
Showing 1 changed file with 50 additions and 37 deletions.
87 changes: 50 additions & 37 deletions DroneClassification/models/how_to_build_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/gage/anaconda3/envs/i2sb/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"import torch\n",
Expand All @@ -22,70 +31,72 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model = ResNet18_UNet()\n",
"model = ResNet_UNet()\n",
"model.load_state_dict(torch.load('ResNet18_UNet.pth', map_location='cpu', weights_only=True))\n",
"wrapper = SegmentModelWrapper(model)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Single output:\n",
"torch.Size([1, 1, 256, 256])\n",
"tensor([[[[1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 1, 1],\n",
"tensor([[[[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]]]], dtype=torch.uint8)\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8)\n",
"\n",
"\n",
"\n",
"Batch output:\n",
"torch.Size([4, 1, 256, 256])\n",
"tensor([[[[1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 1, 1],\n",
"tensor([[[[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]]],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]]],\n",
"\n",
"\n",
" [[[1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 1, 1],\n",
" [[[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]]],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]]],\n",
"\n",
"\n",
" [[[1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 1, 1],\n",
" [[[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]]],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]]],\n",
"\n",
"\n",
" [[[1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 0, 0],\n",
" [1, 1, 1, ..., 1, 1, 1],\n",
" [[[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]]]], dtype=torch.uint8)\n"
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]]]], dtype=torch.uint8)\n"
]
}
],
Expand All @@ -95,18 +106,20 @@
"\n",
"wrapper.eval()\n",
"output = wrapper(test_input)\n",
"print(\"Single output:\")\n",
"print(output.shape)\n",
"print(output)\n",
"output = wrapper(batch_test)\n",
"print(\"\\n\\n\")\n",
"print(\"Batch output:\")\n",
"print(output.shape)\n",
"print(output)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mangrove",
"display_name": "i2sb",
"language": "python",
"name": "python3"
},
Expand All @@ -120,7 +133,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.8.18"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 15f9fb3

Please sign in to comment.