{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# YOLO to ONNX" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "### init\n", "# !pip install -r https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt\n", "import json\n", "import base64\n", "from io import BytesIO\n", "import os\n", "from datetime import datetime\n", "from glob import glob\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from PIL import Image, ImageDraw, ImageFont\n", "import torch\n", "import torch.onnx\n", "import onnx\n", "import onnxruntime" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "300\n" ] } ], "source": [ "# 推論用画像\n", "paths = sorted(glob('mask_data/test_answer/images/*.png'))\n", "print(len(paths))\n", "# 画像の読み込み\n", "imgs = []\n", "for p in paths:\n", " img = Image.open(p)\n", " imgs.append(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction with PyTorch" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Size (MB): 14.465205\n" ] } ], "source": [ "# モデルサイズ確認\n", "print('Size (MB):', os.path.getsize('yolov5_best.pt')/1e6)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using cache found in /Users/taichinakabeppu/.cache/torch/hub/ultralytics_yolov5_master\n", "YOLOv5 🚀 2022-4-28 torch 1.9.0 CPU\n", "\n", "Fusing layers... \n", "[W NNPACK.cpp:79] Could not initialize NNPACK! Reason: Unsupported hardware.\n", "Model summary: 213 layers, 7018216 parameters, 0 gradients, 15.8 GFLOPs\n", "Adding AutoShape... \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "# 学習済みモデルの読み込み\n", "model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5_best.pt', device='cpu')\n", "print(model.training)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Runtime = 0:03:01.582600\n" ] } ], "source": [ "# 推論\n", "start = datetime.now()\n", "model.eval()\n", "with torch.no_grad():\n", " # バッチ処理\n", " results = model(imgs)\n", "end = datetime.now()\n", "# 所要時間\n", "print('Runtime =', end-start)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | file | \n", "xmin | \n", "ymin | \n", "xmax | \n", "ymax | \n", "confidence | \n", "class | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "maksssksksss500.png | \n", "110 | \n", "230 | \n", "211 | \n", "369 | \n", "0.959684 | \n", "with_mask | \n", "
1 | \n", "maksssksksss501.png | \n", "354 | \n", "64 | \n", "400 | \n", "132 | \n", "0.948252 | \n", "with_mask | \n", "
2 | \n", "maksssksksss501.png | \n", "45 | \n", "37 | \n", "117 | \n", "126 | \n", "0.948185 | \n", "with_mask | \n", "
3 | \n", "maksssksksss501.png | \n", "301 | \n", "51 | \n", "349 | \n", "117 | \n", "0.945372 | \n", "with_mask | \n", "
4 | \n", "maksssksksss501.png | \n", "164 | \n", "47 | \n", "219 | \n", "123 | \n", "0.943369 | \n", "with_mask | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1422 | \n", "maksssksksss799.png | \n", "86 | \n", "239 | \n", "131 | \n", "284 | \n", "0.914433 | \n", "with_mask | \n", "
1423 | \n", "maksssksksss799.png | \n", "30 | \n", "64 | \n", "46 | \n", "80 | \n", "0.914345 | \n", "with_mask | \n", "
1424 | \n", "maksssksksss799.png | \n", "317 | \n", "248 | \n", "333 | \n", "265 | \n", "0.897670 | \n", "with_mask | \n", "
1425 | \n", "maksssksksss799.png | \n", "340 | \n", "248 | \n", "365 | \n", "274 | \n", "0.892757 | \n", "with_mask | \n", "
1426 | \n", "maksssksksss799.png | \n", "304 | \n", "209 | \n", "324 | \n", "230 | \n", "0.858823 | \n", "with_mask | \n", "
1427 rows × 7 columns
\n", "\n", " | file | \n", "xmin | \n", "ymin | \n", "xmax | \n", "ymax | \n", "confidence | \n", "class | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "maksssksksss500.png | \n", "110 | \n", "230 | \n", "211 | \n", "369 | \n", "0.959685 | \n", "with_mask | \n", "
1 | \n", "maksssksksss501.png | \n", "354 | \n", "64 | \n", "400 | \n", "132 | \n", "0.948251 | \n", "with_mask | \n", "
2 | \n", "maksssksksss501.png | \n", "45 | \n", "37 | \n", "117 | \n", "126 | \n", "0.948185 | \n", "with_mask | \n", "
3 | \n", "maksssksksss501.png | \n", "301 | \n", "51 | \n", "349 | \n", "117 | \n", "0.945373 | \n", "with_mask | \n", "
4 | \n", "maksssksksss501.png | \n", "164 | \n", "47 | \n", "219 | \n", "123 | \n", "0.943369 | \n", "with_mask | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1422 | \n", "maksssksksss799.png | \n", "86 | \n", "239 | \n", "131 | \n", "284 | \n", "0.914434 | \n", "with_mask | \n", "
1423 | \n", "maksssksksss799.png | \n", "30 | \n", "64 | \n", "46 | \n", "80 | \n", "0.914345 | \n", "with_mask | \n", "
1424 | \n", "maksssksksss799.png | \n", "317 | \n", "248 | \n", "333 | \n", "265 | \n", "0.897670 | \n", "with_mask | \n", "
1425 | \n", "maksssksksss799.png | \n", "340 | \n", "248 | \n", "365 | \n", "274 | \n", "0.892757 | \n", "with_mask | \n", "
1426 | \n", "maksssksksss799.png | \n", "304 | \n", "209 | \n", "324 | \n", "230 | \n", "0.858823 | \n", "with_mask | \n", "
1427 rows × 7 columns
\n", "\n", " | file | \n", "xmin | \n", "ymin | \n", "xmax | \n", "ymax | \n", "confidence | \n", "class | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "maksssksksss500.png | \n", "110 | \n", "230 | \n", "211 | \n", "369 | \n", "0.959684 | \n", "with_mask | \n", "
1 | \n", "maksssksksss501.png | \n", "354 | \n", "64 | \n", "400 | \n", "132 | \n", "0.948252 | \n", "with_mask | \n", "
2 | \n", "maksssksksss501.png | \n", "45 | \n", "37 | \n", "117 | \n", "126 | \n", "0.948185 | \n", "with_mask | \n", "
3 | \n", "maksssksksss501.png | \n", "301 | \n", "51 | \n", "349 | \n", "117 | \n", "0.945372 | \n", "with_mask | \n", "
4 | \n", "maksssksksss501.png | \n", "164 | \n", "47 | \n", "219 | \n", "123 | \n", "0.943369 | \n", "with_mask | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1422 | \n", "maksssksksss799.png | \n", "86 | \n", "239 | \n", "131 | \n", "284 | \n", "0.914433 | \n", "with_mask | \n", "
1423 | \n", "maksssksksss799.png | \n", "30 | \n", "64 | \n", "46 | \n", "80 | \n", "0.914345 | \n", "with_mask | \n", "
1424 | \n", "maksssksksss799.png | \n", "317 | \n", "248 | \n", "333 | \n", "265 | \n", "0.897670 | \n", "with_mask | \n", "
1425 | \n", "maksssksksss799.png | \n", "340 | \n", "248 | \n", "365 | \n", "274 | \n", "0.892757 | \n", "with_mask | \n", "
1426 | \n", "maksssksksss799.png | \n", "304 | \n", "209 | \n", "324 | \n", "230 | \n", "0.858823 | \n", "with_mask | \n", "
1427 rows × 7 columns
\n", "\n", " | xmin | \n", "ymin | \n", "xmax | \n", "ymax | \n", "confidence | \n", "with_mask | \n", "mask_weared_incorrect | \n", "without_mask | \n", "
---|---|---|---|---|---|---|---|---|
0 | \n", "4.214570 | \n", "4.537497 | \n", "10.313617 | \n", "15.985667 | \n", "8.672476e-06 | \n", "0.700935 | \n", "0.045178 | \n", "0.297051 | \n", "
1 | \n", "12.776090 | \n", "5.676386 | \n", "24.197996 | \n", "15.200418 | \n", "2.652407e-06 | \n", "0.548981 | \n", "0.055967 | \n", "0.362423 | \n", "
2 | \n", "18.680874 | \n", "4.333120 | \n", "33.507317 | \n", "12.852788 | \n", "2.145767e-06 | \n", "0.561048 | \n", "0.037742 | \n", "0.448910 | \n", "
3 | \n", "26.508099 | \n", "4.274238 | \n", "35.192181 | \n", "12.477221 | \n", "9.238720e-07 | \n", "0.632332 | \n", "0.032041 | \n", "0.417051 | \n", "
4 | \n", "33.960983 | \n", "4.160562 | \n", "35.272701 | \n", "11.803977 | \n", "5.662441e-07 | \n", "0.651438 | \n", "0.028722 | \n", "0.364463 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
25195 | \n", "493.527527 | \n", "604.938293 | \n", "866.446777 | \n", "790.824097 | \n", "7.003546e-06 | \n", "0.018933 | \n", "0.558842 | \n", "0.431098 | \n", "
25196 | \n", "523.272339 | \n", "603.406799 | \n", "581.229065 | \n", "550.291077 | \n", "3.099442e-06 | \n", "0.162965 | \n", "0.322442 | \n", "0.236716 | \n", "
25197 | \n", "570.945251 | \n", "606.444519 | \n", "347.007294 | \n", "260.588562 | \n", "1.093745e-05 | \n", "0.297493 | \n", "0.242286 | \n", "0.226423 | \n", "
25198 | \n", "586.689819 | \n", "603.715576 | \n", "209.623520 | \n", "170.480194 | \n", "1.293421e-05 | \n", "0.382226 | \n", "0.177253 | \n", "0.244768 | \n", "
25199 | \n", "607.628784 | \n", "607.436646 | \n", "283.054199 | \n", "238.678299 | \n", "3.635883e-06 | \n", "0.313055 | \n", "0.229774 | \n", "0.268563 | \n", "
25200 rows × 8 columns
\n", "\n", " | file | \n", "xmin | \n", "ymin | \n", "xmax | \n", "ymax | \n", "confidence | \n", "class | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "maksssksksss500.png | \n", "111 | \n", "231 | \n", "212 | \n", "370 | \n", "0.958664 | \n", "with_mask | \n", "
1 | \n", "maksssksksss501.png | \n", "44 | \n", "37 | \n", "117 | \n", "126 | \n", "0.951905 | \n", "with_mask | \n", "
2 | \n", "maksssksksss501.png | \n", "160 | \n", "47 | \n", "220 | \n", "122 | \n", "0.950687 | \n", "with_mask | \n", "
3 | \n", "maksssksksss501.png | \n", "353 | \n", "64 | \n", "400 | \n", "132 | \n", "0.949953 | \n", "with_mask | \n", "
4 | \n", "maksssksksss501.png | \n", "302 | \n", "52 | \n", "349 | \n", "117 | \n", "0.943344 | \n", "with_mask | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1297 | \n", "maksssksksss799.png | \n", "44 | \n", "41 | \n", "64 | \n", "65 | \n", "0.904429 | \n", "with_mask | \n", "
1298 | \n", "maksssksksss799.png | \n", "85 | \n", "238 | \n", "133 | \n", "283 | \n", "0.899709 | \n", "with_mask | \n", "
1299 | \n", "maksssksksss799.png | \n", "317 | \n", "248 | \n", "332 | \n", "264 | \n", "0.875682 | \n", "with_mask | \n", "
1300 | \n", "maksssksksss799.png | \n", "304 | \n", "211 | \n", "324 | \n", "229 | \n", "0.867625 | \n", "with_mask | \n", "
1301 | \n", "maksssksksss799.png | \n", "341 | \n", "247 | \n", "366 | \n", "275 | \n", "0.386450 | \n", "with_mask | \n", "
1302 rows × 7 columns
\n", "\n", " | file | \n", "xmin | \n", "ymin | \n", "xmax | \n", "ymax | \n", "confidence | \n", "class | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "maksssksksss500.png | \n", "111 | \n", "231 | \n", "212 | \n", "370 | \n", "0.958664 | \n", "with_mask | \n", "
1 | \n", "maksssksksss501.png | \n", "44 | \n", "37 | \n", "117 | \n", "126 | \n", "0.951905 | \n", "with_mask | \n", "
2 | \n", "maksssksksss501.png | \n", "160 | \n", "47 | \n", "220 | \n", "122 | \n", "0.950687 | \n", "with_mask | \n", "
3 | \n", "maksssksksss501.png | \n", "353 | \n", "64 | \n", "400 | \n", "132 | \n", "0.949953 | \n", "with_mask | \n", "
4 | \n", "maksssksksss501.png | \n", "302 | \n", "52 | \n", "349 | \n", "117 | \n", "0.943344 | \n", "with_mask | \n", "
5 | \n", "maksssksksss501.png | \n", "0 | \n", "44 | \n", "42 | \n", "126 | \n", "0.941737 | \n", "with_mask | \n", "
6 | \n", "maksssksksss501.png | \n", "232 | \n", "38 | \n", "297 | \n", "108 | \n", "0.926159 | \n", "with_mask | \n", "
7 | \n", "maksssksksss502.png | \n", "177 | \n", "57 | \n", "244 | \n", "123 | \n", "0.961261 | \n", "without_mask | \n", "
8 | \n", "maksssksksss502.png | \n", "59 | \n", "73 | \n", "118 | \n", "130 | \n", "0.951826 | \n", "without_mask | \n", "
9 | \n", "maksssksksss502.png | \n", "346 | \n", "103 | \n", "398 | \n", "169 | \n", "0.948219 | \n", "with_mask | \n", "