Vision Transformer (ViT)の理論と実装

Vision Transformer(ViT)は、2020年にGoogleの研究チームが発表した画像認識モデルです。論文 “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”(Dosovitskiy et al.)で提案され、CNNを使わずにTransformerのみで画像分類タスクにおいて最高水準の性能を達成しました。

それまで画像認識ではCNN(畳み込みニューラルネットワーク)が支配的でしたが、ViTの登場により、自然言語処理で成功を収めたTransformerアーキテクチャが画像領域にも適用できることが示されました。現在では、CLIP、DALL-E、Stable Diffusionなど、多くの最先端モデルがViTをベースとしています。

本記事の内容

  • ViTの全体アーキテクチャと設計思想
  • 画像のパッチ分割と線形埋め込み
  • CLSトークンと位置埋め込み
  • Transformer Encoderによる特徴抽出
  • PyTorchによるスクラッチ実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

ViTのアイデア:画像を単語列のように扱う

CNNとの違い

CNNは局所的な畳み込みフィルタを階層的に適用することで、画像の空間的な特徴を抽出します。畳み込みは画像の局所性(近くのピクセル同士は関連が強い)と平行移動不変性(パターンが画像のどこに現れても同じように検出できる)という帰納的バイアスを持っています。

一方、Transformerは入力系列の全要素間の関連度をSelf-Attentionで計算します。局所性の仮定を持たないため、遠く離れた要素間の関連も直接的に捉えられます。ただし、その分だけ学習に必要なデータ量が多くなります。

画像をパッチに分割する

ViTの核心的なアイデアは、画像を小さなパッチに分割し、各パッチを1つのトークン(単語)として扱うことです。

具体的には、$H \times W$ の画像を $P \times P$ のパッチに分割します。パッチの数は以下のようになります。

$$ N = \frac{H \times W}{P^2} $$

例えば、$224 \times 224$ の画像を $16 \times 16$ のパッチに分割すると、

$$ N = \frac{224 \times 224}{16 \times 16} = \frac{50176}{256} = 196 $$

となり、196個のパッチ(トークン)が得られます。これは、196単語の文章をTransformerに入力するのと同じ構造です。

ViTの全体アーキテクチャ

ViTの処理フローを段階的に説明します。

[入力画像] (H × W × C)
    ↓
[パッチ分割] → N個のパッチ (P × P × C)
    ↓
[パッチ埋め込み] → (N, D)
    ↓
[CLSトークンを先頭に追加] → (N+1, D)
    ↓
[位置埋め込みを加算] → (N+1, D)
    ↓
┌─────────────────────────┐
│ Transformer Encoder ×L  │
│  ├ Layer Normalization  │
│  ├ Multi-Head Self-Attn │
│  ├ 残差接続             │
│  ├ Layer Normalization  │
│  ├ MLP                  │
│  └ 残差接続             │
└─────────────────────────┘
    ↓
[CLSトークンの出力を取得] → (D,)
    ↓
[分類ヘッド (MLP)] → (num_classes,)

パッチ埋め込み

各パッチは $P \times P \times C$ のテンソルです($C$ はチャンネル数、RGBなら3)。これを1次元に平坦化すると $P^2 \cdot C$ 次元のベクトルになります。

このベクトルを線形変換により $D$ 次元の埋め込みベクトルに変換します。

$$ \bm{z}_i = \bm{W}_E \bm{x}_i^{\text{patch}} + \bm{b}_E $$

ここで $\bm{W}_E \in \mathbb{R}^{D \times (P^2 C)}$、$\bm{b}_E \in \mathbb{R}^D$ は学習可能なパラメータです。

実装上は、カーネルサイズとストライドを $P$ に設定した畳み込み層で効率的に計算できます。

CLSトークン

BERTと同様に、ViTでは系列の先頭に学習可能なCLSトークン(Classification Token)を追加します。

$$ \bm{z}_0 = \bm{x}_{\text{cls}} $$

Transformer Encoderを通過した後、CLSトークンの出力ベクトルは画像全体の情報を集約した表現となります。最終的な分類は、このCLSトークンの出力を分類ヘッド(MLP)に入力することで行います。

なぜCLSトークンを使うのでしょうか。Self-Attentionにより、CLSトークンは全パッチからの情報を集約できます。各パッチの出力を平均する方法(Global Average Pooling)も考えられますが、CLSトークンを使うことで、モデルが「分類に重要な情報を集約する方法」を学習できます。

位置埋め込み

Transformerは位置情報を持たないため、各パッチの空間的な位置を伝える必要があります。ViTでは、学習可能な位置埋め込みを使用します。

$$ \bm{z}_i’ = \bm{z}_i + \bm{E}_{\text{pos}}[i] $$

ここで $\bm{E}_{\text{pos}} \in \mathbb{R}^{(N+1) \times D}$ は学習可能なパラメータです。CLSトークンを含めて $N+1$ 個の位置埋め込みを持ちます。

原論文のTransformerではsin/cos関数による固定の位置エンコーディングを使いましたが、ViTでは学習可能な位置埋め込みを使います。実験の結果、2D構造を明示的にエンコードする2D位置埋め込みと1D位置埋め込みの性能差はほとんどないことが報告されています。

Transformer Encoder

ViTのTransformer Encoderは、原論文のEncoderと基本的に同じ構造ですが、Layer Normalizationの位置が異なります(Pre-Normalization)。

$$ \begin{align} \bm{z}’ &= \bm{z} + \text{MSA}(\text{LN}(\bm{z})) \\ \bm{z}” &= \bm{z}’ + \text{MLP}(\text{LN}(\bm{z}’)) \end{align} $$

ここで、MSAはMulti-head Self-Attention、LNはLayer Normalization、MLPは2層の全結合ネットワークです。

原論文のTransformerではサブレイヤーの後にLayer Normalizationを適用しますが(Post-Normalization)、ViTではサブレイヤーの前に適用します(Pre-Normalization)。Pre-Normalizationは学習の安定性が高く、深いネットワークでも勾配が安定することが知られています。

MLPは以下の構造を持ちます。

$$ \text{MLP}(\bm{x}) = \bm{W}_2 \cdot \text{GELU}(\bm{W}_1 \bm{x} + \bm{b}_1) + \bm{b}_2 $$

GELU(Gaussian Error Linear Unit)は、BERTやGPTで使われる活性化関数です。

$$ \text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right] $$

ここで $\Phi(x)$ は標準正規分布の累積分布関数です。

分類ヘッド

Transformer Encoderを通過した後、CLSトークンの出力ベクトル $\bm{z}_0^L$($L$ はEncoder層の数)を分類ヘッドに入力します。

$$ \hat{\bm{y}} = \text{softmax}(\bm{W}_{\text{head}} \bm{z}_0^L + \bm{b}_{\text{head}}) $$

事前学習時には隠れ層を持つMLPを使い、ファインチューニング時には単一の線形層を使うことが多いです。

ViTのモデルバリアント

原論文では、以下の3つのモデルサイズが定義されています。

モデル 層数 (L) 隠れ次元 (D) MLPサイズ ヘッド数 パラメータ数
ViT-Base 12 768 3072 12 86M
ViT-Large 24 1024 4096 16 307M
ViT-Huge 32 1280 5120 16 632M

パッチサイズは16または14が一般的で、ViT-B/16は「ViT-Baseでパッチサイズ16」を意味します。

PyTorchによる実装

パッチ埋め込み

import torch
import torch.nn as nn


class PatchEmbedding(nn.Module):
    """画像をパッチに分割して埋め込みベクトルに変換"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # 畳み込み層でパッチ分割と線形変換を同時に実行
        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        """
        Args:
            x: (batch_size, channels, height, width)
        Returns:
            (batch_size, num_patches, embed_dim)
        """
        # (B, C, H, W) -> (B, D, H/P, W/P)
        x = self.proj(x)
        # (B, D, H/P, W/P) -> (B, D, N) -> (B, N, D)
        x = x.flatten(2).transpose(1, 2)
        return x

Transformer Encoderブロック

class TransformerEncoderBlock(nn.Module):
    """ViTのTransformer Encoderブロック(Pre-Normalization)"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.norm2 = nn.LayerNorm(embed_dim)

        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, embed_dim)
        Returns:
            (batch_size, seq_len, embed_dim)
        """
        # Pre-Normalization + Multi-Head Self-Attention + 残差接続
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out

        # Pre-Normalization + MLP + 残差接続
        x = x + self.mlp(self.norm2(x))
        return x

Vision Transformer全体

class VisionTransformer(nn.Module):
    """Vision Transformer (ViT)"""
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.0,
    ):
        super().__init__()

        # パッチ埋め込み
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # CLSトークン(学習可能)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 位置埋め込み(学習可能)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # Dropout
        self.pos_drop = nn.Dropout(p=dropout)

        # Transformer Encoderブロックを積み重ねる
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # 最終のLayer Normalization
        self.norm = nn.LayerNorm(embed_dim)

        # 分類ヘッド
        self.head = nn.Linear(embed_dim, num_classes)

        # 重みの初期化
        self._init_weights()

    def _init_weights(self):
        # 位置埋め込みとCLSトークンの初期化
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        # 線形層とLayerNormの初期化
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        """
        Args:
            x: (batch_size, channels, height, width)
        Returns:
            (batch_size, num_classes)
        """
        batch_size = x.shape[0]

        # パッチ埋め込み: (B, C, H, W) -> (B, N, D)
        x = self.patch_embed(x)

        # CLSトークンを先頭に追加: (B, N, D) -> (B, N+1, D)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # 位置埋め込みを加算
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # Transformer Encoderブロックを適用
        for block in self.blocks:
            x = block(x)

        # 最終のLayer Normalization
        x = self.norm(x)

        # CLSトークンの出力を取得して分類
        cls_output = x[:, 0]  # (B, D)
        logits = self.head(cls_output)  # (B, num_classes)

        return logits

動作確認

# モデル作成(ViT-Base/16相当)
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=1000,
    embed_dim=768,
    depth=12,
    num_heads=12,
)

# ダミー入力
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224)

# 順伝播
logits = model(x)

print(f"入力形状: {x.shape}")      # (2, 3, 224, 224)
print(f"出力形状: {logits.shape}") # (2, 1000)

# パラメータ数を確認
total_params = sum(p.numel() for p in model.parameters())
print(f"総パラメータ数: {total_params:,}")  # 約86M

Attention重みの可視化

ViTのAttention重みを可視化することで、モデルがどの領域に注目しているかを確認できます。

import matplotlib.pyplot as plt
import numpy as np


def visualize_attention(model, img, patch_size=16):
    """Attention重みを可視化"""
    model.eval()

    with torch.no_grad():
        # パッチ埋め込み
        x = model.patch_embed(img)
        batch_size = x.shape[0]

        # CLSトークンと位置埋め込み
        cls_tokens = model.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + model.pos_embed

        # 最後のブロックからAttention重みを取得
        for i, block in enumerate(model.blocks):
            x_norm = block.norm1(x)
            if i == len(model.blocks) - 1:
                # 最後のブロックでAttention重みを取得
                _, attn_weights = block.attn(
                    x_norm, x_norm, x_norm, need_weights=True
                )
            else:
                attn_out, _ = block.attn(x_norm, x_norm, x_norm)
                x = x + attn_out
                x = x + block.mlp(block.norm2(x))

    # CLSトークンから各パッチへのAttention重み
    # attn_weights: (B, num_heads, seq_len, seq_len)
    attn_cls = attn_weights[0, :, 0, 1:]  # (num_heads, num_patches)
    attn_mean = attn_cls.mean(dim=0)  # 全ヘッドの平均

    # パッチグリッドに変形
    num_patches_per_side = int(np.sqrt(attn_mean.shape[0]))
    attn_map = attn_mean.reshape(num_patches_per_side, num_patches_per_side)

    return attn_map.numpy()


# 可視化例(ランダム画像で)
torch.manual_seed(42)
dummy_img = torch.randn(1, 3, 224, 224)

model_small = VisionTransformer(
    img_size=224, patch_size=16, embed_dim=192, depth=4, num_heads=4, num_classes=10
)

attn_map = visualize_attention(model_small, dummy_img)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# 入力画像(正規化を戻してRGB表示用に変換)
img_display = dummy_img[0].permute(1, 2, 0).numpy()
img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min())
axes[0].imshow(img_display)
axes[0].set_title("Input Image")
axes[0].axis("off")

# Attention重み
im = axes[1].imshow(attn_map, cmap="viridis")
axes[1].set_title("Attention from CLS token")
axes[1].axis("off")
plt.colorbar(im, ax=axes[1], shrink=0.8)

plt.tight_layout()
plt.show()

ViTの特性と限界

大規模データでの有効性

ViTはCNNのような帰納的バイアス(局所性、平行移動不変性)を持たないため、小規模データセットではCNNに劣ることが報告されています。しかし、JFT-300M(3億枚の画像)のような大規模データセットで事前学習すると、ImageNetでCNNを上回る性能を達成します。

これは、帰納的バイアスが少ないことで、データから柔軟にパターンを学習できる一方、十分なデータがないと汎化性能が低下するためです。

計算効率

Self-Attentionの計算量は系列長の2乗に比例するため、パッチサイズを小さくする(解像度を上げる)と計算コストが急増します。$224 \times 224$ の画像をパッチサイズ16で分割すると196パッチですが、パッチサイズ8では784パッチとなり、Attentionの計算量は約16倍になります。

後続の発展

ViT以降、DeiT(Data-efficient Image Transformer)、Swin Transformer、BEiTなど、多くの改良が提案されています。特にSwin Transformerは、階層的な構造とシフトウィンドウによる局所Attentionを導入し、様々な視覚タスクで高い性能を達成しています。

まとめ

本記事では、Vision Transformer(ViT)の仕組みを解説しました。

  • パッチ分割: 画像を $P \times P$ のパッチに分割し、各パッチを1つのトークンとして扱う
  • パッチ埋め込み: 各パッチを線形変換で $D$ 次元のベクトルに変換
  • CLSトークン: 画像全体の情報を集約するための学習可能なトークン
  • 位置埋め込み: 学習可能な1D位置埋め込みを使用
  • Pre-Normalization: サブレイヤーの前にLayer Normalizationを適用し、学習を安定化
  • 大規模データの重要性: 帰納的バイアスが少ないため、大規模データでの事前学習が性能向上に不可欠

ViTは、Transformerの画像領域への適用可能性を示した画期的なモデルであり、CLIPやStable Diffusionなど後続の多くのモデルの基盤となっています。

次のステップとして、以下の記事も参考にしてください。