{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Vision Transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "原論文 \n", "https://openreview.net/forum?id=YicbFdNTTy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## アルゴリズムの理解" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ViT は大きく以下の 3 つの部分で構成されている(画像クラス分類の場合)\n", "\n", "- Input Layer\n", " - 入力(画像)をパッチに分割\n", " - 「クラストークン」 と 「パッチ」 のベクトルを出力\n", "- Encoder\n", " - Self-Attention 処理\n", " - 「クラストークン」 を出力\n", "- MLPHead\n", " - 入力画像に対するラベルを予測(クラス分類器)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_01](image/vit_01.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Input Layer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. パッチに分割\n", "2. 埋め込み (Embedding)\n", "3. CLS (Class Token)\n", "4. 位置埋め込み (Positional Embedding)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "説明のため パッチ 4 つ で図解" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- パッチに分割" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_patch](image/vit_patch.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 埋め込み" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_patch2emb](image/vit_patch2emb.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- クラストークンと位置埋め込み" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_patch2clsemb](image/vit_patch2clsemb.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 実装" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "gather": { "logged": 1667634927256 } }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "gather": { "logged": 1667634927440 } }, "outputs": [], "source": [ "class VitInputLayer(nn.Module):\n", " def __init__(self,\n", " in_channels:int=3,\n", " emb_dim:int=384,\n", " nump_patch_row:int=2,\n", " image_size:int=32):\n", "\n", " super(VitInputLayer, self).__init__()\n", " self.in_channels = in_channels\n", " self.emb_dim = emb_dim\n", " self.num_patch_row = nump_patch_row\n", " self.image_size = image_size\n", "\n", " # パッチ数\n", " self.num_patch = self.num_patch_row**2\n", " # パッチのサイズ : 画像 1 辺が 32 なら patch_size = 16\n", " self.patch_size = int(self.image_size // self.num_patch_row)\n", "\n", " # 入力画像のパッチ分割・埋め込み\n", " self.patch_emb_layer = nn.Conv2d(\n", " in_channels=self.in_channels,\n", " out_channels=self.emb_dim,\n", " kernel_size=self.patch_size,\n", " stride=self.patch_size)\n", "\n", " # CLS\n", " self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))\n", "\n", " # Position Embedding\n", " # CLS が先頭に結合されているため長さ emb_dim の位置埋め込みベクトルを(パッチ数 +1)個用意\n", " self.pos_emb = nn.Parameter(torch.randn(1, self.num_patch+1, emb_dim))\n", "\n", " def forward(self, x:torch.Tensor) -> torch.Tensor:\n", " # パッチの埋め込み & flatten\n", "\n", " ## Patch の埋め込み (B, C, H, W) -> (B, D, H/P, W/P)\n", " z_0 = self.patch_emb_layer(x)\n", "\n", " ## パッチの flatten (B, D, H/P, W/P) -> (B, D, Np)\n", " ## Np はパッチの数 (=H*W/P^2)\n", " z_0 = z_0.flatten(2)\n", "\n", " ## 軸の入れ替え\n", " z_0 = z_0.transpose(1, 2)\n", "\n", " # パッチの埋め込みの先頭に CLS を結合\n", " ## (B, Np, D) -> (B, N, D) : N = (Np + 1)\n", " ## cls_token は (1, 1, D) なので repeat で (B, 1, D) に変換(複製)して結合する\n", " z_0 = torch.cat([self.cls_token.repeat(repeats=(x.size(0), 1, 1)), z_0], dim=1)\n", "\n", " # Position Embedding の加算\n", " ## (B, N, D) -> (B, N, D)\n", " z_0 = z_0 + self.pos_emb\n", "\n", " return z_0" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "gather": { "logged": 1667634927589 } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 5, 384])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# check\n", "x = torch.randn(1, 3, 32, 32)\n", "input_layer = VitInputLayer()\n", "z_0 = input_layer(x)\n", "z_0.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder\n", "#### Self-Attention(自己注意)\n", "1. パッチ内の情報の抽出\n", " - -> 埋め込み\n", "2. 自分との類似度測定\n", " - -> ベクトル同士の内積\n", "3. 類似度に基づいた合体\n", " - -> 内積の値を係数にした加重和\n", " - -> 係数 : 内積の Softmax で算出\n", " - -> 加重和 : Attention Weight" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Self-Attention のイメージ" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_self_attention_image](image/vit_self_attention_image.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Self-Attention でも、埋め込みによって情報の抽出を行う。 \n", "線形層を 3 つ用意し、それぞれの線形層で埋め込んだあとの各ベクトルを以下のように呼ぶ。\n", "\n", "- q (query)\n", "- k (key)\n", "- v (value)\n", "\n", "q, k, v ともに全く同じベクトルから埋め込んだ結果だが、それぞれ異なる線形層を用いて埋め込まれているため、異なる値を取る。 \n", "\n", "q, k, v に分ける表現は動画サイトの動画検索に例えるとわかりやすい。\n", "\n", "- q : 検索キーワード\n", "- k : 動画タイトル\n", "- v : 動画\n", "\n", "検索キーワードから動画を検索する際は、検索キーワードと動画のタイトルの一致度を見る。 \n", "Self-Attention も同様に、q, k の類似度を計算し、その類似度をもとに v の加重和を行う。\n", "\n", "内積 -> 行列積 (ソフトマックスで正規化) -> 類似度になる。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_self_attention](image/vit_self_attention.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-Head Self-Attention\n", "パッチ同士の関係は 1 つの Attension Weight が保持している。 \n", "この Attention Weight が複数あれば、各パッチ間の関係を、Attention Weight の数だけ学習できる。 \n", "1 つのパッチから複数個の q, k, v を埋め込み、複数の Attention Weight を獲得すれば良い。 \n", "ハイパーパラメータ「ヘッドの数」で指定する。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "gather": { "logged": 1667634927869 } }, "outputs": [], "source": [ "class MultiHeadSelfAttention(nn.Module):\n", " def __init__(self,\n", " emb_dim:int=384,\n", " head:int=3,\n", " dropout:float=0.):\n", "\n", " super(MultiHeadSelfAttention, self).__init__()\n", " self.head = head\n", " self.emb_dim = emb_dim\n", " self.head_dim = emb_dim // head\n", " # D_h の 二乗根:qk^T を割るための係数\n", " self.sqrt_dh = self.head**0.5\n", "\n", " # 入力を query, key, value に埋め込むための線形層\n", " self.w_q = nn.Linear(emb_dim, emb_dim, bias=False)\n", " self.w_k = nn.Linear(emb_dim, emb_dim, bias=False)\n", " self.w_v = nn.Linear(emb_dim, emb_dim, bias=False)\n", "\n", " # Dropout\n", " self.attn_drop = nn.Dropout(dropout)\n", "\n", " # MHSA の結果を出力に埋め込むための線形層\n", " self.w_o = nn.Sequential(\n", " nn.Linear(emb_dim, emb_dim),\n", " nn.Dropout(dropout))\n", "\n", " def forward(self, z:torch.Tensor) -> torch.Tensor:\n", " batch_size, num_patch, _ = z.size()\n", "\n", " # 埋め込み:(B, N, D) -> (B, N, D)\n", " q = self.w_q(z)\n", " k = self.w_k(z)\n", " v = self.w_v(z)\n", "\n", " # (q, k, v) を head に分ける\n", " ## まずベクトルを head の個数に分ける\n", " ## (B, N, D) -> (B, N, h, D//h)\n", " q = q.view(batch_size, num_patch, self.head, self.head_dim)\n", " k = k.view(batch_size, num_patch, self.head, self.head_dim)\n", " v = v.view(batch_size, num_patch, self.head, self.head_dim)\n", " ## Self-Attention ができるように(バッチサイズ、ヘッド、トークン数、バッチのベクトル)の形状にする\n", " ## (B, N, h, D//h) -> (B, h, N, D///h)\n", " q = q.transpose(1, 2)\n", " k = k.transpose(1, 2)\n", " v = v.transpose(1, 2)\n", "\n", " # 内積\n", " ## (B, h, N, D//h) -> (B, h, D//h, N)\n", " k_T = k.transpose(2, 3)\n", " ## (B, h, N, D//h) x (B, h, D//h, N) -> (B, h, N, N)\n", " dots = (q @ k_T) / self.sqrt_dh\n", " ## 列方向にソフトマックス\n", " attn = F.softmax(dots, dim=-1)\n", " attn = self.attn_drop(attn)\n", "\n", " # 加重和\n", " ## (B, h, N, N) x (B, h, N, D//h) -> (B, h, N, D//h)\n", " out = attn @ v\n", " ## (B, h, N, D//h) -> (B, N, h, D//h)\n", " out = out.transpose(1, 2)\n", " ## (B, N, h, D//h) -> (B, N, D)\n", " out = out.reshape(batch_size, num_patch, self.emb_dim)\n", "\n", " # 出力層\n", " ## (B, N, D) -> (B, N, D)\n", " out = self.w_o(out)\n", "\n", " return out" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "gather": { "logged": 1667634928013 } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 5, 384])\n", "torch.Size([1, 5, 384])\n" ] } ], "source": [ "# check\n", "print(z_0.shape)\n", "mhsa = MultiHeadSelfAttention()\n", "out = mhsa(z_0)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder Block" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- LayerNormalization\n", "- Multi-Head Self-Attention\n", "- MLP(活性化関数:GERU)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_encoder_block](image/vit_encoder_block.png)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "gather": { "logged": 1667634928152 } }, "outputs": [], "source": [ "class VitEncoderBlock(nn.Module):\n", " def __init__(self,\n", " emb_dim:int=384,\n", " head:int=8,\n", " hidden_dim:int=384*4,\n", " dropout:float=0.\n", " ):\n", "\n", " super(VitEncoderBlock, self).__init__()\n", "\n", " # 1 つ目の LayerNorm\n", " self.ln1 = nn.LayerNorm(emb_dim)\n", " # mhsa\n", " self.msa = MultiHeadSelfAttention(\n", " emb_dim=emb_dim,\n", " head=head,\n", " dropout=dropout\n", " )\n", "\n", " # 2 つ目の LayerNorm\n", " self.ln2 = nn.LayerNorm(emb_dim)\n", " # MLP\n", " self.mlp = nn.Sequential(\n", " nn.Linear(emb_dim, hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, emb_dim),\n", " nn.Dropout(dropout)\n", " )\n", "\n", " def forward(self, z:torch.Tensor) -> torch.Tensor:\n", " # Encoder Block の前半\n", " out = self.msa(self.ln1(z)) + z\n", " # Encoder Block の後半\n", " out = self.mlp(self.ln2(out)) + out\n", "\n", " return out" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "gather": { "logged": 1667634928319 } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 5, 384])\n" ] } ], "source": [ "# check\n", "vit_enc = VitEncoderBlock()\n", "z_1 = vit_enc(z_0)\n", "print(z_1.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ViT 全体" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Input Layer\n", "- Encoder\n", "- MLP Head" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit_mlp](image/vit_mlp.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 全体像" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![vit](image/vit.png)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "gather": { "logged": 1667634928446 } }, "outputs": [], "source": [ "class Vit(nn.Module):\n", " def __init__(self,\n", " in_channels:int=3,\n", " num_classes:int=10,\n", " emb_dim:int=384,\n", " num_patch_row:int=2,\n", " image_size:int=32,\n", " num_blocks:int=7,\n", " head:int=8,\n", " hidden_dim:int=384*4,\n", " dropout:float=0.\n", " ):\n", "\n", " super(Vit, self).__init__()\n", "\n", " # Input Layer\n", " self.input_layer = VitInputLayer(\n", " in_channels,\n", " emb_dim,\n", " num_patch_row,\n", " image_size)\n", "\n", " # Encoder (Encoder Block の多段)\n", " self.encoder = nn.Sequential(*[\n", " VitEncoderBlock(\n", " emb_dim=emb_dim,\n", " head=head,\n", " hidden_dim=hidden_dim,\n", " dropout=dropout\n", " )\n", " for _ in range(num_blocks)])\n", "\n", " # MLP Head\n", " self.mlp_head = nn.Sequential(\n", " nn.LayerNorm(emb_dim),\n", " nn.Linear(emb_dim, num_classes)\n", " )\n", "\n", " def forward(self, x:torch.Tensor) -> torch.Tensor:\n", " # Input Layer\n", " ## (B, C, H, W) -> (B, N, D)\n", " ## N: トークン数(パッチ数 +1)D: ベクトルの長さ\n", " out = self.input_layer(x)\n", "\n", " # Encoder\n", " ## (B, N, D) -> (B, N, D)\n", " out = self.encoder(out)\n", "\n", " # クラストークンのみ抜き出す\n", " ## (B, N, D) -> (B, D)\n", " cls_token = out[:, 0]\n", "\n", " # MLP Head\n", " ## (B, D) -> (B, M)\n", " pred = self.mlp_head(cls_token)\n", "\n", " return pred" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "gather": { "logged": 1667634928590 } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 10])\n" ] } ], "source": [ "# check\n", "x = torch.randn(1, 3, 32, 32)\n", "vit = Vit()\n", "pred = vit(x)\n", "print(pred.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 犬猫画像分類を試し" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "gather": { "logged": 1667634928891 } }, "outputs": [], "source": [ "import os\n", "import random\n", "import numpy as np\n", "import pandas as pd\n", "\n", "# torch\n", "import torch\n", "from torch import nn\n", "from torch.optim import Adam\n", "from torch.optim.optimizer import Optimizer\n", "from torch.utils import data\n", "\n", "# torchvision\n", "from torchvision import transforms as T\n", "\n", "# scikit-learn\n", "# from sklearn.metrics import mean_squared_error\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "gather": { "logged": 1667634929027 } }, "outputs": [], "source": [ "def seed_torch(seed=0):\n", " random.seed(seed)\n", " os.environ['PYTHONHASHSEED'] = str(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "gather": { "logged": 1667634929163 } }, "outputs": [], "source": [ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "gather": { "logged": 1667634929331 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } }, "outputs": [], "source": [ "# !unzip -q ./dog_cat_data.zip" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "gather": { "logged": 1667634929498 } }, "outputs": [], "source": [ "from glob import glob\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "gather": { "logged": 1667634929638 } }, "outputs": [ { "data": { "text/plain": [ "300" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dog_filepaths = sorted(glob('./dog_cat_data/train/dog/*.jpg'))\n", "cat_filepaths = sorted(glob('./dog_cat_data/train/cat/*.jpg'))\n", "paths = dog_filepaths + cat_filepaths\n", "len(paths)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "gather": { "logged": 1667634929789 } }, "outputs": [], "source": [ "class MyDataset(data.Dataset):\n", " def __init__(self, paths):\n", " self.paths = paths\n", " self.transform = T.Compose([\n", " T.Resize(256),\n", " T.CenterCrop(224),\n", " T.ToTensor(),\n", " T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])\n", " self.labels = [0 if p.split('/')[-2] == 'cat' else 1 for p in self.paths]\n", " def __getitem__(self, idx):\n", " path = self.paths[idx]\n", " img = Image.open(path).convert('RGB')\n", " img_transformed = self.transform(img)\n", " label = self.labels[idx]\n", " return img_transformed, label\n", " def __len__(self):\n", " return len(self.paths)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "gather": { "logged": 1667634929932 } }, "outputs": [ { "data": { "text/plain": [ "(210, 90)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = MyDataset(paths=paths)\n", "n_train = int(len(dataset) * 0.7)\n", "n_val = len(dataset) - n_train\n", "train, val = data.random_split(dataset, [n_train, n_val])\n", "len(train), len(val)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "gather": { "logged": 1667634930150 } }, "outputs": [], "source": [ "batch_size = 32\n", "train_loader = data.DataLoader(train, batch_size, shuffle=True, drop_last=True)\n", "val_loader = data.DataLoader(train, batch_size)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "gather": { "logged": 1667634930304 } }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32, 3, 224, 224]), torch.Size([32]))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# check\n", "x, t = next(iter(train_loader))\n", "x.shape, t.shape" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "gather": { "logged": 1667634931055 } }, "outputs": [], "source": [ "def train(\n", " model: nn.Module,\n", " optimizer: Optimizer,\n", " train_loader: data.DataLoader):\n", "\n", " model.train()\n", " # criterion = nn.MSELoss() # 回帰\n", " criterion = nn.CrossEntropyLoss() # 分類\n", " epoch_loss = 0.0\n", " epoch_accuracy = 0.0\n", "\n", " for i, (x_i, y_i) in enumerate(train_loader):\n", " x_i = x_i.to(DEVICE, dtype=torch.float32)\n", " # y_i = y_i.to(DEVICE, dtype=torch.float32).reshape(-1, 1) # 回帰\n", " y_i = y_i.to(DEVICE, dtype=torch.int64) # 分類\n", " output = model(x_i)\n", " loss = criterion(output, y_i)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " accuracy = (output.argmax(dim=1) == y_i).float().mean()\n", " epoch_loss += loss\n", " epoch_accuracy += accuracy\n", " optimizer.step()\n", " return epoch_loss / len(train_loader), epoch_accuracy / len(train_loader)\n", "\n", "def valid(model: nn.Module, valid_loader: data.DataLoader):\n", " model.eval()\n", " criterion = nn.CrossEntropyLoss()\n", " valid_loss = 0.0\n", " valid_accuracy = 0.0\n", " for x_i, y_i in valid_loader:\n", " x_i = x_i.to(DEVICE, dtype=torch.float32)\n", " y_i = y_i.to(DEVICE, dtype=torch.int64)\n", " with torch.no_grad():\n", " output = model(x_i)\n", " loss = criterion(output, y_i)\n", " accuracy = (output.argmax(dim=1) == y_i).float().mean()\n", " valid_loss += loss\n", " valid_accuracy += accuracy\n", " return valid_loss / len(valid_loader), valid_accuracy / len(valid_loader)" ] }, { "cell_type": "markdown", "metadata": { "nteract": { "transient": { "deleting": false } } }, "source": [ "補足:学習率のスケジューラ" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "gather": { "logged": 1667634931964 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/anaconda/envs/azureml_py38/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n", " warnings.warn(\"Detected call of `lr_scheduler.step()` before `optimizer.step()`. \"\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ExponentialLR\n", "model = Vit(\n", " in_channels=3,\n", " num_classes=2,\n", " num_patch_row=8,\n", " image_size=224,\n", " dropout=.1)\n", "optimizer = Adam(model.parameters(), lr=1)\n", "\n", "schedulers = [\n", " lambda optim: CosineAnnealingLR(optim, T_max=10), # 半周期 10 の cosine\n", " lambda optim: StepLR(optim, step_size=30, gamma=.2), # 30 epoch ごとに学習率を 0.05 倍\n", " lambda optim: ExponentialLR(optim, gamma=.95) # 毎 epoch ごとに 0.95 倍\n", "]\n", "\n", "epochs = list(range(100))\n", "fig, ax = plt.subplots(figsize=(10, 6))\n", "\n", "for get_scheduler in schedulers:\n", " rates = []\n", " sche = get_scheduler(Adam(model.parameters(), lr=1.))\n", "\n", " for i in epochs:\n", " rates.append(sche.get_last_lr()[0])\n", " sche.step()\n", "\n", " ax.step(epochs, rates, label=type(sche))\n", "\n", "# ax.set_yscale('log')\n", "ax.grid()\n", "ax.legend();" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "gather": { "logged": 1667634932116 } }, "outputs": [], "source": [ "def run_fold(\n", " model: nn.Module,\n", " train_loader: data.DataLoader,\n", " valid_loader: data.DataLoader,\n", " n_epochs=50) -> np.ndarray:\n", "\n", " optimizer = Adam(model.parameters(), lr=1e-2)\n", " scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-4)\n", "\n", " for epoch in range(1, n_epochs + 1):\n", " print(f'epoch: {epoch} lr: {scheduler.get_last_lr()[0]:.4f}')\n", " train_loss, train_acc = train(model, optimizer, train_loader)\n", " valid_loss, valid_acc = valid(model=model, valid_loader=valid_loader)\n", " scheduler.step()\n", " print(f'score: train_loss {train_loss:.3f} train_acc {train_acc:.3f} valid_loss {valid_loss:.3f} valid_acc {valid_acc:.3f}')" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "gather": { "logged": 1667634933535 } }, "outputs": [], "source": [ "seed_torch(0)\n", "vit = Vit(\n", " in_channels=3,\n", " num_classes=2,\n", " num_patch_row=8,\n", " image_size=224,\n", " dropout=.1)\n", "vit = vit.to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "gather": { "logged": 1667635159384 } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 lr: 0.0100\n", "score: train_loss 2.258 train_acc 0.484 valid_loss 0.861 valid_acc 0.470\n", "epoch: 2 lr: 0.0091\n", "score: train_loss 0.777 train_acc 0.536 valid_loss 0.694 valid_acc 0.522\n", "epoch: 3 lr: 0.0066\n", "score: train_loss 0.701 train_acc 0.531 valid_loss 0.680 valid_acc 0.539\n", "epoch: 4 lr: 0.0035\n", "score: train_loss 0.695 train_acc 0.531 valid_loss 0.681 valid_acc 0.568\n", "epoch: 5 lr: 0.0010\n", "score: train_loss 0.707 train_acc 0.531 valid_loss 0.687 valid_acc 0.535\n", "epoch: 6 lr: 0.0001\n", "score: train_loss 0.676 train_acc 0.573 valid_loss 0.679 valid_acc 0.548\n", "epoch: 7 lr: 0.0010\n", "score: train_loss 0.697 train_acc 0.521 valid_loss 0.675 valid_acc 0.583\n", "epoch: 8 lr: 0.0035\n", "score: train_loss 0.809 train_acc 0.547 valid_loss 0.966 valid_acc 0.470\n", "epoch: 9 lr: 0.0066\n", "score: train_loss 0.803 train_acc 0.536 valid_loss 0.816 valid_acc 0.470\n", "epoch: 10 lr: 0.0091\n", "score: train_loss 0.840 train_acc 0.516 valid_loss 0.679 valid_acc 0.602\n", "epoch: 11 lr: 0.0100\n", "score: train_loss 0.794 train_acc 0.516 valid_loss 0.705 valid_acc 0.530\n", "epoch: 12 lr: 0.0091\n", "score: train_loss 0.846 train_acc 0.484 valid_loss 1.001 valid_acc 0.530\n", "epoch: 13 lr: 0.0066\n", "score: train_loss 0.798 train_acc 0.547 valid_loss 0.914 valid_acc 0.470\n", "epoch: 14 lr: 0.0035\n", "score: train_loss 0.810 train_acc 0.505 valid_loss 0.771 valid_acc 0.530\n", "epoch: 15 lr: 0.0010\n", "score: train_loss 0.723 train_acc 0.547 valid_loss 0.696 valid_acc 0.530\n", "epoch: 16 lr: 0.0001\n", "score: train_loss 0.691 train_acc 0.521 valid_loss 0.691 valid_acc 0.530\n", "epoch: 17 lr: 0.0010\n", "score: train_loss 0.697 train_acc 0.490 valid_loss 0.704 valid_acc 0.470\n", "epoch: 18 lr: 0.0035\n", "score: train_loss 0.702 train_acc 0.500 valid_loss 0.693 valid_acc 0.530\n", "epoch: 19 lr: 0.0066\n", "score: train_loss 0.705 train_acc 0.490 valid_loss 0.687 valid_acc 0.549\n", "epoch: 20 lr: 0.0091\n", "score: train_loss 0.701 train_acc 0.516 valid_loss 0.690 valid_acc 0.488\n" ] } ], "source": [ "seed_torch()\n", "run_fold(\n", " model=vit,\n", " train_loader=train_loader,\n", " valid_loader=val_loader,\n", " n_epochs=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "学習はできるが、性能改善には DataAugmentation や事前学習済みモデルのファインチューニングが必要" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 事前学習済みモデルの活用" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- PyTorch Image MOdels\n", "- https://github.com/rwightman/pytorch-image-models\n", "\n", "※ torchvisionmodels にもある" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "gather": { "logged": 1667635159604 } }, "outputs": [], "source": [ "# !pip3 -q install timm" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "gather": { "logged": 1667635159830 } }, "outputs": [], "source": [ "import timm" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "gather": { "logged": 1667635160040 } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Available Vision Transformer Models: \n" ] }, { "data": { "text/plain": [ "['vit_base_patch8_224',\n", " 'vit_base_patch8_224_dino',\n", " 'vit_base_patch8_224_in21k',\n", " 'vit_base_patch16_18x2_224',\n", " 'vit_base_patch16_224',\n", " 'vit_base_patch16_224_dino',\n", " 'vit_base_patch16_224_in21k',\n", " 'vit_base_patch16_224_miil',\n", " 'vit_base_patch16_224_miil_in21k',\n", " 'vit_base_patch16_224_sam',\n", " 'vit_base_patch16_384',\n", " 'vit_base_patch16_plus_240',\n", " 'vit_base_patch16_rpn_224',\n", " 'vit_base_patch32_224',\n", " 'vit_base_patch32_224_clip_laion2b',\n", " 'vit_base_patch32_224_in21k',\n", " 'vit_base_patch32_224_sam',\n", " 'vit_base_patch32_384',\n", " 'vit_base_patch32_plus_256',\n", " 'vit_base_r26_s32_224',\n", " 'vit_base_r50_s16_224',\n", " 'vit_base_r50_s16_224_in21k',\n", " 'vit_base_r50_s16_384',\n", " 'vit_base_resnet26d_224',\n", " 'vit_base_resnet50_224_in21k',\n", " 'vit_base_resnet50_384',\n", " 'vit_base_resnet50d_224',\n", " 'vit_giant_patch14_224',\n", " 'vit_giant_patch14_224_clip_laion2b',\n", " 'vit_gigantic_patch14_224',\n", " 'vit_huge_patch14_224',\n", " 'vit_huge_patch14_224_clip_laion2b',\n", " 'vit_huge_patch14_224_in21k',\n", " 'vit_large_patch14_224',\n", " 'vit_large_patch14_224_clip_laion2b',\n", " 'vit_large_patch16_224',\n", " 'vit_large_patch16_224_in21k',\n", " 'vit_large_patch16_384',\n", " 'vit_large_patch32_224',\n", " 'vit_large_patch32_224_in21k',\n", " 'vit_large_patch32_384',\n", " 'vit_large_r50_s32_224',\n", " 'vit_large_r50_s32_224_in21k',\n", " 'vit_large_r50_s32_384',\n", " 'vit_relpos_base_patch16_224',\n", " 'vit_relpos_base_patch16_cls_224',\n", " 'vit_relpos_base_patch16_clsgap_224',\n", " 'vit_relpos_base_patch16_plus_240',\n", " 'vit_relpos_base_patch16_rpn_224',\n", " 'vit_relpos_base_patch32_plus_rpn_256',\n", " 'vit_relpos_medium_patch16_224',\n", " 'vit_relpos_medium_patch16_cls_224',\n", " 'vit_relpos_medium_patch16_rpn_224',\n", " 'vit_relpos_small_patch16_224',\n", " 'vit_relpos_small_patch16_rpn_224',\n", " 'vit_small_patch8_224_dino',\n", " 'vit_small_patch16_18x2_224',\n", " 'vit_small_patch16_36x1_224',\n", " 'vit_small_patch16_224',\n", " 'vit_small_patch16_224_dino',\n", " 'vit_small_patch16_224_in21k',\n", " 'vit_small_patch16_384',\n", " 'vit_small_patch32_224',\n", " 'vit_small_patch32_224_in21k',\n", " 'vit_small_patch32_384',\n", " 'vit_small_r26_s32_224',\n", " 'vit_small_r26_s32_224_in21k',\n", " 'vit_small_r26_s32_384',\n", " 'vit_small_resnet26d_224',\n", " 'vit_small_resnet50d_s16_224',\n", " 'vit_srelpos_medium_patch16_224',\n", " 'vit_srelpos_small_patch16_224',\n", " 'vit_tiny_patch16_224',\n", " 'vit_tiny_patch16_224_in21k',\n", " 'vit_tiny_patch16_384',\n", " 'vit_tiny_r_s16_p8_224',\n", " 'vit_tiny_r_s16_p8_224_in21k',\n", " 'vit_tiny_r_s16_p8_384']" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print('Available Vision Transformer Models: ')\n", "timm.list_models('vit*')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "gather": { "logged": 1667635161097 } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear(in_features=768, out_features=1000, bias=True)\n", "Linear(in_features=768, out_features=2, bias=True)\n" ] } ], "source": [ "# ベースモデルの取得\n", "MODEL_NAME = 'vit_base_patch16_224'\n", "_model = timm.create_model(MODEL_NAME, pretrained=False)\n", "print(_model.head)\n", "for param in _model.parameters():\n", " param.requires_grad = False\n", "# タスクに合わせてアーキテクチャを変更\n", "_model.head = nn.Linear(_model.head.in_features, 2)\n", "print(_model.head)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "gather": { "logged": 1667635161230 }, "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } }, "outputs": [], "source": [ "# アーキテクチャと更新するパラメータの確認\n", "# !pip3 -q install torchsummary\n", "# from torchsummary import summary\n", "# _model = _model.to(DEVICE)\n", "# summary(_model, input_size=(3, 224, 224))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "gather": { "logged": 1667635161533 } }, "outputs": [], "source": [ "class ViTBased(nn.Module):\n", " def __init__(self, n_classes:int=2):\n", "\n", " super(ViTBased, self).__init__()\n", "\n", " self.model = timm.create_model(MODEL_NAME, pretrained=True)\n", " for param in self.model.parameters():\n", " param.requires_grad = False\n", " self.model.head = nn.Linear(self.model.head.in_features, n_classes)\n", "\n", " def forward(self, x:torch.Tensor) -> torch.Tensor:\n", " out = self.model(x)\n", " return out" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "gather": { "logged": 1667635163497 } }, "outputs": [], "source": [ "seed_torch()\n", "vit = ViTBased(n_classes=2)\n", "vit = vit.to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "gather": { "logged": 1667635163627 } }, "outputs": [], "source": [ "def run_fold(\n", " model: nn.Module,\n", " train_loader: data.DataLoader,\n", " valid_loader: data.DataLoader,\n", " n_epochs=50) -> np.ndarray:\n", "\n", " optimizer = Adam(model.parameters(), lr=1e-2)\n", "\n", " for epoch in range(1, n_epochs+1):\n", " print(f'epoch: {epoch}')\n", " train_loss, train_acc = train(model, optimizer, train_loader)\n", " valid_loss, valid_acc = valid(model=model, valid_loader=valid_loader)\n", " print(f'score: train_loss {train_loss:.3f} train_acc {train_acc:.3f} valid_loss {valid_loss:.3f} valid_acc {valid_acc:.3f}')" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "gather": { "logged": 1667635199368 } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1\n", "score: train_loss 0.381 train_acc 0.891 valid_loss 0.072 valid_acc 0.974\n", "epoch: 2\n", "score: train_loss 0.037 train_acc 0.984 valid_loss 0.000 valid_acc 1.000\n", "epoch: 3\n", "score: train_loss 0.000 train_acc 1.000 valid_loss 0.000 valid_acc 1.000\n" ] } ], "source": [ "run_fold(\n", " model=vit,\n", " train_loader=train_loader,\n", " valid_loader=val_loader,\n", " n_epochs=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "outputs_hidden": false, "source_hidden": false }, "nteract": { "transient": { "deleting": false } } }, "outputs": [], "source": [] } ], "metadata": { "kernel_info": { "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "nteract": { "version": "nteract-front-end@1.0.0" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "195d00c3bc2576aa3aa8d34b1ef69c319bc4c5e1d04057dba8a69b2c34c3aaa0" } } }, "nbformat": 4, "nbformat_minor": 2 }