Skip to content

Commit

Permalink
add guides
Browse files Browse the repository at this point in the history
  • Loading branch information
innat committed Apr 3, 2024
1 parent d8b5c52 commit 7492cd8
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"none","dataSources":[],"dockerImageVersionId":30673,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":false}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import os\nimport warnings","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2024-04-03T11:39:47.752433Z","iopub.execute_input":"2024-04-03T11:39:47.752872Z","iopub.status.idle":"2024-04-03T11:39:47.786695Z","shell.execute_reply.started":"2024-04-03T11:39:47.752835Z","shell.execute_reply":"2024-04-03T11:39:47.785647Z"},"trusted":true},"execution_count":1,"outputs":[]},{"cell_type":"code","source":"os.environ[\"KERAS_BACKEND\"] = \"torch\"\nwarnings.simplefilter(action=\"ignore\")","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:39:49.500606Z","iopub.execute_input":"2024-04-03T11:39:49.501022Z","iopub.status.idle":"2024-04-03T11:39:49.507410Z","shell.execute_reply.started":"2024-04-03T11:39:49.500990Z","shell.execute_reply":"2024-04-03T11:39:49.506209Z"},"trusted":true},"execution_count":2,"outputs":[]},{"cell_type":"code","source":"!git clone --branch video_swin https://github.com/innat/keras-cv.git\n%cd keras-cv\n!pip install -q -e .\n!pip install -q onnxruntime","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:39:51.228384Z","iopub.execute_input":"2024-04-03T11:39:51.228781Z","iopub.status.idle":"2024-04-03T11:40:46.556937Z","shell.execute_reply.started":"2024-04-03T11:39:51.228751Z","shell.execute_reply":"2024-04-03T11:40:46.555279Z"},"trusted":true},"execution_count":3,"outputs":[{"name":"stdout","text":"Cloning into 'keras-cv'...\nremote: Enumerating objects: 13782, done.\u001b[K\nremote: Counting objects: 100% (1919/1919), done.\u001b[K\nremote: Compressing objects: 100% (769/769), done.\u001b[K\nremote: Total 13782 (delta 1337), reused 1628 (delta 1134), pack-reused 11863\u001b[K\nReceiving objects: 100% (13782/13782), 25.65 MiB | 20.19 MiB/s, done.\nResolving deltas: 100% (9788/9788), done.\n/kaggle/working/keras-cv\n","output_type":"stream"}]},{"cell_type":"code","source":"import numpy as np\n\nimport onnx\nimport onnxruntime\n\nimport torch\nimport keras\nfrom keras import ops\nfrom keras_cv.models import VideoSwinBackbone\nfrom keras_cv.models import VideoClassifier\n\nkeras.__version__, torch.__version__, onnx.__version__, onnxruntime.__version__","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:40:46.559832Z","iopub.execute_input":"2024-04-03T11:40:46.560220Z","iopub.status.idle":"2024-04-03T11:41:14.708819Z","shell.execute_reply.started":"2024-04-03T11:40:46.560186Z","shell.execute_reply":"2024-04-03T11:41:14.707479Z"},"trusted":true},"execution_count":4,"outputs":[{"name":"stderr","text":"2024-04-03 11:40:58.890566: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n2024-04-03 11:40:58.890789: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n2024-04-03 11:40:59.090899: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n","output_type":"stream"},{"execution_count":4,"output_type":"execute_result","data":{"text/plain":"('3.0.5', '2.1.2+cpu', '1.15.0', '1.17.1')"},"metadata":{}}]},{"cell_type":"code","source":"def vswin_tiny():\n !wget https://github.com/innat/VideoSwin/releases/download/v2.0/videoswin_tiny_kinetics400_classifier.weights.h5 -q\n backbone=VideoSwinBackbone(\n input_shape=(32, 224, 224, 3), \n embed_dim=96,\n depths=[2, 2, 6, 2],\n num_heads=[3, 6, 12, 24],\n include_rescaling=False, \n )\n model = VideoClassifier(\n backbone=backbone,\n num_classes=400,\n activation=None,\n pooling='avg',\n )\n model.load_weights(\n 'videoswin_tiny_kinetics400_classifier.weights.h5'\n )\n return model","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:41:14.710550Z","iopub.execute_input":"2024-04-03T11:41:14.711421Z","iopub.status.idle":"2024-04-03T11:41:14.724711Z","shell.execute_reply.started":"2024-04-03T11:41:14.711352Z","shell.execute_reply":"2024-04-03T11:41:14.723686Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"code","source":"model = vswin_tiny()\nmodel.eval()\nmodel.summary()","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:41:14.728735Z","iopub.execute_input":"2024-04-03T11:41:14.729335Z","iopub.status.idle":"2024-04-03T11:41:19.565786Z","shell.execute_reply.started":"2024-04-03T11:41:14.729291Z","shell.execute_reply":"2024-04-03T11:41:19.564529Z"},"trusted":true},"execution_count":6,"outputs":[{"output_type":"display_data","data":{"text/plain":"\u001b[1mModel: \"video_classifier\"\u001b[0m\n","text/html":"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"video_classifier\"</span>\n</pre>\n"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n│ videos (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │\n│ │ \u001b[38;5;34m3\u001b[0m) │ │\n├─────────────────────────────────┼────────────────────────┼───────────────┤\n│ video_swin_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m27,850,470\u001b[0m │\n│ (\u001b[38;5;33mVideoSwinBackbone\u001b[0m) │ │ │\n├─────────────────────────────────┼────────────────────────┼───────────────┤\n│ avg_pool │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n│ (\u001b[38;5;33mGlobalAveragePooling3D\u001b[0m) │ │ │\n├─────────────────────────────────┼────────────────────────┼───────────────┤\n│ predictions (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m400\u001b[0m) │ \u001b[38;5;34m307,600\u001b[0m │\n└─────────────────────────────────┴────────────────────────┴───────────────┘\n","text/html":"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n│ videos (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">224</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">224</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n│ │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3</span>) │ │\n├─────────────────────────────────┼────────────────────────┼───────────────┤\n│ video_swin_backbone │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">16</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">7</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">7</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">768</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">27,850,470</span> │\n│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">VideoSwinBackbone</span>) │ │ │\n├─────────────────────────────────┼────────────────────────┼───────────────┤\n│ avg_pool │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">768</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GlobalAveragePooling3D</span>) │ │ │\n├─────────────────────────────────┼────────────────────────┼───────────────┤\n│ predictions (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">400</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">307,600</span> │\n└─────────────────────────────────┴────────────────────────┴───────────────┘\n</pre>\n"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n","text/html":"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">28,158,070</span> (107.41 MB)\n</pre>\n"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m28,158,070\u001b[0m (107.41 MB)\n","text/html":"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">28,158,070</span> (107.41 MB)\n</pre>\n"},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n","text/html":"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n</pre>\n"},"metadata":{}}]},{"cell_type":"code","source":"def to_numpy(tensor):\n if tensor.requires_grad:\n tensor = tensor.detach()\n tensor = tensor.cpu()\n numpy_array = tensor.numpy()\n return numpy_array","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:41:19.567311Z","iopub.execute_input":"2024-04-03T11:41:19.567696Z","iopub.status.idle":"2024-04-03T11:41:19.575010Z","shell.execute_reply.started":"2024-04-03T11:41:19.567662Z","shell.execute_reply":"2024-04-03T11:41:19.573495Z"},"trusted":true},"execution_count":7,"outputs":[]},{"cell_type":"code","source":"batch_size = 1\nx = torch.randn(batch_size, 32, 224, 224, 3, requires_grad=True)\ntorch_out = model(x)","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:41:19.576604Z","iopub.execute_input":"2024-04-03T11:41:19.577040Z","iopub.status.idle":"2024-04-03T11:41:28.127731Z","shell.execute_reply.started":"2024-04-03T11:41:19.577002Z","shell.execute_reply":"2024-04-03T11:41:28.126396Z"},"trusted":true},"execution_count":8,"outputs":[]},{"cell_type":"code","source":"torch.onnx.export(\n model, # model being run\n x, # model input (or a tuple for multiple inputs)\n \"vswin_tiny.onnx\", \n export_params=True, \n opset_version=10, \n do_constant_folding=True, \n input_names = ['input'], # the model's input names\n output_names = ['output'], # the model's output names\n dynamic_axes={\n 'input' : {0 : 'batch_size'}, \n 'output' : {0 : 'batch_size'}\n }\n)","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:41:28.129359Z","iopub.execute_input":"2024-04-03T11:41:28.129801Z","iopub.status.idle":"2024-04-03T11:42:00.453165Z","shell.execute_reply.started":"2024-04-03T11:41:28.129765Z","shell.execute_reply":"2024-04-03T11:42:00.452097Z"},"trusted":true},"execution_count":9,"outputs":[]},{"cell_type":"code","source":"onnx_model = onnx.load(\"vswin_tiny.onnx\")\nonnx.checker.check_model(onnx_model)\nort_session = onnxruntime.InferenceSession(\n \"vswin_tiny.onnx\", providers=[\"CPUExecutionProvider\"]\n)\n\n# compute ONNX Runtime output prediction\nort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}\nort_outs = ort_session.run(None, ort_inputs)","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:42:00.461211Z","iopub.execute_input":"2024-04-03T11:42:00.462252Z","iopub.status.idle":"2024-04-03T11:42:06.197878Z","shell.execute_reply.started":"2024-04-03T11:42:00.462190Z","shell.execute_reply":"2024-04-03T11:42:06.196447Z"},"trusted":true},"execution_count":10,"outputs":[]},{"cell_type":"code","source":"np.testing.assert_allclose(\n to_numpy(torch_out), ort_outs[0], rtol=1e-05, atol=1e-05\n)","metadata":{"execution":{"iopub.status.busy":"2024-04-03T11:42:06.199756Z","iopub.execute_input":"2024-04-03T11:42:06.201579Z","iopub.status.idle":"2024-04-03T11:42:06.238447Z","shell.execute_reply.started":"2024-04-03T11:42:06.201514Z","shell.execute_reply":"2024-04-03T11:42:06.237213Z"},"trusted":true},"execution_count":11,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]}

0 comments on commit 7492cd8

Please sign in to comment.