Mixture of Experts (MoE) の仕組みとゲーティング機構

Mixture of Experts(MoE)は、ニューラルネットワークの効率を大幅に向上させるアーキテクチャです。モデルに複数の「エキスパート」ネットワークを持たせ、入力に応じて一部のエキスパートのみを活性化することで、パラメータ数を増やしながら計算コストを抑えます。

近年、Mixtral 8x7B や GPT-4 など、最先端のLLMでMoEが採用されています。本記事では、MoEの原理、ゲーティング機構、学習の課題と解決策、実装の考え方を解説します。

本記事の内容

  • MoEの基本概念と動機
  • ゲーティング機構とルーティング
  • 負荷分散の課題と対策
  • Mixtralのアーキテクチャ
  • PyTorchでの実装の考え方

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

MoEの基本概念

動機:パラメータ数と計算量のトレードオフ

通常の密(dense)モデルでは、パラメータ数を増やすと計算量も比例して増加します。

$$ \text{計算量} \propto \text{パラメータ数} $$

しかし、モデルの性能はパラメータ数に依存することが知られています(スケーリング則)。より大きなモデルを効率的に作れないでしょうか?

MoEの解決策

MoEは、疎な活性化(sparse activation)により、パラメータ数と計算量を分離します。

$$ \text{総パラメータ数} = N \times \text{エキスパートあたりのパラメータ数} $$

$$ \text{推論時の計算量} \propto k \times \text{エキスパートあたりのパラメータ数} $$

ここで $N$ は総エキスパート数、$k$ は活性化されるエキスパート数($k \ll N$)です。

例えば、8つのエキスパートがあり、各入力で2つだけ活性化する場合: – パラメータ数: 8倍 – 計算量: 2倍

MoEの構造

TransformerにおけるMoEは、各レイヤーのFFN(Feed-Forward Network)を複数のエキスパートに置き換えます。

標準Transformer:
    Self-Attention → FFN → Output

MoE Transformer:
    Self-Attention → Router → Expert 1, 2, ..., N → Combine → Output

ゲーティング機構

ソフトマックスゲーティング

最も基本的なゲーティングは、線形変換 + ソフトマックスです。

入力 $\bm{x} \in \mathbb{R}^d$ に対して:

$$ \bm{g} = \text{softmax}(\bm{W}_g \bm{x}) $$

ここで $\bm{W}_g \in \mathbb{R}^{N \times d}$ はゲーティング重み、$\bm{g} \in \mathbb{R}^N$ は各エキスパートへのルーティング確率です。

出力: $$ \bm{y} = \sum_{i=1}^{N} g_i \cdot E_i(\bm{x}) $$

ここで $E_i$ は $i$ 番目のエキスパートネットワークです。

Top-kゲーティング

すべてのエキスパートを使うと計算量が増えるため、上位 $k$ 個のエキスパートのみを活性化します。

$$ \text{TopK}(\bm{g}, k) = \{i : g_i \text{ is among the top-}k \text{ values}\} $$

$$ \bm{y} = \sum_{i \in \text{TopK}(\bm{g}, k)} \tilde{g}_i \cdot E_i(\bm{x}) $$

ここで $\tilde{g}_i$ は選択されたエキスパート間で再正規化されたゲート値:

$$ \tilde{g}_i = \frac{g_i}{\sum_{j \in \text{TopK}(\bm{g}, k)} g_j} $$

ノイズ付きTop-k(Noisy Top-k)

Shazeer et al. (2017) は、ゲーティングにノイズを追加することで探索を促進する手法を提案しました。

$$ \bm{h} = \bm{W}_g \bm{x} + \text{StandardNormal}() \cdot \text{Softplus}(\bm{W}_{\text{noise}} \bm{x}) $$

$$ \bm{g} = \text{softmax}(\text{KeepTopK}(\bm{h}, k)) $$

ノイズにより、学習中に様々なエキスパートが試されます。

負荷分散

問題:エキスパートの崩壊

ナイーブなMoEでは、一部のエキスパートに入力が集中し、他のエキスパートが使われなくなる崩壊(collapse)が発生します。

原因: 1. ゲーティングは微分可能なので、最初に強くなったエキスパートがさらに強化される 2. 使われないエキスパートは学習されず、さらに使われなくなる悪循環

解決策1:補助損失(Auxiliary Loss)

負荷分散を促進するための追加の損失関数を導入します。

重要度損失(Importance Loss):

各エキスパートへの総ルーティング重みのばらつきを抑制:

$$ L_{\text{importance}} = \text{CV}\left(\sum_{x \in \text{batch}} \bm{g}(x)\right)^2 $$

ここで CV は変動係数(coefficient of variation)= 標準偏差 / 平均 です。

負荷損失(Load Loss):

各エキスパートに実際にルーティングされるトークン数のばらつきを抑制:

$$ L_{\text{load}} = N \cdot \sum_{i=1}^{N} f_i \cdot P_i $$

ここで: – $f_i$: エキスパート $i$ に送られるトークンの割合 – $P_i$: エキスパート $i$ のゲート確率の平均

理想的には $f_i = P_i = 1/N$ です。

解決策2:Expert Choice(EC)

Zhou et al. (2022) は、トークンがエキスパートを選ぶのではなく、エキスパートがトークンを選ぶ手法を提案しました。

各エキスパートは、バッチ内から固定数のトークンを選びます:

  1. ゲートスコア $\bm{S} = \bm{W}_g \bm{X}^\top \in \mathbb{R}^{N \times B}$ を計算
  2. 各エキスパート(行)について、上位 $C$ 個のトークンを選択
  3. 選択されたトークンに対してエキスパートを適用

$C$ はキャパシティ(各エキスパートが処理するトークン数)です。

利点: – 各エキスパートの負荷が自動的に均等化 – ドロップされるトークンがない

Mixtralのアーキテクチャ

概要

Mixtral 8x7B は、Mistral AI が開発したMoEモデルです。

  • 総パラメータ数: 約47B
  • 活性化パラメータ数: 約13B(推論時)
  • エキスパート数: 8
  • 活性化エキスパート数: 2
  • ベース: Mistral 7B と同じアーキテクチャ

構造

各Transformerレイヤー: – Self-Attention: 標準(非MoE) – FFN: MoE(8エキスパート、Top-2ルーティング)

Input → RMSNorm → Self-Attention → Residual
      → RMSNorm → MoE(8 experts, Top-2) → Residual → Output

ルーターの設計

Mixtralのルーター:

$$ \bm{g} = \text{softmax}(\bm{W}_g \bm{x}) $$

$$ \text{output} = \sum_{i \in \text{Top2}(\bm{g})} g_i \cdot E_i(\bm{x}) $$

シンプルなソフトマックス + Top-2 選択で、ノイズは使用していません。

PyTorchでの実装の考え方

MoEレイヤーの実装

import torch
import torch.nn as nn
import torch.nn.functional as F


class Expert(nn.Module):
    """単一のエキスパート(FFN)"""

    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff, bias=False)
        self.fc2 = nn.Linear(d_ff, d_model, bias=False)
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(x)))


class MoELayer(nn.Module):
    """Mixture of Expertsレイヤー"""

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int,
        top_k: int = 2,
        aux_loss_weight: float = 0.01,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        self.aux_loss_weight = aux_loss_weight

        # エキスパートネットワーク
        self.experts = nn.ModuleList([
            Expert(d_model, d_ff) for _ in range(num_experts)
        ])

        # ルーター(ゲーティングネットワーク)
        self.router = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x: torch.Tensor) -> tuple:
        """
        Args:
            x: (batch_size, seq_len, d_model)

        Returns:
            output: (batch_size, seq_len, d_model)
            aux_loss: 補助損失(負荷分散用)
        """
        batch_size, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)  # (batch * seq, d_model)
        num_tokens = x_flat.shape[0]

        # ルーターでゲート確率を計算
        router_logits = self.router(x_flat)  # (num_tokens, num_experts)
        router_probs = F.softmax(router_logits, dim=-1)

        # Top-k 選択
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        # (num_tokens, top_k)

        # 再正規化
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # 出力の初期化
        output = torch.zeros_like(x_flat)

        # 各エキスパートを適用
        for i in range(self.num_experts):
            # このエキスパートに送られるトークンを特定
            # どのトークンのどの位置でこのエキスパートが選ばれたか
            expert_mask = (top_k_indices == i)  # (num_tokens, top_k)

            if not expert_mask.any():
                continue

            # マスクされた位置のトークンと重みを取得
            token_indices = expert_mask.any(dim=-1).nonzero(as_tuple=True)[0]
            expert_input = x_flat[token_indices]

            # エキスパートを適用
            expert_output = self.experts[i](expert_input)

            # 重みを計算
            weights = (top_k_probs * expert_mask.float()).sum(dim=-1)
            weights = weights[token_indices].unsqueeze(-1)

            # 出力に加算
            output[token_indices] += weights * expert_output

        output = output.view(batch_size, seq_len, d_model)

        # 補助損失(負荷分散)
        aux_loss = self._compute_aux_loss(router_probs, top_k_indices)

        return output, aux_loss

    def _compute_aux_loss(
        self,
        router_probs: torch.Tensor,
        top_k_indices: torch.Tensor,
    ) -> torch.Tensor:
        """負荷分散のための補助損失を計算"""
        num_tokens = router_probs.shape[0]

        # 各エキスパートの平均ルーティング確率
        mean_probs = router_probs.mean(dim=0)  # (num_experts,)

        # 各エキスパートに実際にルーティングされる頻度
        one_hot = F.one_hot(top_k_indices, num_classes=self.num_experts).float()
        # (num_tokens, top_k, num_experts)
        tokens_per_expert = one_hot.sum(dim=(0, 1)) / (num_tokens * self.top_k)
        # (num_experts,)

        # 損失: 確率と頻度の積の和(均等な場合に最小)
        aux_loss = self.num_experts * (mean_probs * tokens_per_expert).sum()

        return aux_loss * self.aux_loss_weight


# 使用例
d_model = 256
d_ff = 512
num_experts = 8
top_k = 2

moe = MoELayer(d_model, d_ff, num_experts, top_k)

x = torch.randn(2, 10, d_model)  # (batch=2, seq=10, d_model)
output, aux_loss = moe(x)

print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")
print(f"補助損失: {aux_loss.item():.6f}")

効率的な実装(バッチ処理)

上記の実装は教育的ですが、非効率です。実際には、各エキスパートへのトークンをバッチ処理します。

class EfficientMoELayer(nn.Module):
    """より効率的なMoE実装"""

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int,
        top_k: int = 2,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k

        # 全エキスパートの重みを1つのテンソルに
        self.w1 = nn.Parameter(torch.randn(num_experts, d_model, d_ff) * 0.02)
        self.w2 = nn.Parameter(torch.randn(num_experts, d_ff, d_model) * 0.02)

        self.router = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape
        x_flat = x.view(-1, d_model)

        # ルーティング
        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # 各トップエキスパートの出力を計算
        # これはまだ最適化の余地がありますが、概念的な実装として
        outputs = []
        for k in range(self.top_k):
            expert_indices = top_k_indices[:, k]  # (num_tokens,)
            weights = top_k_probs[:, k]  # (num_tokens,)

            # 各エキスパートの出力を集める
            expert_out = torch.zeros_like(x_flat)
            for e in range(self.num_experts):
                mask = (expert_indices == e)
                if mask.any():
                    exp_input = x_flat[mask]
                    # FFN: SiLU(x @ w1) @ w2
                    h = F.silu(exp_input @ self.w1[e])
                    out = h @ self.w2[e]
                    expert_out[mask] = out

            outputs.append(weights.unsqueeze(-1) * expert_out)

        output = sum(outputs)
        return output.view(batch_size, seq_len, d_model)

可視化

エキスパートの選択パターン

import matplotlib.pyplot as plt
import numpy as np

# シミュレーションデータ
np.random.seed(42)
num_tokens = 100
num_experts = 8
top_k = 2

# ルーター出力をシミュレート
router_logits = np.random.randn(num_tokens, num_experts)
router_probs = np.exp(router_logits) / np.exp(router_logits).sum(axis=1, keepdims=True)

# Top-2 選択
top_k_indices = np.argsort(-router_probs, axis=1)[:, :top_k]

# 各エキスパートの選択回数をカウント
expert_counts = np.zeros(num_experts)
for indices in top_k_indices:
    for idx in indices:
        expert_counts[idx] += 1

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# エキスパートの選択分布
ax = axes[0]
ax.bar(range(num_experts), expert_counts / num_tokens * 100 / top_k, color='steelblue', alpha=0.7)
ax.axhline(y=100 / num_experts, color='red', linestyle='--', label='Ideal uniform')
ax.set_xlabel('Expert Index', fontsize=12)
ax.set_ylabel('Selection Frequency (%)', fontsize=12)
ax.set_title('Expert Selection Distribution', fontsize=14)
ax.legend()
ax.set_xticks(range(num_experts))

# ルーティング確率のヒートマップ
ax = axes[1]
im = ax.imshow(router_probs[:20], cmap='Blues', aspect='auto')
ax.set_xlabel('Expert Index', fontsize=12)
ax.set_ylabel('Token Index', fontsize=12)
ax.set_title('Router Probabilities (first 20 tokens)', fontsize=14)
ax.set_xticks(range(num_experts))
plt.colorbar(im, ax=ax, label='Probability')

plt.tight_layout()
plt.show()

MoE vs Dense モデルの比較

import matplotlib.pyplot as plt
import numpy as np

# モデルサイズと計算量の関係
params_dense = np.array([7, 13, 30, 65, 180])  # billion
flops_dense = params_dense  # 密モデルはパラメータ数に比例

# MoE: 8エキスパート、Top-2活性化
num_experts = 8
top_k = 2
# 共有部分(Attention等)が約半分、MoE部分が半分と仮定
shared_fraction = 0.5
params_moe = params_dense * (shared_fraction + (1 - shared_fraction) * num_experts)
flops_moe = params_dense * (shared_fraction + (1 - shared_fraction) * top_k)

fig, ax = plt.subplots(figsize=(10, 6))

ax.scatter(params_dense, flops_dense, s=150, marker='o', color='steelblue',
           label='Dense Model', zorder=3)
ax.scatter(params_moe, flops_moe, s=150, marker='s', color='coral',
           label='MoE Model (8 experts, Top-2)', zorder=3)

# 線で結ぶ
for pd, fd, pm, fm in zip(params_dense, flops_dense, params_moe, flops_moe):
    ax.plot([pd, pm], [fd, fm], 'k--', alpha=0.3)

# 注釈
for pd, fd in zip(params_dense, flops_dense):
    ax.annotate(f'{pd}B', (pd, fd), textcoords="offset points",
                xytext=(0, 10), ha='center', fontsize=9)
for pm, fm in zip(params_moe, flops_moe):
    ax.annotate(f'{pm:.0f}B\nparams', (pm, fm), textcoords="offset points",
                xytext=(15, -10), ha='left', fontsize=8, color='coral')

ax.set_xlabel('Total Parameters (Billions)', fontsize=12)
ax.set_ylabel('Compute (relative to 7B dense)', fontsize=12)
ax.set_title('MoE: More Parameters, Less Compute', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 350)
ax.set_ylim(0, 200)

plt.tight_layout()
plt.show()

MoEの課題と今後

課題

  1. 負荷分散の難しさ: 補助損失のチューニングが必要
  2. 通信コスト: 分散学習時、トークンを異なるGPUのエキスパートに送る必要がある
  3. バッチ効率: エキスパートごとのバッチサイズが不均一
  4. メモリ: 総パラメータ数が多く、推論時もすべてのエキスパートをメモリに載せる必要がある

最近の発展

  1. Expert Choice: エキスパートがトークンを選ぶことで負荷分散を改善
  2. Switch Transformer: Top-1 ルーティングで効率化
  3. Soft MoE: 離散的なルーティングを連続化
  4. MoE + 量子化: メモリ効率の改善

まとめ

本記事では、Mixture of Experts(MoE)のアーキテクチャについて解説しました。

  • 基本概念: 複数のエキスパートから一部のみを活性化し、パラメータ数と計算量を分離
  • ゲーティング機構: ソフトマックス + Top-k 選択で、入力に応じたエキスパートを選択
  • 負荷分散: 補助損失やExpert Choice手法で、エキスパートの崩壊を防ぐ
  • Mixtral: 8エキスパート、Top-2の実用的なMoEモデル
  • トレードオフ: より多くのパラメータ、効率的な計算、だが負荷分散や通信の課題あり

MoEは、大規模モデルを効率的に学習・推論するための有力なアーキテクチャです。今後も、より効率的なルーティングや負荷分散手法の研究が進むことが期待されます。

次のステップとして、以下の記事も参考にしてください。