Vision Transformer(ViT)は、2020年にGoogleの研究チームが発表した画像認識モデルです。論文 “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale”(Dosovitskiy et al.)で提案され、CNNを使わずにTransformerのみで画像分類タスクにおいて最高水準の性能を達成しました。
それまで画像認識ではCNN(畳み込みニューラルネットワーク)が支配的でしたが、ViTの登場により、自然言語処理で成功を収めたTransformerアーキテクチャが画像領域にも適用できることが示されました。現在では、CLIP、DALL-E、Stable Diffusionなど、多くの最先端モデルがViTをベースとしています。
本記事の内容
- ViTの全体アーキテクチャと設計思想
- 画像のパッチ分割と線形埋め込み
- CLSトークンと位置埋め込み
- Transformer Encoderによる特徴抽出
- PyTorchによるスクラッチ実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
ViTのアイデア:画像を単語列のように扱う
CNNとの違い
CNNは局所的な畳み込みフィルタを階層的に適用することで、画像の空間的な特徴を抽出します。畳み込みは画像の局所性(近くのピクセル同士は関連が強い)と平行移動不変性(パターンが画像のどこに現れても同じように検出できる)という帰納的バイアスを持っています。
一方、Transformerは入力系列の全要素間の関連度をSelf-Attentionで計算します。局所性の仮定を持たないため、遠く離れた要素間の関連も直接的に捉えられます。ただし、その分だけ学習に必要なデータ量が多くなります。
画像をパッチに分割する
ViTの核心的なアイデアは、画像を小さなパッチに分割し、各パッチを1つのトークン(単語)として扱うことです。
具体的には、$H \times W$ の画像を $P \times P$ のパッチに分割します。パッチの数は以下のようになります。
$$ N = \frac{H \times W}{P^2} $$
例えば、$224 \times 224$ の画像を $16 \times 16$ のパッチに分割すると、
$$ N = \frac{224 \times 224}{16 \times 16} = \frac{50176}{256} = 196 $$
となり、196個のパッチ(トークン)が得られます。これは、196単語の文章をTransformerに入力するのと同じ構造です。
ViTの全体アーキテクチャ
ViTの処理フローを段階的に説明します。
[入力画像] (H × W × C)
↓
[パッチ分割] → N個のパッチ (P × P × C)
↓
[パッチ埋め込み] → (N, D)
↓
[CLSトークンを先頭に追加] → (N+1, D)
↓
[位置埋め込みを加算] → (N+1, D)
↓
┌─────────────────────────┐
│ Transformer Encoder ×L │
│ ├ Layer Normalization │
│ ├ Multi-Head Self-Attn │
│ ├ 残差接続 │
│ ├ Layer Normalization │
│ ├ MLP │
│ └ 残差接続 │
└─────────────────────────┘
↓
[CLSトークンの出力を取得] → (D,)
↓
[分類ヘッド (MLP)] → (num_classes,)
パッチ埋め込み
各パッチは $P \times P \times C$ のテンソルです($C$ はチャンネル数、RGBなら3)。これを1次元に平坦化すると $P^2 \cdot C$ 次元のベクトルになります。
このベクトルを線形変換により $D$ 次元の埋め込みベクトルに変換します。
$$ \bm{z}_i = \bm{W}_E \bm{x}_i^{\text{patch}} + \bm{b}_E $$
ここで $\bm{W}_E \in \mathbb{R}^{D \times (P^2 C)}$、$\bm{b}_E \in \mathbb{R}^D$ は学習可能なパラメータです。
実装上は、カーネルサイズとストライドを $P$ に設定した畳み込み層で効率的に計算できます。
CLSトークン
BERTと同様に、ViTでは系列の先頭に学習可能なCLSトークン(Classification Token)を追加します。
$$ \bm{z}_0 = \bm{x}_{\text{cls}} $$
Transformer Encoderを通過した後、CLSトークンの出力ベクトルは画像全体の情報を集約した表現となります。最終的な分類は、このCLSトークンの出力を分類ヘッド(MLP)に入力することで行います。
なぜCLSトークンを使うのでしょうか。Self-Attentionにより、CLSトークンは全パッチからの情報を集約できます。各パッチの出力を平均する方法(Global Average Pooling)も考えられますが、CLSトークンを使うことで、モデルが「分類に重要な情報を集約する方法」を学習できます。
位置埋め込み
Transformerは位置情報を持たないため、各パッチの空間的な位置を伝える必要があります。ViTでは、学習可能な位置埋め込みを使用します。
$$ \bm{z}_i’ = \bm{z}_i + \bm{E}_{\text{pos}}[i] $$
ここで $\bm{E}_{\text{pos}} \in \mathbb{R}^{(N+1) \times D}$ は学習可能なパラメータです。CLSトークンを含めて $N+1$ 個の位置埋め込みを持ちます。
原論文のTransformerではsin/cos関数による固定の位置エンコーディングを使いましたが、ViTでは学習可能な位置埋め込みを使います。実験の結果、2D構造を明示的にエンコードする2D位置埋め込みと1D位置埋め込みの性能差はほとんどないことが報告されています。
Transformer Encoder
ViTのTransformer Encoderは、原論文のEncoderと基本的に同じ構造ですが、Layer Normalizationの位置が異なります(Pre-Normalization)。
$$ \begin{align} \bm{z}’ &= \bm{z} + \text{MSA}(\text{LN}(\bm{z})) \\ \bm{z}” &= \bm{z}’ + \text{MLP}(\text{LN}(\bm{z}’)) \end{align} $$
ここで、MSAはMulti-head Self-Attention、LNはLayer Normalization、MLPは2層の全結合ネットワークです。
原論文のTransformerではサブレイヤーの後にLayer Normalizationを適用しますが(Post-Normalization)、ViTではサブレイヤーの前に適用します(Pre-Normalization)。Pre-Normalizationは学習の安定性が高く、深いネットワークでも勾配が安定することが知られています。
MLPは以下の構造を持ちます。
$$ \text{MLP}(\bm{x}) = \bm{W}_2 \cdot \text{GELU}(\bm{W}_1 \bm{x} + \bm{b}_1) + \bm{b}_2 $$
GELU(Gaussian Error Linear Unit)は、BERTやGPTで使われる活性化関数です。
$$ \text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right] $$
ここで $\Phi(x)$ は標準正規分布の累積分布関数です。
分類ヘッド
Transformer Encoderを通過した後、CLSトークンの出力ベクトル $\bm{z}_0^L$($L$ はEncoder層の数)を分類ヘッドに入力します。
$$ \hat{\bm{y}} = \text{softmax}(\bm{W}_{\text{head}} \bm{z}_0^L + \bm{b}_{\text{head}}) $$
事前学習時には隠れ層を持つMLPを使い、ファインチューニング時には単一の線形層を使うことが多いです。
ViTのモデルバリアント
原論文では、以下の3つのモデルサイズが定義されています。
| モデル | 層数 (L) | 隠れ次元 (D) | MLPサイズ | ヘッド数 | パラメータ数 |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 3072 | 12 | 86M |
| ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
パッチサイズは16または14が一般的で、ViT-B/16は「ViT-Baseでパッチサイズ16」を意味します。
PyTorchによる実装
パッチ埋め込み
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""画像をパッチに分割して埋め込みベクトルに変換"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 畳み込み層でパッチ分割と線形変換を同時に実行
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
"""
Args:
x: (batch_size, channels, height, width)
Returns:
(batch_size, num_patches, embed_dim)
"""
# (B, C, H, W) -> (B, D, H/P, W/P)
x = self.proj(x)
# (B, D, H/P, W/P) -> (B, D, N) -> (B, N, D)
x = x.flatten(2).transpose(1, 2)
return x
Transformer Encoderブロック
class TransformerEncoderBlock(nn.Module):
"""ViTのTransformer Encoderブロック(Pre-Normalization)"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(
embed_dim, num_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
"""
Args:
x: (batch_size, seq_len, embed_dim)
Returns:
(batch_size, seq_len, embed_dim)
"""
# Pre-Normalization + Multi-Head Self-Attention + 残差接続
x_norm = self.norm1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_out
# Pre-Normalization + MLP + 残差接続
x = x + self.mlp(self.norm2(x))
return x
Vision Transformer全体
class VisionTransformer(nn.Module):
"""Vision Transformer (ViT)"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.0,
):
super().__init__()
# パッチ埋め込み
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
num_patches = self.patch_embed.num_patches
# CLSトークン(学習可能)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置埋め込み(学習可能)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
# Dropout
self.pos_drop = nn.Dropout(p=dropout)
# Transformer Encoderブロックを積み重ねる
self.blocks = nn.ModuleList([
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
# 最終のLayer Normalization
self.norm = nn.LayerNorm(embed_dim)
# 分類ヘッド
self.head = nn.Linear(embed_dim, num_classes)
# 重みの初期化
self._init_weights()
def _init_weights(self):
# 位置埋め込みとCLSトークンの初期化
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
# 線形層とLayerNormの初期化
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
"""
Args:
x: (batch_size, channels, height, width)
Returns:
(batch_size, num_classes)
"""
batch_size = x.shape[0]
# パッチ埋め込み: (B, C, H, W) -> (B, N, D)
x = self.patch_embed(x)
# CLSトークンを先頭に追加: (B, N, D) -> (B, N+1, D)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# 位置埋め込みを加算
x = x + self.pos_embed
x = self.pos_drop(x)
# Transformer Encoderブロックを適用
for block in self.blocks:
x = block(x)
# 最終のLayer Normalization
x = self.norm(x)
# CLSトークンの出力を取得して分類
cls_output = x[:, 0] # (B, D)
logits = self.head(cls_output) # (B, num_classes)
return logits
動作確認
# モデル作成(ViT-Base/16相当)
model = VisionTransformer(
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
)
# ダミー入力
batch_size = 2
x = torch.randn(batch_size, 3, 224, 224)
# 順伝播
logits = model(x)
print(f"入力形状: {x.shape}") # (2, 3, 224, 224)
print(f"出力形状: {logits.shape}") # (2, 1000)
# パラメータ数を確認
total_params = sum(p.numel() for p in model.parameters())
print(f"総パラメータ数: {total_params:,}") # 約86M
Attention重みの可視化
ViTのAttention重みを可視化することで、モデルがどの領域に注目しているかを確認できます。
import matplotlib.pyplot as plt
import numpy as np
def visualize_attention(model, img, patch_size=16):
"""Attention重みを可視化"""
model.eval()
with torch.no_grad():
# パッチ埋め込み
x = model.patch_embed(img)
batch_size = x.shape[0]
# CLSトークンと位置埋め込み
cls_tokens = model.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + model.pos_embed
# 最後のブロックからAttention重みを取得
for i, block in enumerate(model.blocks):
x_norm = block.norm1(x)
if i == len(model.blocks) - 1:
# 最後のブロックでAttention重みを取得
_, attn_weights = block.attn(
x_norm, x_norm, x_norm, need_weights=True
)
else:
attn_out, _ = block.attn(x_norm, x_norm, x_norm)
x = x + attn_out
x = x + block.mlp(block.norm2(x))
# CLSトークンから各パッチへのAttention重み
# attn_weights: (B, num_heads, seq_len, seq_len)
attn_cls = attn_weights[0, :, 0, 1:] # (num_heads, num_patches)
attn_mean = attn_cls.mean(dim=0) # 全ヘッドの平均
# パッチグリッドに変形
num_patches_per_side = int(np.sqrt(attn_mean.shape[0]))
attn_map = attn_mean.reshape(num_patches_per_side, num_patches_per_side)
return attn_map.numpy()
# 可視化例(ランダム画像で)
torch.manual_seed(42)
dummy_img = torch.randn(1, 3, 224, 224)
model_small = VisionTransformer(
img_size=224, patch_size=16, embed_dim=192, depth=4, num_heads=4, num_classes=10
)
attn_map = visualize_attention(model_small, dummy_img)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
# 入力画像(正規化を戻してRGB表示用に変換)
img_display = dummy_img[0].permute(1, 2, 0).numpy()
img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min())
axes[0].imshow(img_display)
axes[0].set_title("Input Image")
axes[0].axis("off")
# Attention重み
im = axes[1].imshow(attn_map, cmap="viridis")
axes[1].set_title("Attention from CLS token")
axes[1].axis("off")
plt.colorbar(im, ax=axes[1], shrink=0.8)
plt.tight_layout()
plt.show()
ViTの特性と限界
大規模データでの有効性
ViTはCNNのような帰納的バイアス(局所性、平行移動不変性)を持たないため、小規模データセットではCNNに劣ることが報告されています。しかし、JFT-300M(3億枚の画像)のような大規模データセットで事前学習すると、ImageNetでCNNを上回る性能を達成します。
これは、帰納的バイアスが少ないことで、データから柔軟にパターンを学習できる一方、十分なデータがないと汎化性能が低下するためです。
計算効率
Self-Attentionの計算量は系列長の2乗に比例するため、パッチサイズを小さくする(解像度を上げる)と計算コストが急増します。$224 \times 224$ の画像をパッチサイズ16で分割すると196パッチですが、パッチサイズ8では784パッチとなり、Attentionの計算量は約16倍になります。
後続の発展
ViT以降、DeiT(Data-efficient Image Transformer)、Swin Transformer、BEiTなど、多くの改良が提案されています。特にSwin Transformerは、階層的な構造とシフトウィンドウによる局所Attentionを導入し、様々な視覚タスクで高い性能を達成しています。
まとめ
本記事では、Vision Transformer(ViT)の仕組みを解説しました。
- パッチ分割: 画像を $P \times P$ のパッチに分割し、各パッチを1つのトークンとして扱う
- パッチ埋め込み: 各パッチを線形変換で $D$ 次元のベクトルに変換
- CLSトークン: 画像全体の情報を集約するための学習可能なトークン
- 位置埋め込み: 学習可能な1D位置埋め込みを使用
- Pre-Normalization: サブレイヤーの前にLayer Normalizationを適用し、学習を安定化
- 大規模データの重要性: 帰納的バイアスが少ないため、大規模データでの事前学習が性能向上に不可欠
ViTは、Transformerの画像領域への適用可能性を示した画期的なモデルであり、CLIPやStable Diffusionなど後続の多くのモデルの基盤となっています。
次のステップとして、以下の記事も参考にしてください。