モデル枝刈り(Pruning)は、ニューラルネットワークの不要なパラメータを削除してモデルを軽量化する技術です。推論速度の向上とメモリ使用量の削減を実現しながら、精度の低下を最小限に抑えることができます。
本記事では、枝刈りの理論的背景から実装まで詳しく解説します。
本記事の内容
- 枝刈りの基本概念と分類
- 重要度スコアの計算方法
- PyTorchでの実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
枝刈りとは
基本的なアイデア
ニューラルネットワークには、予測性能にほとんど寄与しない「冗長な」パラメータが多数存在します。枝刈りは、これらの冗長なパラメータを特定して削除(ゼロに設定)することで、モデルを効率化します。
宝くじ仮説(Lottery Ticket Hypothesis)
2019年に発表された「宝くじ仮説」によると:
密なニューラルネットワークには、独立に学習させても元のネットワークと同等の性能を達成できる「当たりくじ」(サブネットワーク)が存在する
この発見は、枝刈りの理論的な裏付けを与えました。
枝刈りの分類
1. 粒度による分類
| 種類 | 説明 | 利点 | 欠点 |
|---|---|---|---|
| 非構造化(Unstructured) | 個々の重みを削除 | 高い圧縮率 | ハードウェア効率が低い |
| 構造化(Structured) | フィルタ/チャネル単位で削除 | ハードウェア効率が高い | 圧縮率は低い |
2. タイミングによる分類
| 種類 | 説明 |
|---|---|
| 学習後枝刈り | 学習済みモデルを枝刈り → 再学習 |
| 学習中枝刈り | 学習しながら徐々に枝刈り |
| 学習前枝刈り | 初期化時に枝刈り(宝くじ仮説) |
重要度スコアの計算
枝刈りの核心は、「どのパラメータを削除するか」を決める重要度スコアの計算です。
Magnitude-based Pruning
最もシンプルな手法は、重みの絶対値を重要度とする方法です:
$$ s_i = |w_i| $$
スコアが閾値 $\theta$ 以下の重みを削除:
$$ w_i’ = \begin{cases} w_i & \text{if } |w_i| > \theta \\ 0 & \text{otherwise} \end{cases} $$
Gradient-based Pruning
勾配情報を使う手法もあります。重みと勾配の積を重要度とします:
$$ s_i = \left| w_i \cdot \frac{\partial \mathcal{L}}{\partial w_i} \right| $$
これは、その重みがどれだけ損失に影響するかを表します。
Taylor Expansion based Pruning
損失関数のテイラー展開に基づく手法です。重み $w_i$ を0にしたときの損失変化を近似:
$$ \Delta \mathcal{L}_i = \mathcal{L}(w_i = 0) – \mathcal{L}(w_i) $$
1次のテイラー展開で近似すると:
$$ \Delta \mathcal{L}_i \approx -w_i \cdot \frac{\partial \mathcal{L}}{\partial w_i} $$
損失増加を最小化したいので、$|\Delta \mathcal{L}_i|$ が小さい重みを削除します。
Hessian-based Pruning
2次の情報(ヘッシアン)を使う手法もあります。Optimal Brain Damage(OBD)では:
$$ \Delta \mathcal{L}_i \approx \frac{1}{2} w_i^2 \cdot \frac{\partial^2 \mathcal{L}}{\partial w_i^2} $$
ヘッシアンの対角成分と重みの2乗の積で重要度を評価します。
枝刈りのアルゴリズム
基本的な枝刈りの流れ
- モデルを学習(または学習済みモデルを読み込み)
- 重要度スコアを計算
- スコアが低いパラメータをゼロに設定
- (オプション)Fine-tuning で精度を回復
- 必要に応じて2-4を繰り返す
反復的枝刈り(Iterative Pruning)
一度に大量のパラメータを削除すると精度が大きく低下します。反復的枝刈りでは:
- 少量(例: 10%)を枝刈り
- Fine-tuning
- 目標の圧縮率に達するまで繰り返し
Pythonでの実装
Magnitude-based Pruning の実装
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class MagnitudePruner:
"""Magnitude-based Pruning の実装"""
def __init__(self, model):
self.model = model
def compute_mask(self, sparsity):
"""
指定されたスパース率に基づいてマスクを計算
Args:
sparsity: 削除する重みの割合 (0-1)
Returns:
masks: 各層のマスク辞書
"""
masks = {}
# 全ての重みを収集
all_weights = []
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
all_weights.append(param.data.abs().view(-1))
# 全重みを結合して閾値を計算
all_weights = torch.cat(all_weights)
threshold = torch.quantile(all_weights, sparsity)
# 各層のマスクを作成
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
mask = (param.data.abs() > threshold).float()
masks[name] = mask
return masks
def apply_mask(self, masks):
"""マスクを適用して重みを枝刈り"""
with torch.no_grad():
for name, param in self.model.named_parameters():
if name in masks:
param.data *= masks[name]
def get_sparsity(self):
"""現在のスパース率を計算"""
total_params = 0
zero_params = 0
for name, param in self.model.named_parameters():
if 'weight' in name and param.dim() > 1:
total_params += param.numel()
zero_params += (param.data == 0).sum().item()
return zero_params / total_params if total_params > 0 else 0
# 動作確認用のシンプルなネットワーク
class SimpleNet(nn.Module):
def __init__(self, input_dim=784, hidden_dims=[512, 256], num_classes=10):
super().__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# 枝刈りのデモ
torch.manual_seed(42)
model = SimpleNet()
# 枝刈り前のパラメータ分布
weights_before = []
for name, param in model.named_parameters():
if 'weight' in name and param.dim() > 1:
weights_before.extend(param.data.view(-1).numpy())
pruner = MagnitudePruner(model)
# 50%のスパース率で枝刈り
masks = pruner.compute_mask(sparsity=0.5)
pruner.apply_mask(masks)
print(f"Sparsity after pruning: {pruner.get_sparsity():.2%}")
# 枝刈り後のパラメータ分布
weights_after = []
for name, param in model.named_parameters():
if 'weight' in name and param.dim() > 1:
weights_after.extend(param.data.view(-1).numpy())
# 可視化
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].hist(weights_before, bins=100, alpha=0.7, color='blue')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Count')
axes[0].set_title('Weight Distribution (Before Pruning)')
axes[0].set_yscale('log')
axes[1].hist(weights_after, bins=100, alpha=0.7, color='red')
axes[1].set_xlabel('Weight Value')
axes[1].set_ylabel('Count')
axes[1].set_title('Weight Distribution (After 50% Pruning)')
axes[1].set_yscale('log')
plt.tight_layout()
plt.show()
反復的枝刈り + Fine-tuning
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
def create_synthetic_dataset(n_samples=5000, input_dim=100, num_classes=10):
"""合成データセットの作成"""
np.random.seed(42)
X = np.random.randn(n_samples, input_dim).astype(np.float32)
W_true = np.random.randn(input_dim, num_classes).astype(np.float32)
logits = X @ W_true + np.random.randn(n_samples, num_classes).astype(np.float32) * 0.5
y = np.argmax(logits, axis=1)
split = int(0.8 * n_samples)
return (X[:split], y[:split]), (X[split:], y[split:])
class PrunableNet(nn.Module):
"""枝刈り対応ネットワーク"""
def __init__(self, input_dim=100, hidden_dims=[256, 128], num_classes=10):
super().__init__()
self.layers = nn.ModuleList()
prev_dim = input_dim
for hidden_dim in hidden_dims:
self.layers.append(nn.Linear(prev_dim, hidden_dim))
prev_dim = hidden_dim
self.output = nn.Linear(prev_dim, num_classes)
# マスクを保持
self.masks = {}
def forward(self, x):
for layer in self.layers:
x = F.relu(layer(x))
return self.output(x)
def apply_masks(self):
"""保持しているマスクを適用"""
with torch.no_grad():
for name, param in self.named_parameters():
if name in self.masks:
param.data *= self.masks[name]
def train_epoch(model, train_loader, criterion, optimizer):
"""1エポックの学習"""
model.train()
total_loss = 0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
loss.backward()
optimizer.step()
# 枝刈り後はマスクを再適用(勾配更新で復活しないように)
model.apply_masks()
total_loss += loss.item()
return total_loss / len(train_loader)
def evaluate(model, test_loader):
"""評価"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
outputs = model(X_batch)
_, predicted = torch.max(outputs, 1)
total += y_batch.size(0)
correct += (predicted == y_batch).sum().item()
return correct / total
def iterative_pruning(model, train_loader, test_loader,
target_sparsity=0.9, pruning_steps=10, finetune_epochs=5):
"""反復的枝刈り"""
criterion = nn.CrossEntropyLoss()
# 初期精度
initial_acc = evaluate(model, test_loader)
print(f"Initial Accuracy: {initial_acc:.4f}")
sparsities = []
accuracies = []
# 各ステップでの枝刈り率
current_sparsity = 0.0
step_sparsity = 1.0 - (1.0 - target_sparsity) ** (1.0 / pruning_steps)
for step in range(pruning_steps):
# 累積スパース率の計算
current_sparsity = 1.0 - (1.0 - step_sparsity) ** (step + 1)
# マスクの計算と適用
pruner = MagnitudePruner(model)
masks = pruner.compute_mask(current_sparsity)
# モデルにマスクを保存
model.masks = masks
model.apply_masks()
actual_sparsity = pruner.get_sparsity()
print(f"\nStep {step + 1}/{pruning_steps}: Target sparsity = {current_sparsity:.2%}, Actual = {actual_sparsity:.2%}")
# Fine-tuning
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(finetune_epochs):
loss = train_epoch(model, train_loader, criterion, optimizer)
acc = evaluate(model, test_loader)
print(f" Accuracy after fine-tuning: {acc:.4f}")
sparsities.append(actual_sparsity)
accuracies.append(acc)
return sparsities, accuracies, initial_acc
# データ準備
(X_train, y_train), (X_test, y_test) = create_synthetic_dataset()
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# モデルの学習
print("Training initial model...")
model = PrunableNet(input_dim=100, hidden_dims=[256, 128], num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(20):
loss = train_epoch(model, train_loader, criterion, optimizer)
if (epoch + 1) % 5 == 0:
acc = evaluate(model, test_loader)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}")
# 反復的枝刈り
print("\n" + "="*50)
print("Starting Iterative Pruning...")
print("="*50)
sparsities, accuracies, initial_acc = iterative_pruning(
model, train_loader, test_loader,
target_sparsity=0.9, pruning_steps=5, finetune_epochs=10
)
# 結果の可視化
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot([0] + sparsities, [initial_acc] + accuracies, 'o-', linewidth=2, markersize=8)
plt.xlabel('Sparsity')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Sparsity (Iterative Pruning)')
plt.grid(True, alpha=0.3)
plt.xlim(-0.05, 1.0)
plt.subplot(1, 2, 2)
remaining_params = [1 - s for s in [0] + sparsities]
plt.bar(range(len(remaining_params)), remaining_params, alpha=0.7)
plt.xlabel('Pruning Step')
plt.ylabel('Remaining Parameters (%)')
plt.title('Parameter Reduction')
plt.xticks(range(len(remaining_params)), ['Init'] + [f'Step {i+1}' for i in range(len(sparsities))])
plt.tight_layout()
plt.show()
構造化枝刈り(フィルタ単位)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class StructuredPruner:
"""構造化枝刈り(フィルタ/チャネル単位)"""
def __init__(self, model):
self.model = model
def compute_filter_importance(self, conv_layer):
"""
フィルタの重要度を計算(L1ノルムベース)
Args:
conv_layer: 畳み込み層
Returns:
importance: 各フィルタの重要度スコア
"""
# 重み形状: (out_channels, in_channels, H, W)
weight = conv_layer.weight.data
# 各フィルタのL1ノルム
importance = weight.abs().sum(dim=(1, 2, 3))
return importance
def prune_conv_layer(self, conv_layer, bn_layer, prune_ratio):
"""
畳み込み層を構造化枝刈り
Args:
conv_layer: 畳み込み層
bn_layer: バッチ正規化層(存在する場合)
prune_ratio: 枝刈りするフィルタの割合
Returns:
indices_to_keep: 保持するフィルタのインデックス
"""
importance = self.compute_filter_importance(conv_layer)
n_filters = len(importance)
n_to_prune = int(n_filters * prune_ratio)
n_to_keep = n_filters - n_to_prune
# 重要度の高いフィルタを保持
_, indices_to_keep = torch.topk(importance, n_to_keep)
indices_to_keep = indices_to_keep.sort().values
return indices_to_keep
class SimpleCNN(nn.Module):
"""構造化枝刈りのデモ用CNN"""
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# 構造化枝刈りのデモ
torch.manual_seed(42)
model = SimpleCNN()
print("Before Pruning:")
print(f" conv1: {model.conv1.weight.shape}")
print(f" conv2: {model.conv2.weight.shape}")
pruner = StructuredPruner(model)
# conv1のフィルタ重要度
importance = pruner.compute_filter_importance(model.conv1)
print(f"\nconv1 filter importance (first 10): {importance[:10].numpy().round(3)}")
# 上位50%のフィルタを保持
indices = pruner.prune_conv_layer(model.conv1, model.bn1, prune_ratio=0.5)
print(f"\nFilters to keep: {indices.numpy()}")
print(f"Keeping {len(indices)} out of {model.conv1.out_channels} filters")
枝刈りの実践的なヒント
1. スパース率の選択
| タスク | 推奨スパース率 |
|---|---|
| 画像分類(軽い) | 50-70% |
| 画像分類(重い) | 70-90% |
| 物体検出 | 50-80% |
| 言語モデル | 50-90% |
2. 枝刈りの順序
一般的に、以下の層は枝刈りしにくい: – 最初の層(入力に近い) – 最後の層(出力に近い) – バッチ正規化層
3. Fine-tuning の重要性
枝刈り後のFine-tuningは精度回復に不可欠です。学習率は元の1/10程度から開始するのが一般的です。
まとめ
本記事では、モデル枝刈り(Pruning)について解説しました。
- 枝刈りは不要なパラメータを削除してモデルを軽量化する技術
- 非構造化枝刈りは高圧縮率、構造化枝刈りはハードウェア効率が高い
- Magnitude-based Pruningは最もシンプルで効果的な手法の一つ
- 反復的枝刈り + Fine-tuning で精度を維持しながら高い圧縮率を達成
次のステップとして、以下の記事も参考にしてください。