GPT-4やGeminiのような最高性能のLLMは、数千億〜数兆のパラメータを持つと言われています。しかし、全てのパラメータを全ての入力に対して使う必要はあるのでしょうか。「数学の質問には数学が得意なパラメータだけを使い、プログラミングの質問にはコードが得意なパラメータだけを使う」ことができれば、同じ性能をより少ない計算で達成できるはずです。
この発想を実現するのがMixture of Experts(MoE, 混合エキスパート)です。MoEは、TransformerのFFN(Feed-Forward Network)層を複数のエキスパートに分割し、各入力トークンに対して少数のエキスパートのみを選択的に活性化する仕組みです。
例えば、Mixtral 8x7Bは合計47Bのパラメータを持ちますが、各トークンの処理では8つのエキスパートのうち2つのみを使用するため、実効的な計算量は13B相当です。にもかかわらず、ベンチマークではLLaMA-2 70Bに匹敵する性能を達成しています。
MoEを理解することは、以下のような場面で直接役立ちます。
- 最新LLMの理解: GPT-4、Gemini 1.5、Mixtral、DBRX、Grok-1など、最先端のLLMの多くがMoEアーキテクチャを採用しています
- 推論コストの最適化: MoEモデルの推論は、同等性能のdenseモデルより数分の1の計算量で済みます
- モデル設計の指針: いつMoEを使うべきか、エキスパート数やTop-kの選択など、設計判断の基盤になります
本記事の内容
- 条件付き計算(Conditional Computation)の概念
- MoEの基本構造とルーティング機構の数式
- Top-kゲーティングとソフトマックスルーティング
- 負荷分散損失(Load Balancing Loss)
- Expert Capacity と Token Dropping
- Mixtralの設計と工夫
- Pythonによるスクラッチ実装
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
条件付き計算
Dense vs Sparse
従来のTransformerはdenseモデルです。全ての入力トークンに対して全てのパラメータが計算に参加します。モデルサイズを大きくすると性能は向上しますが、計算量もパラメータ数に比例して増加します。
$$ \text{FLOPs}_{\text{dense}} \propto N_{\text{params}} $$
Sparseモデル(MoE)では、各入力トークンに対して一部のパラメータのみが活性化されます。モデルの総パラメータ数を大きくしても、1トークンあたりの計算量は制御可能です。
$$ \text{FLOPs}_{\text{MoE}} \propto \frac{k}{E} \times N_{\text{params}} $$
ここで $E$ はエキスパート数、$k$ はトークンあたりに選択されるエキスパート数です。$E = 8, k = 2$ の場合、1トークンあたりの計算量は dense の約 $\frac{2}{8} = 25\%$ です。
これにより、より多くのパラメータ(= より多くの知識を格納する容量)を持ちながら、推論コストを抑えることが可能になります。
MoEの歴史
MoEの概念は1991年のJacobs et al.に遡りますが、現代的なTransformer MoEの系譜は以下の通りです。
| 年 | 手法/モデル | 主な貢献 |
|---|---|---|
| 2017 | Shazeer et al. | Transformer + MoE (Top-2 routing) |
| 2021 | Switch Transformer | Top-1 routing, 大規模MoE |
| 2022 | ST-MoE | 安定化技術のまとめ |
| 2024 | Mixtral 8x7B | 初のオープンソース高品質MoE LLM |
| 2024 | DeepSeek-V2 | Shared Expert + Fine-grained Expert |
では、MoEの具体的な構造を見ていきましょう。
MoEの基本構造
TransformerのどこにMoEを入れるか
MoEは通常、TransformerブロックのFFN層を置き換えます。Attention層は全トークン間の関係を捉える必要があるためdenseのまま残し、FFN層だけをMoE化します。
標準的なTransformerブロック:
x → Attention → Add & Norm → FFN → Add & Norm → y
MoE Transformerブロック:
x → Attention → Add & Norm → MoE(FFN₁, FFN₂, ..., FFNₑ) → Add & Norm → y
ゲーティングネットワーク(ルーター)
MoE層の核心はゲーティングネットワーク(ルーター)です。入力トークン $\bm{x} \in \mathbb{R}^d$ に対して、どのエキスパートを使うかを決定します。
$$ \bm{g}(\bm{x}) = \text{softmax}(\bm{W}_g \bm{x}) \in \mathbb{R}^E $$
ここで $\bm{W}_g \in \mathbb{R}^{E \times d}$ はゲーティングの重み行列、$E$ はエキスパート数です。$\bm{g}(\bm{x})_i$ は「トークン $\bm{x}$ がエキスパート $i$ にルーティングされる確率」を表します。
Top-k ゲーティング
全てのエキスパートの出力を重み付き和で計算すると、計算量はdenseモデルと変わりません。そこで、上位 $k$ 個のエキスパートのみを選択します。
$$ \text{TopK}(\bm{g}(\bm{x})) = \begin{cases} g_i(\bm{x}) & \text{if } g_i(\bm{x}) \in \text{Top-}k \\ 0 & \text{otherwise} \end{cases} $$
正規化して重みの和を1にします:
$$ \tilde{g}_i(\bm{x}) = \frac{\text{TopK}(g_i(\bm{x}))}{\sum_{j} \text{TopK}(g_j(\bm{x}))} $$
MoE層の出力
MoE層の出力は、選択されたエキスパートの出力の重み付き和です:
$$ \text{MoE}(\bm{x}) = \sum_{i=1}^{E} \tilde{g}_i(\bm{x}) \cdot \text{FFN}_i(\bm{x}) $$
実際には $\tilde{g}_i(\bm{x}) = 0$ のエキスパートは計算をスキップするので、計算量は $k$ 個のエキスパートの分だけです。
具体例
Mixtral 8x7Bの場合: – $E = 8$(8個のエキスパート) – $k = 2$(各トークンで2個を選択) – 各エキスパートは LLaMA-7B の FFN と同じサイズ – 総パラメータ: 8 × 7B(FFN) + Attention等の共有部 ≈ 47B – 1トークンあたりの計算量: 2 × 7B(FFN) + Attention等 ≈ 13B 相当
しかし、ルーティングには根本的な問題が潜んでいます。全てのトークンが同じエキスパートに集中する「ルーティング崩壊」が起きやすいのです。次にこの問題とその解決策を見ましょう。
負荷分散の問題と解決
ルーティング崩壊
学習初期に、あるエキスパートが偶然他のエキスパートより良い重みを持つと、多くのトークンがそのエキスパートにルーティングされます。すると、そのエキスパートはさらに多くのデータで学習されて性能が上がり、ますます多くのトークンが集中します。この正のフィードバックループにより、最終的に少数のエキスパートだけが使われ、残りは死んだ状態になります。
これがルーティング崩壊(routing collapse)であり、MoEの最大の課題です。
負荷分散損失(Auxiliary Load Balancing Loss)
この問題を解決するために、負荷分散損失(auxiliary load balancing loss)を追加します。
各エキスパート $i$ について、以下の2つの量を定義します。
ディスパッチ率 $f_i$: バッチ内の全トークンのうち、エキスパート $i$ にルーティングされた割合:
$$ f_i = \frac{1}{T} \sum_{t=1}^{T} \mathbb{1}[i \in \text{TopK}(\bm{g}(\bm{x}_t))] $$
ルーティング確率の平均 $p_i$: バッチ内の全トークンに対するゲート確率の平均:
$$ p_i = \frac{1}{T} \sum_{t=1}^{T} g_i(\bm{x}_t) $$
負荷分散損失は:
$$ \mathcal{L}_{\text{balance}} = \alpha \cdot E \cdot \sum_{i=1}^{E} f_i \cdot p_i $$
$\alpha$ はバランス損失の係数(通常 $0.01$ 程度)です。
この損失の意味を理解しましょう。完全に均一にルーティングされている場合、$f_i = k/E$、$p_i = 1/E$ で、$\sum f_i p_i = k/E^2$ が最小値になります。あるエキスパートにトークンが集中すると、そのエキスパートの $f_i \cdot p_i$ が大きくなり、損失が増加します。$\sum f_i p_i$ は内積なので、コーシー・シュワルツの不等式から均一分布が最小値を達成します。
Expert Capacity
もう一つの負荷制御手法がExpert Capacity(エキスパート容量)です。各エキスパートが1バッチで処理できるトークン数に上限を設けます。
$$ C = \left\lceil \frac{k \cdot T}{E} \right\rceil \cdot c_f $$
ここで $T$ はバッチ内の総トークン数、$c_f$ はキャパシティ倍率(通常1.0〜1.5)です。容量を超えたトークンはドロップ(残差接続のみでFFNをスキップ)されます。
Mixtralの設計
アーキテクチャの詳細
Mixtral 8x7B(Jiang et al., 2024)は、実用的なMoE LLMとして以下の設計選択をしています。
| パラメータ | 値 |
|---|---|
| 層数 | 32 |
| 隠れ次元 | 4096 |
| ヘッド数(Q) | 32 |
| ヘッド数(KV) | 8 (GQA) |
| エキスパート数 | 8 |
| Top-k | 2 |
| FFN中間次元 | 14336 |
| 総パラメータ | 46.7B |
| 活性パラメータ | 12.9B |
Mixtralの工夫
1. 全てのFFN層をMoE化: 一部の層だけでなく、全32層のFFN層をMoE化しています。これにより、各層で異なるエキスパートの組み合わせが選択され、多様な専門化パターンが学習されます。
2. Shared components: Attention層(Self-Attention + KVキャッシュ)と埋め込み層はエキスパート間で共有されます。パラメータの大部分(FFN)がMoE化されつつ、Attention の情報統合機能は全トークンで共通です。
3. SwiGLU活性化: 各エキスパートのFFN は LLaMA と同じ SwiGLU 活性化を使用します。
次に、MoEの仕組みをPythonで実装して理解を深めましょう。
Pythonによるスクラッチ実装
MoE層の実装
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
class Expert:
"""1つのFFNエキスパート。"""
def __init__(self, d_model, d_ff):
# SwiGLU風の2つの線形層
self.W1 = np.random.randn(d_model, d_ff) * 0.02
self.W2 = np.random.randn(d_ff, d_model) * 0.02
self.W_gate = np.random.randn(d_model, d_ff) * 0.02
def forward(self, x):
"""SwiGLU: gate(xW_gate) * (xW1) → W2"""
gate = x @ self.W_gate
gate = gate * (1 / (1 + np.exp(-gate))) # SiLU activation
hidden = gate * (x @ self.W1)
return hidden @ self.W2
class MoELayer:
"""Mixture of Experts層。"""
def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
self.num_experts = num_experts
self.top_k = top_k
self.d_model = d_model
# エキスパート
self.experts = [Expert(d_model, d_ff) for _ in range(num_experts)]
# ゲーティングネットワーク
self.W_gate = np.random.randn(d_model, num_experts) * 0.1
def route(self, x):
"""トークンをエキスパートにルーティングする。
Parameters
----------
x : np.ndarray, shape (batch, d_model)
Returns
-------
indices : np.ndarray, shape (batch, top_k) — 選択されたエキスパートの番号
weights : np.ndarray, shape (batch, top_k) — 正規化された重み
gate_probs : np.ndarray, shape (batch, num_experts) — 全ゲート確率
"""
# ゲートロジット
logits = x @ self.W_gate # (batch, num_experts)
# ソフトマックス
logits_max = np.max(logits, axis=1, keepdims=True)
exp_logits = np.exp(logits - logits_max)
gate_probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
# Top-k選択
indices = np.argsort(-gate_probs, axis=1)[:, :self.top_k]
# 選択されたエキスパートの重みを正規化
weights = np.zeros((x.shape[0], self.top_k))
for i in range(x.shape[0]):
selected_probs = gate_probs[i, indices[i]]
weights[i] = selected_probs / np.sum(selected_probs)
return indices, weights, gate_probs
def forward(self, x):
"""MoE層の前方計算。"""
batch_size = x.shape[0]
indices, weights, gate_probs = self.route(x)
output = np.zeros_like(x)
for i in range(batch_size):
for j in range(self.top_k):
expert_idx = indices[i, j]
expert_output = self.experts[expert_idx].forward(x[i:i+1])
output[i] += weights[i, j] * expert_output[0]
return output, indices, gate_probs
def load_balancing_loss(self, gate_probs, indices):
"""負荷分散損失を計算する。"""
batch_size = gate_probs.shape[0]
# ディスパッチ率 f_i
f = np.zeros(self.num_experts)
for i in range(batch_size):
for j in range(self.top_k):
f[indices[i, j]] += 1
f /= batch_size
# ルーティング確率の平均 p_i
p = np.mean(gate_probs, axis=0)
# 負荷分散損失
loss = self.num_experts * np.sum(f * p)
return loss, f, p
# MoE層の動作確認
d_model = 64
d_ff = 128
num_experts = 8
top_k = 2
batch_size = 256
moe = MoELayer(d_model, d_ff, num_experts, top_k)
x = np.random.randn(batch_size, d_model).astype(np.float64) * 0.5
output, indices, gate_probs = moe.forward(x)
lb_loss, f, p = moe.load_balancing_loss(gate_probs, indices)
print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")
print(f"負荷分散損失: {lb_loss:.4f} (理想値: {top_k:.4f})")
print(f"エキスパートごとのディスパッチ率: {f}")
ルーティングパターンの可視化
# ルーティングパターンの分析
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# 左: エキスパートごとのトークン割り当て数
expert_counts = np.zeros(num_experts)
for i in range(batch_size):
for j in range(top_k):
expert_counts[indices[i, j]] += 1
colors = plt.cm.Set3(np.linspace(0, 1, num_experts))
bars = axes[0].bar(range(num_experts), expert_counts, color=colors)
axes[0].axhline(y=batch_size * top_k / num_experts, color='red',
linestyle='--', label=f'Ideal ({batch_size*top_k/num_experts:.0f})')
axes[0].set_xlabel('Expert Index')
axes[0].set_ylabel('Number of Tokens Assigned')
axes[0].set_title('Token Assignment per Expert')
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis='y')
# 中央: ゲーティング確率のヒートマップ(上位50トークン)
im = axes[1].imshow(gate_probs[:50], aspect='auto', cmap='YlOrRd')
axes[1].set_xlabel('Expert Index')
axes[1].set_ylabel('Token Index')
axes[1].set_title('Gating Probabilities (first 50 tokens)')
plt.colorbar(im, ax=axes[1])
# 右: エキスパートペアの共起頻度
pair_counts = np.zeros((num_experts, num_experts))
for i in range(batch_size):
for j in range(top_k):
for k_idx in range(j + 1, top_k):
e1, e2 = sorted([indices[i, j], indices[i, k_idx]])
pair_counts[e1, e2] += 1
# 対称化
pair_counts = pair_counts + pair_counts.T
im2 = axes[2].imshow(pair_counts, cmap='Blues')
axes[2].set_xlabel('Expert Index')
axes[2].set_ylabel('Expert Index')
axes[2].set_title('Expert Pair Co-occurrence')
plt.colorbar(im2, ax=axes[2])
plt.tight_layout()
plt.savefig('moe_routing.png', dpi=150, bbox_inches='tight')
plt.show()
左のグラフでは、各エキスパートに割り当てられたトークン数を示しています。赤い破線が理想的な均等分配を表しており、実際の割り当てがどの程度偏っているかが確認できます。初期化直後(学習前)でも偏りが生じており、これが負荷分散損失なしで学習すると崩壊に至る原因です。
中央のヒートマップでは、各トークン(行)がどのエキスパート(列)にどの程度の確率でルーティングされるかを色で表しています。トークンによってルーティングパターンが異なっており、MoEが入力に応じてエキスパートを切り替えていることがわかります。
右のグラフは、Top-2で選択されたエキスパートのペアの共起頻度を示しています。特定のペアが頻繁に共選択される傾向があれば、それらのエキスパートが相補的な役割を果たしている可能性があります。
負荷分散損失の効果シミュレーション
def simulate_routing_with_balance(n_steps=500, alpha=0.0):
"""負荷分散損失ありなしでルーティングの偏りをシミュレーション。"""
W_gate = np.random.randn(d_model, num_experts) * 0.1
lr = 0.01
imbalance_history = []
for step in range(n_steps):
# ランダムバッチ
x_batch = np.random.randn(64, d_model) * 0.5
# ゲーティング
logits = x_batch @ W_gate
logits_max = np.max(logits, axis=1, keepdims=True)
probs = np.exp(logits - logits_max) / np.sum(np.exp(logits - logits_max), axis=1, keepdims=True)
# Top-k
top_indices = np.argsort(-probs, axis=1)[:, :top_k]
# 負荷の不均衡度
counts = np.zeros(num_experts)
for i in range(64):
for j in range(top_k):
counts[top_indices[i, j]] += 1
counts /= 64
imbalance = np.std(counts) / np.mean(counts) # 変動係数
imbalance_history.append(imbalance)
# 勾配更新(簡略化: 負荷分散損失の勾配のみ)
if alpha > 0:
f = counts
p = np.mean(probs, axis=0)
# 負荷分散損失の勾配(近似)
grad = alpha * num_experts * np.outer(np.mean(x_batch, axis=0), f * p)
W_gate -= lr * grad
return imbalance_history
# 比較
imbalance_no_balance = simulate_routing_with_balance(alpha=0.0)
imbalance_with_balance = simulate_routing_with_balance(alpha=0.1)
plt.figure(figsize=(8, 5))
plt.plot(imbalance_no_balance, label='No load balancing (α=0)', alpha=0.8)
plt.plot(imbalance_with_balance, label='With load balancing (α=0.1)', alpha=0.8)
plt.xlabel('Training Step')
plt.ylabel('Routing Imbalance (CV)')
plt.title('Effect of Load Balancing Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('moe_load_balance.png', dpi=150, bbox_inches='tight')
plt.show()
このグラフから、負荷分散損失なし(α=0)ではルーティングの不均衡度が高いままであるのに対し、負荷分散損失あり(α=0.1)では不均衡度が低く抑えられていることが確認できます。実際の学習では、この負荷分散損失がないとルーティング崩壊に至り、一部のエキスパートが全く使われなくなります。
MoEの課題と最新動向
メモリの壁
MoEモデルは計算量は少ないものの、全エキスパートのパラメータをGPUメモリに載せる必要があります。Mixtral 8x7Bは活性パラメータは13Bですが、メモリ使用量は47Bパラメータ分(FP16で約94GB)です。これは量子化で緩和できますが、同等計算量のdenseモデル(13B, 約26GB)と比べるとメモリ効率は低下します。
Expert Parallelism
MoEモデルの分散学習・推論では、Expert Parallelismが使われます。各GPUに異なるエキスパートを配置し、All-to-All通信でトークンを適切なGPUに送信します。通信コストがボトルネックになるため、GPUのインターコネクト帯域が重要です。
DeepSeek-MoE: Shared Expert
DeepSeek-V2(2024)は、一部のエキスパートを共有エキスパート(Shared Expert)として全トークンで活性化させます。共有エキスパートが汎用的な知識を処理し、ルーテッドエキスパートが専門的な知識を処理する分業構造です。
まとめ
本記事では、MoEの理論とMixtralの設計を解説しました。
- MoEは条件付き計算により、総パラメータ数を大きくしつつ1トークンあたりの計算量を抑える
- ゲーティングネットワークがトークンをTop-kのエキスパートにルーティングし、重み付き和で出力を生成する
- 負荷分散損失がエキスパート間のトークン分配を均一化し、ルーティング崩壊を防ぐ
- Mixtral 8x7Bは47Bパラメータだが活性パラメータは13B相当で、LLaMA-2 70Bに匹敵する性能を達成
- メモリ使用量が総パラメータ数に比例する点が、denseモデルと比較した主な制約
次のステップとして、以下の記事も参考にしてください。