【モデル圧縮】Pruning(枝刈り)の理論と実装

モデル枝刈り(Pruning)は、ニューラルネットワークの不要なパラメータを削除してモデルを軽量化する技術です。推論速度の向上とメモリ使用量の削減を実現しながら、精度の低下を最小限に抑えることができます。

本記事では、枝刈りの理論的背景から実装まで詳しく解説します。

本記事の内容

  • 枝刈りの基本概念と分類
  • 重要度スコアの計算方法
  • PyTorchでの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

枝刈りとは

基本的なアイデア

ニューラルネットワークには、予測性能にほとんど寄与しない「冗長な」パラメータが多数存在します。枝刈りは、これらの冗長なパラメータを特定して削除(ゼロに設定)することで、モデルを効率化します。

宝くじ仮説(Lottery Ticket Hypothesis)

2019年に発表された「宝くじ仮説」によると:

密なニューラルネットワークには、独立に学習させても元のネットワークと同等の性能を達成できる「当たりくじ」(サブネットワーク)が存在する

この発見は、枝刈りの理論的な裏付けを与えました。

枝刈りの分類

1. 粒度による分類

種類 説明 利点 欠点
非構造化(Unstructured) 個々の重みを削除 高い圧縮率 ハードウェア効率が低い
構造化(Structured) フィルタ/チャネル単位で削除 ハードウェア効率が高い 圧縮率は低い

2. タイミングによる分類

種類 説明
学習後枝刈り 学習済みモデルを枝刈り → 再学習
学習中枝刈り 学習しながら徐々に枝刈り
学習前枝刈り 初期化時に枝刈り(宝くじ仮説)

重要度スコアの計算

枝刈りの核心は、「どのパラメータを削除するか」を決める重要度スコアの計算です。

Magnitude-based Pruning

最もシンプルな手法は、重みの絶対値を重要度とする方法です:

$$ s_i = |w_i| $$

スコアが閾値 $\theta$ 以下の重みを削除:

$$ w_i’ = \begin{cases} w_i & \text{if } |w_i| > \theta \\ 0 & \text{otherwise} \end{cases} $$

Gradient-based Pruning

勾配情報を使う手法もあります。重みと勾配の積を重要度とします:

$$ s_i = \left| w_i \cdot \frac{\partial \mathcal{L}}{\partial w_i} \right| $$

これは、その重みがどれだけ損失に影響するかを表します。

Taylor Expansion based Pruning

損失関数のテイラー展開に基づく手法です。重み $w_i$ を0にしたときの損失変化を近似:

$$ \Delta \mathcal{L}_i = \mathcal{L}(w_i = 0) – \mathcal{L}(w_i) $$

1次のテイラー展開で近似すると:

$$ \Delta \mathcal{L}_i \approx -w_i \cdot \frac{\partial \mathcal{L}}{\partial w_i} $$

損失増加を最小化したいので、$|\Delta \mathcal{L}_i|$ が小さい重みを削除します。

Hessian-based Pruning

2次の情報(ヘッシアン)を使う手法もあります。Optimal Brain Damage(OBD)では:

$$ \Delta \mathcal{L}_i \approx \frac{1}{2} w_i^2 \cdot \frac{\partial^2 \mathcal{L}}{\partial w_i^2} $$

ヘッシアンの対角成分と重みの2乗の積で重要度を評価します。

枝刈りのアルゴリズム

基本的な枝刈りの流れ

  1. モデルを学習(または学習済みモデルを読み込み)
  2. 重要度スコアを計算
  3. スコアが低いパラメータをゼロに設定
  4. (オプション)Fine-tuning で精度を回復
  5. 必要に応じて2-4を繰り返す

反復的枝刈り(Iterative Pruning)

一度に大量のパラメータを削除すると精度が大きく低下します。反復的枝刈りでは:

  1. 少量(例: 10%)を枝刈り
  2. Fine-tuning
  3. 目標の圧縮率に達するまで繰り返し

Pythonでの実装

Magnitude-based Pruning の実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class MagnitudePruner:
    """Magnitude-based Pruning の実装"""

    def __init__(self, model):
        self.model = model

    def compute_mask(self, sparsity):
        """
        指定されたスパース率に基づいてマスクを計算

        Args:
            sparsity: 削除する重みの割合 (0-1)
        Returns:
            masks: 各層のマスク辞書
        """
        masks = {}

        # 全ての重みを収集
        all_weights = []
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                all_weights.append(param.data.abs().view(-1))

        # 全重みを結合して閾値を計算
        all_weights = torch.cat(all_weights)
        threshold = torch.quantile(all_weights, sparsity)

        # 各層のマスクを作成
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                mask = (param.data.abs() > threshold).float()
                masks[name] = mask

        return masks

    def apply_mask(self, masks):
        """マスクを適用して重みを枝刈り"""
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in masks:
                    param.data *= masks[name]

    def get_sparsity(self):
        """現在のスパース率を計算"""
        total_params = 0
        zero_params = 0

        for name, param in self.model.named_parameters():
            if 'weight' in name and param.dim() > 1:
                total_params += param.numel()
                zero_params += (param.data == 0).sum().item()

        return zero_params / total_params if total_params > 0 else 0

# 動作確認用のシンプルなネットワーク
class SimpleNet(nn.Module):
    def __init__(self, input_dim=784, hidden_dims=[512, 256], num_classes=10):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# 枝刈りのデモ
torch.manual_seed(42)
model = SimpleNet()

# 枝刈り前のパラメータ分布
weights_before = []
for name, param in model.named_parameters():
    if 'weight' in name and param.dim() > 1:
        weights_before.extend(param.data.view(-1).numpy())

pruner = MagnitudePruner(model)

# 50%のスパース率で枝刈り
masks = pruner.compute_mask(sparsity=0.5)
pruner.apply_mask(masks)

print(f"Sparsity after pruning: {pruner.get_sparsity():.2%}")

# 枝刈り後のパラメータ分布
weights_after = []
for name, param in model.named_parameters():
    if 'weight' in name and param.dim() > 1:
        weights_after.extend(param.data.view(-1).numpy())

# 可視化
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(weights_before, bins=100, alpha=0.7, color='blue')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Count')
axes[0].set_title('Weight Distribution (Before Pruning)')
axes[0].set_yscale('log')

axes[1].hist(weights_after, bins=100, alpha=0.7, color='red')
axes[1].set_xlabel('Weight Value')
axes[1].set_ylabel('Count')
axes[1].set_title('Weight Distribution (After 50% Pruning)')
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

反復的枝刈り + Fine-tuning

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

def create_synthetic_dataset(n_samples=5000, input_dim=100, num_classes=10):
    """合成データセットの作成"""
    np.random.seed(42)

    X = np.random.randn(n_samples, input_dim).astype(np.float32)
    W_true = np.random.randn(input_dim, num_classes).astype(np.float32)
    logits = X @ W_true + np.random.randn(n_samples, num_classes).astype(np.float32) * 0.5
    y = np.argmax(logits, axis=1)

    split = int(0.8 * n_samples)
    return (X[:split], y[:split]), (X[split:], y[split:])

class PrunableNet(nn.Module):
    """枝刈り対応ネットワーク"""

    def __init__(self, input_dim=100, hidden_dims=[256, 128], num_classes=10):
        super().__init__()
        self.layers = nn.ModuleList()

        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            self.layers.append(nn.Linear(prev_dim, hidden_dim))
            prev_dim = hidden_dim
        self.output = nn.Linear(prev_dim, num_classes)

        # マスクを保持
        self.masks = {}

    def forward(self, x):
        for layer in self.layers:
            x = F.relu(layer(x))
        return self.output(x)

    def apply_masks(self):
        """保持しているマスクを適用"""
        with torch.no_grad():
            for name, param in self.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

def train_epoch(model, train_loader, criterion, optimizer):
    """1エポックの学習"""
    model.train()
    total_loss = 0

    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        # 枝刈り後はマスクを再適用(勾配更新で復活しないように)
        model.apply_masks()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def evaluate(model, test_loader):
    """評価"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

    return correct / total

def iterative_pruning(model, train_loader, test_loader,
                       target_sparsity=0.9, pruning_steps=10, finetune_epochs=5):
    """反復的枝刈り"""
    criterion = nn.CrossEntropyLoss()

    # 初期精度
    initial_acc = evaluate(model, test_loader)
    print(f"Initial Accuracy: {initial_acc:.4f}")

    sparsities = []
    accuracies = []

    # 各ステップでの枝刈り率
    current_sparsity = 0.0
    step_sparsity = 1.0 - (1.0 - target_sparsity) ** (1.0 / pruning_steps)

    for step in range(pruning_steps):
        # 累積スパース率の計算
        current_sparsity = 1.0 - (1.0 - step_sparsity) ** (step + 1)

        # マスクの計算と適用
        pruner = MagnitudePruner(model)
        masks = pruner.compute_mask(current_sparsity)

        # モデルにマスクを保存
        model.masks = masks
        model.apply_masks()

        actual_sparsity = pruner.get_sparsity()
        print(f"\nStep {step + 1}/{pruning_steps}: Target sparsity = {current_sparsity:.2%}, Actual = {actual_sparsity:.2%}")

        # Fine-tuning
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        for epoch in range(finetune_epochs):
            loss = train_epoch(model, train_loader, criterion, optimizer)

        acc = evaluate(model, test_loader)
        print(f"  Accuracy after fine-tuning: {acc:.4f}")

        sparsities.append(actual_sparsity)
        accuracies.append(acc)

    return sparsities, accuracies, initial_acc

# データ準備
(X_train, y_train), (X_test, y_test) = create_synthetic_dataset()

train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# モデルの学習
print("Training initial model...")
model = PrunableNet(input_dim=100, hidden_dims=[256, 128], num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(20):
    loss = train_epoch(model, train_loader, criterion, optimizer)
    if (epoch + 1) % 5 == 0:
        acc = evaluate(model, test_loader)
        print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}")

# 反復的枝刈り
print("\n" + "="*50)
print("Starting Iterative Pruning...")
print("="*50)

sparsities, accuracies, initial_acc = iterative_pruning(
    model, train_loader, test_loader,
    target_sparsity=0.9, pruning_steps=5, finetune_epochs=10
)

# 結果の可視化
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot([0] + sparsities, [initial_acc] + accuracies, 'o-', linewidth=2, markersize=8)
plt.xlabel('Sparsity')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Sparsity (Iterative Pruning)')
plt.grid(True, alpha=0.3)
plt.xlim(-0.05, 1.0)

plt.subplot(1, 2, 2)
remaining_params = [1 - s for s in [0] + sparsities]
plt.bar(range(len(remaining_params)), remaining_params, alpha=0.7)
plt.xlabel('Pruning Step')
plt.ylabel('Remaining Parameters (%)')
plt.title('Parameter Reduction')
plt.xticks(range(len(remaining_params)), ['Init'] + [f'Step {i+1}' for i in range(len(sparsities))])

plt.tight_layout()
plt.show()

構造化枝刈り(フィルタ単位)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class StructuredPruner:
    """構造化枝刈り(フィルタ/チャネル単位)"""

    def __init__(self, model):
        self.model = model

    def compute_filter_importance(self, conv_layer):
        """
        フィルタの重要度を計算(L1ノルムベース)

        Args:
            conv_layer: 畳み込み層
        Returns:
            importance: 各フィルタの重要度スコア
        """
        # 重み形状: (out_channels, in_channels, H, W)
        weight = conv_layer.weight.data

        # 各フィルタのL1ノルム
        importance = weight.abs().sum(dim=(1, 2, 3))

        return importance

    def prune_conv_layer(self, conv_layer, bn_layer, prune_ratio):
        """
        畳み込み層を構造化枝刈り

        Args:
            conv_layer: 畳み込み層
            bn_layer: バッチ正規化層(存在する場合)
            prune_ratio: 枝刈りするフィルタの割合
        Returns:
            indices_to_keep: 保持するフィルタのインデックス
        """
        importance = self.compute_filter_importance(conv_layer)
        n_filters = len(importance)
        n_to_prune = int(n_filters * prune_ratio)
        n_to_keep = n_filters - n_to_prune

        # 重要度の高いフィルタを保持
        _, indices_to_keep = torch.topk(importance, n_to_keep)
        indices_to_keep = indices_to_keep.sort().values

        return indices_to_keep

class SimpleCNN(nn.Module):
    """構造化枝刈りのデモ用CNN"""

    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# 構造化枝刈りのデモ
torch.manual_seed(42)
model = SimpleCNN()

print("Before Pruning:")
print(f"  conv1: {model.conv1.weight.shape}")
print(f"  conv2: {model.conv2.weight.shape}")

pruner = StructuredPruner(model)

# conv1のフィルタ重要度
importance = pruner.compute_filter_importance(model.conv1)
print(f"\nconv1 filter importance (first 10): {importance[:10].numpy().round(3)}")

# 上位50%のフィルタを保持
indices = pruner.prune_conv_layer(model.conv1, model.bn1, prune_ratio=0.5)
print(f"\nFilters to keep: {indices.numpy()}")
print(f"Keeping {len(indices)} out of {model.conv1.out_channels} filters")

枝刈りの実践的なヒント

1. スパース率の選択

タスク 推奨スパース率
画像分類(軽い) 50-70%
画像分類(重い) 70-90%
物体検出 50-80%
言語モデル 50-90%

2. 枝刈りの順序

一般的に、以下の層は枝刈りしにくい: – 最初の層(入力に近い) – 最後の層(出力に近い) – バッチ正規化層

3. Fine-tuning の重要性

枝刈り後のFine-tuningは精度回復に不可欠です。学習率は元の1/10程度から開始するのが一般的です。

まとめ

本記事では、モデル枝刈り(Pruning)について解説しました。

  • 枝刈りは不要なパラメータを削除してモデルを軽量化する技術
  • 非構造化枝刈りは高圧縮率、構造化枝刈りはハードウェア効率が高い
  • Magnitude-based Pruningは最もシンプルで効果的な手法の一つ
  • 反復的枝刈り + Fine-tuning で精度を維持しながら高い圧縮率を達成

次のステップとして、以下の記事も参考にしてください。