Mixture of Experts(MoE)の理論と実装 — 条件付き計算でLLMを効率化する

GPT-4やGeminiのような最高性能のLLMは、数千億〜数兆のパラメータを持つと言われています。しかし、全てのパラメータを全ての入力に対して使う必要はあるのでしょうか。「数学の質問には数学が得意なパラメータだけを使い、プログラミングの質問にはコードが得意なパラメータだけを使う」ことができれば、同じ性能をより少ない計算で達成できるはずです。

この発想を実現するのがMixture of Experts(MoE, 混合エキスパート)です。MoEは、TransformerのFFN(Feed-Forward Network)層を複数のエキスパートに分割し、各入力トークンに対して少数のエキスパートのみを選択的に活性化する仕組みです。

例えば、Mixtral 8x7Bは合計47Bのパラメータを持ちますが、各トークンの処理では8つのエキスパートのうち2つのみを使用するため、実効的な計算量は13B相当です。にもかかわらず、ベンチマークではLLaMA-2 70Bに匹敵する性能を達成しています。

MoEを理解することは、以下のような場面で直接役立ちます。

  • 最新LLMの理解: GPT-4、Gemini 1.5、Mixtral、DBRX、Grok-1など、最先端のLLMの多くがMoEアーキテクチャを採用しています
  • 推論コストの最適化: MoEモデルの推論は、同等性能のdenseモデルより数分の1の計算量で済みます
  • モデル設計の指針: いつMoEを使うべきか、エキスパート数やTop-kの選択など、設計判断の基盤になります

本記事の内容

  • 条件付き計算(Conditional Computation)の概念
  • MoEの基本構造とルーティング機構の数式
  • Top-kゲーティングとソフトマックスルーティング
  • 負荷分散損失(Load Balancing Loss)
  • Expert Capacity と Token Dropping
  • Mixtralの設計と工夫
  • Pythonによるスクラッチ実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

画像なし
LLaMAアーキテクチャの設計思想
MoEが適用されるTransformerの基本構造を理解します
画像なし
Self-Attention機構の理論と実装
MoEと組み合わされるAttention機構を理解します

条件付き計算

Dense vs Sparse

従来のTransformerはdenseモデルです。全ての入力トークンに対して全てのパラメータが計算に参加します。モデルサイズを大きくすると性能は向上しますが、計算量もパラメータ数に比例して増加します。

$$ \text{FLOPs}_{\text{dense}} \propto N_{\text{params}} $$

Sparseモデル(MoE)では、各入力トークンに対して一部のパラメータのみが活性化されます。モデルの総パラメータ数を大きくしても、1トークンあたりの計算量は制御可能です。

$$ \text{FLOPs}_{\text{MoE}} \propto \frac{k}{E} \times N_{\text{params}} $$

ここで $E$ はエキスパート数、$k$ はトークンあたりに選択されるエキスパート数です。$E = 8, k = 2$ の場合、1トークンあたりの計算量は dense の約 $\frac{2}{8} = 25\%$ です。

これにより、より多くのパラメータ(= より多くの知識を格納する容量)を持ちながら、推論コストを抑えることが可能になります。

MoEの歴史

MoEの概念は1991年のJacobs et al.に遡りますが、現代的なTransformer MoEの系譜は以下の通りです。

手法/モデル 主な貢献
2017 Shazeer et al. Transformer + MoE (Top-2 routing)
2021 Switch Transformer Top-1 routing, 大規模MoE
2022 ST-MoE 安定化技術のまとめ
2024 Mixtral 8x7B 初のオープンソース高品質MoE LLM
2024 DeepSeek-V2 Shared Expert + Fine-grained Expert

では、MoEの具体的な構造を見ていきましょう。

MoEの基本構造

TransformerのどこにMoEを入れるか

MoEは通常、TransformerブロックのFFN層を置き換えます。Attention層は全トークン間の関係を捉える必要があるためdenseのまま残し、FFN層だけをMoE化します。

標準的なTransformerブロック:

x → Attention → Add & Norm → FFN → Add & Norm → y

MoE Transformerブロック:

x → Attention → Add & Norm → MoE(FFN₁, FFN₂, ..., FFNₑ) → Add & Norm → y

ゲーティングネットワーク(ルーター)

MoE層の核心はゲーティングネットワーク(ルーター)です。入力トークン $\bm{x} \in \mathbb{R}^d$ に対して、どのエキスパートを使うかを決定します。

$$ \bm{g}(\bm{x}) = \text{softmax}(\bm{W}_g \bm{x}) \in \mathbb{R}^E $$

ここで $\bm{W}_g \in \mathbb{R}^{E \times d}$ はゲーティングの重み行列、$E$ はエキスパート数です。$\bm{g}(\bm{x})_i$ は「トークン $\bm{x}$ がエキスパート $i$ にルーティングされる確率」を表します。

Top-k ゲーティング

全てのエキスパートの出力を重み付き和で計算すると、計算量はdenseモデルと変わりません。そこで、上位 $k$ 個のエキスパートのみを選択します。

$$ \text{TopK}(\bm{g}(\bm{x})) = \begin{cases} g_i(\bm{x}) & \text{if } g_i(\bm{x}) \in \text{Top-}k \\ 0 & \text{otherwise} \end{cases} $$

正規化して重みの和を1にします:

$$ \tilde{g}_i(\bm{x}) = \frac{\text{TopK}(g_i(\bm{x}))}{\sum_{j} \text{TopK}(g_j(\bm{x}))} $$

MoE層の出力

MoE層の出力は、選択されたエキスパートの出力の重み付き和です:

$$ \text{MoE}(\bm{x}) = \sum_{i=1}^{E} \tilde{g}_i(\bm{x}) \cdot \text{FFN}_i(\bm{x}) $$

実際には $\tilde{g}_i(\bm{x}) = 0$ のエキスパートは計算をスキップするので、計算量は $k$ 個のエキスパートの分だけです。

具体例

Mixtral 8x7Bの場合: – $E = 8$(8個のエキスパート) – $k = 2$(各トークンで2個を選択) – 各エキスパートは LLaMA-7B の FFN と同じサイズ – 総パラメータ: 8 × 7B(FFN) + Attention等の共有部 ≈ 47B – 1トークンあたりの計算量: 2 × 7B(FFN) + Attention等 ≈ 13B 相当

しかし、ルーティングには根本的な問題が潜んでいます。全てのトークンが同じエキスパートに集中する「ルーティング崩壊」が起きやすいのです。次にこの問題とその解決策を見ましょう。

負荷分散の問題と解決

ルーティング崩壊

学習初期に、あるエキスパートが偶然他のエキスパートより良い重みを持つと、多くのトークンがそのエキスパートにルーティングされます。すると、そのエキスパートはさらに多くのデータで学習されて性能が上がり、ますます多くのトークンが集中します。この正のフィードバックループにより、最終的に少数のエキスパートだけが使われ、残りは死んだ状態になります。

これがルーティング崩壊(routing collapse)であり、MoEの最大の課題です。

負荷分散損失(Auxiliary Load Balancing Loss)

この問題を解決するために、負荷分散損失(auxiliary load balancing loss)を追加します。

各エキスパート $i$ について、以下の2つの量を定義します。

ディスパッチ率 $f_i$: バッチ内の全トークンのうち、エキスパート $i$ にルーティングされた割合:

$$ f_i = \frac{1}{T} \sum_{t=1}^{T} \mathbb{1}[i \in \text{TopK}(\bm{g}(\bm{x}_t))] $$

ルーティング確率の平均 $p_i$: バッチ内の全トークンに対するゲート確率の平均:

$$ p_i = \frac{1}{T} \sum_{t=1}^{T} g_i(\bm{x}_t) $$

負荷分散損失は:

$$ \mathcal{L}_{\text{balance}} = \alpha \cdot E \cdot \sum_{i=1}^{E} f_i \cdot p_i $$

$\alpha$ はバランス損失の係数(通常 $0.01$ 程度)です。

この損失の意味を理解しましょう。完全に均一にルーティングされている場合、$f_i = k/E$、$p_i = 1/E$ で、$\sum f_i p_i = k/E^2$ が最小値になります。あるエキスパートにトークンが集中すると、そのエキスパートの $f_i \cdot p_i$ が大きくなり、損失が増加します。$\sum f_i p_i$ は内積なので、コーシー・シュワルツの不等式から均一分布が最小値を達成します。

Expert Capacity

もう一つの負荷制御手法がExpert Capacity(エキスパート容量)です。各エキスパートが1バッチで処理できるトークン数に上限を設けます。

$$ C = \left\lceil \frac{k \cdot T}{E} \right\rceil \cdot c_f $$

ここで $T$ はバッチ内の総トークン数、$c_f$ はキャパシティ倍率(通常1.0〜1.5)です。容量を超えたトークンはドロップ(残差接続のみでFFNをスキップ)されます。

Mixtralの設計

アーキテクチャの詳細

Mixtral 8x7B(Jiang et al., 2024)は、実用的なMoE LLMとして以下の設計選択をしています。

パラメータ
層数 32
隠れ次元 4096
ヘッド数(Q) 32
ヘッド数(KV) 8 (GQA)
エキスパート数 8
Top-k 2
FFN中間次元 14336
総パラメータ 46.7B
活性パラメータ 12.9B

Mixtralの工夫

1. 全てのFFN層をMoE化: 一部の層だけでなく、全32層のFFN層をMoE化しています。これにより、各層で異なるエキスパートの組み合わせが選択され、多様な専門化パターンが学習されます。

2. Shared components: Attention層(Self-Attention + KVキャッシュ)と埋め込み層はエキスパート間で共有されます。パラメータの大部分(FFN)がMoE化されつつ、Attention の情報統合機能は全トークンで共通です。

3. SwiGLU活性化: 各エキスパートのFFN は LLaMA と同じ SwiGLU 活性化を使用します。

次に、MoEの仕組みをPythonで実装して理解を深めましょう。

Pythonによるスクラッチ実装

MoE層の実装

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

class Expert:
    """1つのFFNエキスパート。"""

    def __init__(self, d_model, d_ff):
        # SwiGLU風の2つの線形層
        self.W1 = np.random.randn(d_model, d_ff) * 0.02
        self.W2 = np.random.randn(d_ff, d_model) * 0.02
        self.W_gate = np.random.randn(d_model, d_ff) * 0.02

    def forward(self, x):
        """SwiGLU: gate(xW_gate) * (xW1) → W2"""
        gate = x @ self.W_gate
        gate = gate * (1 / (1 + np.exp(-gate)))  # SiLU activation
        hidden = gate * (x @ self.W1)
        return hidden @ self.W2


class MoELayer:
    """Mixture of Experts層。"""

    def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
        self.num_experts = num_experts
        self.top_k = top_k
        self.d_model = d_model

        # エキスパート
        self.experts = [Expert(d_model, d_ff) for _ in range(num_experts)]

        # ゲーティングネットワーク
        self.W_gate = np.random.randn(d_model, num_experts) * 0.1

    def route(self, x):
        """トークンをエキスパートにルーティングする。

        Parameters
        ----------
        x : np.ndarray, shape (batch, d_model)

        Returns
        -------
        indices : np.ndarray, shape (batch, top_k) — 選択されたエキスパートの番号
        weights : np.ndarray, shape (batch, top_k) — 正規化された重み
        gate_probs : np.ndarray, shape (batch, num_experts) — 全ゲート確率
        """
        # ゲートロジット
        logits = x @ self.W_gate  # (batch, num_experts)

        # ソフトマックス
        logits_max = np.max(logits, axis=1, keepdims=True)
        exp_logits = np.exp(logits - logits_max)
        gate_probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

        # Top-k選択
        indices = np.argsort(-gate_probs, axis=1)[:, :self.top_k]

        # 選択されたエキスパートの重みを正規化
        weights = np.zeros((x.shape[0], self.top_k))
        for i in range(x.shape[0]):
            selected_probs = gate_probs[i, indices[i]]
            weights[i] = selected_probs / np.sum(selected_probs)

        return indices, weights, gate_probs

    def forward(self, x):
        """MoE層の前方計算。"""
        batch_size = x.shape[0]
        indices, weights, gate_probs = self.route(x)

        output = np.zeros_like(x)
        for i in range(batch_size):
            for j in range(self.top_k):
                expert_idx = indices[i, j]
                expert_output = self.experts[expert_idx].forward(x[i:i+1])
                output[i] += weights[i, j] * expert_output[0]

        return output, indices, gate_probs

    def load_balancing_loss(self, gate_probs, indices):
        """負荷分散損失を計算する。"""
        batch_size = gate_probs.shape[0]

        # ディスパッチ率 f_i
        f = np.zeros(self.num_experts)
        for i in range(batch_size):
            for j in range(self.top_k):
                f[indices[i, j]] += 1
        f /= batch_size

        # ルーティング確率の平均 p_i
        p = np.mean(gate_probs, axis=0)

        # 負荷分散損失
        loss = self.num_experts * np.sum(f * p)
        return loss, f, p


# MoE層の動作確認
d_model = 64
d_ff = 128
num_experts = 8
top_k = 2
batch_size = 256

moe = MoELayer(d_model, d_ff, num_experts, top_k)
x = np.random.randn(batch_size, d_model).astype(np.float64) * 0.5

output, indices, gate_probs = moe.forward(x)
lb_loss, f, p = moe.load_balancing_loss(gate_probs, indices)

print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")
print(f"負荷分散損失: {lb_loss:.4f} (理想値: {top_k:.4f})")
print(f"エキスパートごとのディスパッチ率: {f}")

ルーティングパターンの可視化

# ルーティングパターンの分析
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 左: エキスパートごとのトークン割り当て数
expert_counts = np.zeros(num_experts)
for i in range(batch_size):
    for j in range(top_k):
        expert_counts[indices[i, j]] += 1

colors = plt.cm.Set3(np.linspace(0, 1, num_experts))
bars = axes[0].bar(range(num_experts), expert_counts, color=colors)
axes[0].axhline(y=batch_size * top_k / num_experts, color='red',
                linestyle='--', label=f'Ideal ({batch_size*top_k/num_experts:.0f})')
axes[0].set_xlabel('Expert Index')
axes[0].set_ylabel('Number of Tokens Assigned')
axes[0].set_title('Token Assignment per Expert')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')

# 中央: ゲーティング確率のヒートマップ(上位50トークン)
im = axes[1].imshow(gate_probs[:50], aspect='auto', cmap='YlOrRd')
axes[1].set_xlabel('Expert Index')
axes[1].set_ylabel('Token Index')
axes[1].set_title('Gating Probabilities (first 50 tokens)')
plt.colorbar(im, ax=axes[1])

# 右: エキスパートペアの共起頻度
pair_counts = np.zeros((num_experts, num_experts))
for i in range(batch_size):
    for j in range(top_k):
        for k_idx in range(j + 1, top_k):
            e1, e2 = sorted([indices[i, j], indices[i, k_idx]])
            pair_counts[e1, e2] += 1

# 対称化
pair_counts = pair_counts + pair_counts.T
im2 = axes[2].imshow(pair_counts, cmap='Blues')
axes[2].set_xlabel('Expert Index')
axes[2].set_ylabel('Expert Index')
axes[2].set_title('Expert Pair Co-occurrence')
plt.colorbar(im2, ax=axes[2])

plt.tight_layout()
plt.savefig('moe_routing.png', dpi=150, bbox_inches='tight')
plt.show()

左のグラフでは、各エキスパートに割り当てられたトークン数を示しています。赤い破線が理想的な均等分配を表しており、実際の割り当てがどの程度偏っているかが確認できます。初期化直後(学習前)でも偏りが生じており、これが負荷分散損失なしで学習すると崩壊に至る原因です。

中央のヒートマップでは、各トークン(行)がどのエキスパート(列)にどの程度の確率でルーティングされるかを色で表しています。トークンによってルーティングパターンが異なっており、MoEが入力に応じてエキスパートを切り替えていることがわかります。

右のグラフは、Top-2で選択されたエキスパートのペアの共起頻度を示しています。特定のペアが頻繁に共選択される傾向があれば、それらのエキスパートが相補的な役割を果たしている可能性があります。

負荷分散損失の効果シミュレーション

def simulate_routing_with_balance(n_steps=500, alpha=0.0):
    """負荷分散損失ありなしでルーティングの偏りをシミュレーション。"""
    W_gate = np.random.randn(d_model, num_experts) * 0.1
    lr = 0.01

    imbalance_history = []

    for step in range(n_steps):
        # ランダムバッチ
        x_batch = np.random.randn(64, d_model) * 0.5

        # ゲーティング
        logits = x_batch @ W_gate
        logits_max = np.max(logits, axis=1, keepdims=True)
        probs = np.exp(logits - logits_max) / np.sum(np.exp(logits - logits_max), axis=1, keepdims=True)

        # Top-k
        top_indices = np.argsort(-probs, axis=1)[:, :top_k]

        # 負荷の不均衡度
        counts = np.zeros(num_experts)
        for i in range(64):
            for j in range(top_k):
                counts[top_indices[i, j]] += 1
        counts /= 64
        imbalance = np.std(counts) / np.mean(counts)  # 変動係数
        imbalance_history.append(imbalance)

        # 勾配更新(簡略化: 負荷分散損失の勾配のみ)
        if alpha > 0:
            f = counts
            p = np.mean(probs, axis=0)
            # 負荷分散損失の勾配(近似)
            grad = alpha * num_experts * np.outer(np.mean(x_batch, axis=0), f * p)
            W_gate -= lr * grad

    return imbalance_history

# 比較
imbalance_no_balance = simulate_routing_with_balance(alpha=0.0)
imbalance_with_balance = simulate_routing_with_balance(alpha=0.1)

plt.figure(figsize=(8, 5))
plt.plot(imbalance_no_balance, label='No load balancing (α=0)', alpha=0.8)
plt.plot(imbalance_with_balance, label='With load balancing (α=0.1)', alpha=0.8)
plt.xlabel('Training Step')
plt.ylabel('Routing Imbalance (CV)')
plt.title('Effect of Load Balancing Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('moe_load_balance.png', dpi=150, bbox_inches='tight')
plt.show()

このグラフから、負荷分散損失なし(α=0)ではルーティングの不均衡度が高いままであるのに対し、負荷分散損失あり(α=0.1)では不均衡度が低く抑えられていることが確認できます。実際の学習では、この負荷分散損失がないとルーティング崩壊に至り、一部のエキスパートが全く使われなくなります。

MoEの課題と最新動向

メモリの壁

MoEモデルは計算量は少ないものの、全エキスパートのパラメータをGPUメモリに載せる必要があります。Mixtral 8x7Bは活性パラメータは13Bですが、メモリ使用量は47Bパラメータ分(FP16で約94GB)です。これは量子化で緩和できますが、同等計算量のdenseモデル(13B, 約26GB)と比べるとメモリ効率は低下します。

Expert Parallelism

MoEモデルの分散学習・推論では、Expert Parallelismが使われます。各GPUに異なるエキスパートを配置し、All-to-All通信でトークンを適切なGPUに送信します。通信コストがボトルネックになるため、GPUのインターコネクト帯域が重要です。

DeepSeek-MoE: Shared Expert

DeepSeek-V2(2024)は、一部のエキスパートを共有エキスパート(Shared Expert)として全トークンで活性化させます。共有エキスパートが汎用的な知識を処理し、ルーテッドエキスパートが専門的な知識を処理する分業構造です。

まとめ

本記事では、MoEの理論とMixtralの設計を解説しました。

  • MoEは条件付き計算により、総パラメータ数を大きくしつつ1トークンあたりの計算量を抑える
  • ゲーティングネットワークがトークンをTop-kのエキスパートにルーティングし、重み付き和で出力を生成する
  • 負荷分散損失がエキスパート間のトークン分配を均一化し、ルーティング崩壊を防ぐ
  • Mixtral 8x7Bは47Bパラメータだが活性パラメータは13B相当で、LLaMA-2 70Bに匹敵する性能を達成
  • メモリ使用量が総パラメータ数に比例する点が、denseモデルと比較した主な制約

次のステップとして、以下の記事も参考にしてください。

画像なし
LLaMAアーキテクチャの設計思想
MoEの基盤となるdense Transformerの設計を理解します
画像なし
LLMの量子化を完全解説
MoEモデルのメモリ制約を緩和する量子化技術を理解します