BERTは自然言語処理に革命をもたらしましたが、その事前学習には根本的な非効率が潜んでいます。BERTが採用するMLM(Masked Language Model)は、入力トークンのうちわずか15%だけをマスクし、そのマスク位置だけで損失を計算します。残りの85%のトークンからは学習信号がまったく得られません。つまり、1回のフォワードパスで全トークンを処理する計算コストを払っているにもかかわらず、そのうちの一部しか学習に活かしていないのです。
「入力の全トークンから学習信号を得ることはできないか?」—これがELECTRA(Efficiently Learning an Encoder that Classifies Token Replacements Accurately)の出発点です。Clark et al.(2020, Google / Stanford)は、GAN(敵対的生成ネットワーク)に着想を得た2つのネットワーク構成を提案しました。小さなGeneratorがマスク位置にもっともらしいトークンを生成し、大きなDiscriminatorが全トークンに対して「本物か偽物か」を判別します。このReplaced Token Detection(RTD)というタスクにより、Discriminatorは入力の100%から学習信号を受け取ります。
ELECTRAの意義は、特に計算資源が限られた環境で際立ちます。
- 小規模モデルでの高精度: ELECTRA-Smallは1つのGPUで4日間学習するだけで、同じ計算量のBERT-Smallを大幅に上回る性能を達成しました。研究室や個人開発者にとって、限られたリソースで高品質な言語モデルを得る現実的な手段です
- 学習効率の改善: ELECTRA-Baseは、BERT-Baseの1/4の計算量で同等の性能に到達します。大規模な事前学習の電力コストや環境負荷が議論される中、学習効率の向上は技術的にも社会的にも重要です
- エッジデバイスへの展開: 小規模モデルでも高い精度が出るため、スマートフォンやIoTデバイスでの推論に適した事前学習済みモデルを効率的に作れます
本記事の内容
- BERTのMLMが抱える非効率性(15%問題と[MASK]トークンの不整合)
- ELECTRAのGenerator + Discriminatorアーキテクチャ
- Replaced Token Detectionの数学的定式化(Generator損失、Discriminator損失、結合損失)
- GANとの類似点と本質的な相違点
- Weight Sharing(埋め込み共有)の効果と設計判断
- PyTorchでのGenerator、Discriminator、学習ループのスクラッチ実装
- ELECTRA vs BERTの性能比較と計算効率の分析
前提知識
この記事を読む前に、以下の記事を読んでおくと理解が深まります。
MLMの非効率性 — BERTが抱える2つの問題
ELECTRAの設計意図を理解するには、まずBERTのMLMが持つ構造的な弱点を正確に把握する必要があります。MLMには大きく分けて2つの問題があります。
問題1: 15%しか学習信号がない
BERTの事前学習では、入力トークン列のうちランダムに15%を選び、それらをマスク(あるいは置換・維持)します。損失関数は、このマスクされた15%の位置に対してのみ計算されます。
長さ $T$ のトークン列 $\bm{x} = (x_1, x_2, \ldots, x_T)$ に対して、マスク位置の集合を $\mathcal{M}$ とすると、MLMの損失は次のように定義されます。
$$ \mathcal{L}_{\text{MLM}} = -\sum_{i \in \mathcal{M}} \log P(x_i \mid \bm{x}_{\backslash \mathcal{M}}) $$
ここで $|\mathcal{M}| \approx 0.15T$ です。モデルは全トークンを処理するために $T$ 個のトークン分の計算を行いますが、学習信号を得るのは $0.15T$ 個分だけです。
この非効率さを日常のたとえで考えてみましょう。試験勉強で100問の問題集を解いたのに、答え合わせは15問分しか行わないようなものです。残りの85問については正解だったのか間違っていたのかわからないまま、次の問題集に進んでしまいます。同じ時間を使うなら、100問すべての答え合わせをした方が学習効率は遥かに高いはずです。
数値で確認してみましょう。典型的なBERTの入力は512トークンです。15%をマスクすると、1サンプルあたり約77トークンからしか学習信号が得られません。残りの435トークン分の計算は、文脈を提供する役割は果たしますが、損失に直接的な勾配を生じません。
問題2: [MASK]トークンの事前学習・ファインチューニング不整合
MLMの2つ目の問題は、事前学習とファインチューニングの間にある入力分布の不整合です。事前学習時にはマスクされたトークンが [MASK] という特殊トークンに置き換わりますが、ファインチューニング時や実際のテキストには [MASK] はどこにも出現しません。
BERTはこの問題を軽減するために、マスク対象トークンの扱いを3通りに分けています。
- 80%:
[MASK]トークンに置換 - 10%: ランダムなトークンに置換
- 10%: 元のトークンのまま維持
この 80/10/10 ルールはヒューリスティクスとして有効に機能しますが、根本的な解決策とは言えません。依然としてマスク位置の80%は [MASK] に置換されており、入力分布のズレは残ります。
これら2つの問題は独立ではありません。学習効率が低いことは学習に必要な計算量を増大させ、入力分布の不整合はファインチューニング時の性能ロスにつながります。では、MLMの枠組みそのものを離れて、これらの問題を同時に解決するアプローチは考えられないでしょうか。ELECTRAは、まさにその問いに対する回答です。
ELECTRAの基本アイデア — 穴埋めから真贋判定へ
着想: テスト問題の作成者と解答者
ELECTRAのアイデアを直感的に理解するために、学校の試験のアナロジーで考えてみましょう。
BERTのMLMは「穴埋めテスト」です。文章の一部が空欄([MASK])になっており、そこに正しい単語を埋めるよう求められます。しかしこの方式では、空欄以外の部分は「最初から答えが見えている問題」なので、そこから学ぶことはありません。
ELECTRAは発想を変えます。まず出題者(Generator)が空欄にもっともらしい単語を埋めて「完成品」を作ります。次に審査官(Discriminator)が、その完成品の各単語を1つずつ吟味して、「この単語はオリジナルか、出題者が差し替えたものか?」を判定します。審査官は全ての単語を検査する必要があるため、文のあらゆる位置から学習信号を得ることになります。
この発想により、ELECTRAは2つの問題を同時に解決します。
- 全トークンからの学習信号: Discriminatorは全位置で二値分類を行うため、入力の100%から勾配が流れます
- [MASK]の不整合を解消: Discriminatorの入力には
[MASK]トークンが一切含まれません。Generatorが生成した実際の単語か、元の単語かのどちらかです
GAN(敵対的生成ネットワーク)との類似性
ELECTRAのGenerator + Discriminatorという構成は、GANを連想させます。画像生成のGANでは、Generatorが偽の画像を生成し、Discriminatorが本物と偽物を区別します。ELECTRAでも同様に、Generatorが偽のトークンを生成し、Discriminatorが真贋を判定します。
しかし後で詳しく解説しますが、ELECTRAとGANには本質的な違いがあります。GANのGeneratorは「Discriminatorを騙す」ように学習しますが、ELECTRAのGeneratorは単にMLMの損失を最小化するだけで、Discriminatorからの勾配は受け取りません。この違いは、学習の安定性に大きく影響します。
ELECTRAのアーキテクチャの全体像を把握したところで、次はGeneratorとDiscriminatorそれぞれの役割と動作を詳しく見ていきましょう。
Generatorの役割 — マスク位置にトークンを生成する小さなMLMモデル
Generatorの構造
Generatorは、BERTと同じMLMタスクを行う小さなTransformer Encoderです。ここで「小さい」というのが重要なポイントです。ELECTRAの論文では、GeneratorのサイズをDiscriminatorの1/4〜1/3程度にすることが推奨されています。
具体的には、Discriminatorが12層・隠れ次元768(BERT-Baseと同じ)だとすると、Generatorは4層・隠れ次元256程度に設定します。Generatorを小さくする理由は2つあります。
理由1: 計算効率。ELECTRAの計算コストの大部分はDiscriminatorが占めます。Generatorを大きくしても全体の学習効率はあまり改善しません。むしろ、Generatorの計算コストを抑えてDiscriminatorに計算予算を回す方が効率的です。
理由2: 適切な難易度のタスク。Generatorが強すぎると、ほぼ正解のトークンを生成してしまい、Discriminatorにとって「本物か偽物か」の判別が簡単になりすぎます。逆にGeneratorが弱すぎると、明らかに不自然なトークンが生成され、やはり判別が簡単になります。適度なサイズのGeneratorが、Discriminatorにとって「ちょうどいい難易度」の問題を生成するのです。
Generatorの動作
Generatorの処理フローを具体的に追ってみましょう。入力トークン列 $\bm{x} = (x_1, x_2, \ldots, x_T)$ が与えられたとき、以下の手順で動作します。
ステップ1: マスク位置の選定
BERTと同様に、入力トークンのうちランダムに15%を選びます。選ばれた位置の集合を $\mathcal{M}$ とします。選ばれた位置のトークンは [MASK] に置換され、マスク済みの入力 $\bm{x}^{\text{masked}}$ が作られます。
ステップ2: Generator Encoderでの処理
$\bm{x}^{\text{masked}}$ をGenerator(小さなTransformer Encoder)に通し、各位置の隠れ表現 $\bm{h}_G(t)$ を得ます。
ステップ3: マスク位置のトークン予測
マスク位置 $t \in \mathcal{M}$ について、隠れ表現を語彙全体の確率分布に変換します。
$$ P_G(x_t \mid \bm{x}^{\text{masked}}) = \frac{\exp(\bm{e}(x_t)^{\top} \bm{h}_G(t))}{\sum_{x’ \in \mathcal{V}} \exp(\bm{e}(x’)^{\top} \bm{h}_G(t))} $$
ここで $\bm{e}(x_t)$ はトークン $x_t$ の埋め込みベクトル、$\mathcal{V}$ は語彙全体の集合です。この式はMLMヘッドと同じ softmax 分類です。
ステップ4: トークンのサンプリング
各マスク位置 $t \in \mathcal{M}$ で、上記の確率分布からトークンをサンプリングします。
$$ \hat{x}_t \sim P_G(x_t \mid \bm{x}^{\text{masked}}), \quad t \in \mathcal{M} $$
サンプリングされたトークン $\hat{x}_t$ がオリジナルのトークン $x_t$ と一致する場合もあります。この場合、その位置は「置換されていない」として扱います。
ステップ5: 「汚染された」入力の生成
マスク位置をサンプリングされたトークンで埋めて、「汚染された」入力 $\tilde{\bm{x}}$ を作ります。
$$ \tilde{x}_t = \begin{cases} \hat{x}_t & \text{if } t \in \mathcal{M} \\ x_t & \text{if } t \notin \mathcal{M} \end{cases} $$
この $\tilde{\bm{x}}$ には [MASK] トークンが一切含まれない点が重要です。全ての位置に「本物っぽい」トークンが入っているため、Discriminatorが受け取る入力は自然なテキストに近い形になります。
Generatorが生成した「汚染された」入力を受け取るのがDiscriminatorです。次に、Discriminatorがこの入力をどのように処理するかを見ていきましょう。
Discriminatorの役割 — 全トークンで真贋を見分ける
Discriminatorの構造
Discriminatorは、ELECTRAの事前学習で本当に学ばせたい「本体」のモデルです。BERTと同じ規模のTransformer Encoderであり、ファインチューニング時にはこのDiscriminatorが下流タスクに使用されます。
Discriminatorが行うタスクは非常にシンプルです。入力 $\tilde{\bm{x}}$ の各位置 $t$ について、「そのトークンはオリジナル(本物)か、Generatorが置換した偽物か?」を二値分類します。
Discriminatorの動作
ステップ1: 隠れ表現の計算
「汚染された」入力 $\tilde{\bm{x}}$ をDiscriminator(大きなTransformer Encoder)に通し、各位置の隠れ表現 $\bm{h}_D(t)$ を得ます。
ステップ2: 各トークンの真贋判定
各位置 $t$ について、隠れ表現を線形層 + sigmoid関数に通して、「本物である確率」を予測します。
$$ D(\tilde{\bm{x}}, t) = \sigma(\bm{w}^{\top} \bm{h}_D(t)) $$
ここで $\bm{w}$ はDiscriminatorの判別ヘッドの重みベクトル、$\sigma$ はsigmoid関数です。この判定は全位置 $t = 1, 2, \ldots, T$ に対して行われます。これがMLMの15%とは対照的な、ELECTRAの最大の特徴です。
なぜ全トークンからの学習が効果的なのか
直感的に考えてみましょう。あるトークンが「本物かどうか」を判定するためには、Discriminatorは以下のことを理解している必要があります。
- 文法的な正しさ: 品詞の整合性、活用の正しさ、語順の妥当性
- 意味的な一貫性: 文脈に合った単語が使われているか
- 世界知識: 事実として正しい組み合わせか(「東京は日本の首都」vs「東京はフランスの首都」)
BERTのMLMでも同様の知識は学習されますが、MLMでは「空欄に何が入るか」を予測する形です。ELECTRAのRTDでは、「この単語がこの位置にあるのは自然か」を判定する形です。後者の方がより繊細な言語理解を必要とします。なぜなら、Generatorが生成するトークンは「もっともらしいが微妙に間違っている」ことが多く、Discriminatorはその微妙な違和感を検出しなければならないからです。
さらに、MLMでは空欄が [MASK] であることが明らかなので、モデルは「ここが問題のある位置だ」と事前にわかっています。一方RTDでは、どの位置が置換されているか事前にはわかりません。Discriminatorは全位置を等しく注意深く検査する必要があり、これが全トークンからの効率的な学習につながるのです。
学習信号の比率を定量的に比較すると、以下のようになります。
| 方式 | 1サンプルあたりの学習信号 | 512トークン入力の場合 |
|---|---|---|
| MLM(BERT) | $0.15T$ 位置 | 約77位置 |
| RTD(ELECTRA) | $T$ 位置 | 512位置 |
| 効率比 | RTD / MLM $\approx 6.7$ 倍 | — |
もちろん、RTDの二値分類(本物/偽物)とMLMの語彙分類(30,000語超の中から正解を選ぶ)では1位置あたりの情報量が異なるため、単純に6.7倍の学習効率とは言えません。しかし実験的には、ELECTRAの方が同じ計算量でBERTを大幅に上回ることが確認されています。
GeneratorとDiscriminatorの役割がわかったところで、次はこれらを数学的に定式化し、ELECTRAの損失関数がどのように設計されているかを見ていきましょう。
数学的定式化 — ELECTRAの損失関数
ここでは、ELECTRAの学習で使われる損失関数を数式で厳密に定義します。ELECTRAの損失は3つの要素から構成されます。
記法の整理
まず、記法を整理しておきましょう。
- $\bm{x} = (x_1, x_2, \ldots, x_T)$: オリジナルの入力トークン列
- $\mathcal{M}$: マスク位置の集合($|\mathcal{M}| \approx 0.15T$)
- $\bm{x}^{\text{masked}}$: マスク位置を
[MASK]に置換した入力 - $\tilde{\bm{x}} = (\tilde{x}_1, \tilde{x}_2, \ldots, \tilde{x}_T)$: Generatorがマスク位置を埋めた「汚染された」入力
- $\theta_G$: Generatorのパラメータ
- $\theta_D$: Discriminatorのパラメータ
Generator損失: MLM損失
Generatorの損失は、BERTのMLMと同じ交差エントロピー損失です。マスクされた位置のトークンを正しく予測することを目指します。
$$ \mathcal{L}_{\text{Gen}}(\theta_G) = -\sum_{t \in \mathcal{M}} \log P_G(x_t \mid \bm{x}^{\text{masked}}; \theta_G) $$
$P_G(x_t \mid \bm{x}^{\text{masked}}; \theta_G)$ は、Generatorが位置 $t$ でオリジナルのトークン $x_t$ を予測する確率です。前のセクションで示した softmax 分布に対応します。
Generatorの損失は、標準的なMLMの損失とまったく同じ形をしていることに注意してください。ELECTRAのGeneratorは、独立した小さなMLMモデルに過ぎません。
Discriminator損失: バイナリクロスエントロピー
Discriminatorの損失は、各位置での二値分類に対するバイナリクロスエントロピーです。各位置 $t$ について、$\tilde{x}_t$ がオリジナルか置換されたものかを判定します。
まず、位置 $t$ の正解ラベルを定義します。
$$ y_t = \begin{cases} 1 & \text{if } \tilde{x}_t = x_t \quad (\text{本物}) \\ 0 & \text{if } \tilde{x}_t \neq x_t \quad (\text{偽物}) \end{cases} $$
ここで重要なのは、$t \in \mathcal{M}$(マスク位置)であっても、Generatorがたまたまオリジナルと同じトークンを生成した場合は $y_t = 1$(本物)になるということです。ラベルは「マスクされたかどうか」ではなく「実際にトークンが変わったかどうか」で決まります。
Discriminatorが位置 $t$ でトークンが本物である確率を $D(\tilde{\bm{x}}, t; \theta_D)$ と書くと、損失は次のようになります。
$$ \mathcal{L}_{\text{Disc}}(\theta_D) = -\sum_{t=1}^{T} \left[ y_t \log D(\tilde{\bm{x}}, t; \theta_D) + (1 – y_t) \log (1 – D(\tilde{\bm{x}}, t; \theta_D)) \right] $$
この式を項ごとに分解して確認しましょう。$y_t = 1$(本物)の位置では、第1項 $\log D(\tilde{\bm{x}}, t; \theta_D)$ が有効になります。つまり、Discriminatorが「本物」と正しく判定する確率を高めます。$y_t = 0$(偽物)の位置では、第2項 $\log(1 – D(\tilde{\bm{x}}, t; \theta_D))$ が有効になります。つまり、Discriminatorが「偽物」と正しく判定する確率を高めます。
和が全位置 $t = 1, 2, \ldots, T$ にわたることが、MLMの $t \in \mathcal{M}$ のみとは決定的に異なります。
結合損失: GeneratorとDiscriminatorの同時学習
ELECTRAの全体損失は、Generator損失とDiscriminator損失の重み付き和です。
$$ \mathcal{L}_{\text{ELECTRA}}(\theta_G, \theta_D) = \mathcal{L}_{\text{Gen}}(\theta_G) + \lambda \, \mathcal{L}_{\text{Disc}}(\theta_D) $$
ここで $\lambda$ はDiscriminator損失の重みを制御するハイパーパラメータです。論文では $\lambda = 50$ が使用されています。
$\lambda = 50$ という一見大きな値は、Generator損失とDiscriminator損失のスケールの違いを補正するためです。Generator損失はマスク位置(約15%のトークン)に対する語彙サイズ(30,000語超)の交差エントロピーであり、1位置あたりの損失値が大きくなります。一方、Discriminator損失は全位置に対する二値分類のエントロピーであり、1位置あたりの損失値は比較的小さくなります。$\lambda = 50$ は、両者の勾配の大きさを揃えて学習のバランスを保つ役割を果たします。
最適化の流れ
学習の各ステップは以下の手順で進みます。
- 入力 $\bm{x}$ からマスク位置 $\mathcal{M}$ をサンプリング
- Generator で $\bm{x}^{\text{masked}}$ を処理し、マスク位置のトークンをサンプリングして $\tilde{\bm{x}}$ を生成
- $\mathcal{L}_{\text{Gen}}$ を計算し、$\theta_G$ に対する勾配を求める
- Discriminator で $\tilde{\bm{x}}$ を処理し、各位置の真贋を判定
- $\mathcal{L}_{\text{Disc}}$ を計算し、$\theta_D$ に対する勾配を求める
- $\theta_G$ と $\theta_D$ をそれぞれ更新
ここで極めて重要なのは、ステップ3の勾配は $\theta_G$ のみに対して計算され、ステップ5の勾配は $\theta_D$ のみに対して計算されるということです。つまり、$\mathcal{L}_{\text{Disc}}$ の勾配がGeneratorに逆伝播することはありません。この点がGANとの最も重要な違いであり、次のセクションで詳しく解説します。
なぜGANの敵対的学習ではないのか
ELECTRAの構造はGANに似ていますが、学習方法は根本的に異なります。この違いは、ELECTRAの学習安定性を理解する上で極めて重要です。
GANの学習: 敵対的ゲーム
画像生成のGANでは、GeneratorとDiscriminatorがミニマックスゲームを行います。Generatorは「Discriminatorを騙す」ことを目指し、Discriminatorは「GeneratorとRealを区別する」ことを目指します。
$$ \min_G \max_D \; \mathbb{E}_{\bm{x} \sim p_{\text{data}}}[\log D(\bm{x})] + \mathbb{E}_{\bm{z} \sim p_z}[\log(1 – D(G(\bm{z})))] $$
このとき、Discriminatorの判定結果がGeneratorに逆伝播し、Generatorはその勾配を使って「よりDiscriminatorを騙しやすい」出力を生成するように学習します。
ELECTRAの学習: 最尤推定 + 独立した判別
ELECTRAでは、GeneratorはMLMの最尤推定で独立に学習します。Discriminatorの損失はGeneratorのパラメータ更新に影響を与えません。
なぜ敵対的学習を使わないのでしょうか? その理由は、テキストの離散性にあります。
画像の場合、Generatorの出力は連続値(ピクセル値)なので、Discriminatorからの勾配を直接Generatorに逆伝播できます。しかしテキストの場合、Generatorの出力は離散的なトークンです。確率分布からトークンをサンプリングする操作は微分不可能であるため、Discriminatorの損失から勾配をGeneratorに流すことができません。
$$ \hat{x}_t \sim P_G(x_t \mid \bm{x}^{\text{masked}}) $$
このサンプリング操作 $\sim$ は微分不可能です。画像GANでは出力がそのまま連続的にDiscriminatorに渡されますが、ELECTRAではサンプリングというボトルネックが存在するため、勾配が途切れます。
Clark et al. は論文中で、REINFORCE(方策勾配法)を用いてDiscriminatorの報酬をGeneratorに伝える方法も試していますが、結果はMLMによる最尤推定の方が安定して高性能だったと報告しています。
まとめ: GANとELECTRAの比較
| 項目 | GAN | ELECTRA |
|---|---|---|
| Generator出力 | 連続値(画像ピクセル) | 離散値(トークン) |
| Generatorの目的 | Discriminatorを騙す | マスク位置の正解トークンを予測(MLM) |
| 勾配の流れ | D → G(逆伝播) | D → G の逆伝播なし |
| Generatorの損失 | 敵対的損失 | 交差エントロピー(最尤推定) |
| 学習の安定性 | mode collapse等のリスクあり | MLMベースで安定 |
ELECTRAのGeneratorはGANのGeneratorと異なり、Discriminatorを「騙す」インセンティブを持ちません。単に「良い穴埋めモデル」として学習するだけです。これにより、GANで頻繁に問題となるモード崩壊(mode collapse)や学習の不安定性を回避しています。
GANとの違いが明確になったところで、次はELECTRAのもう一つの重要な設計判断であるWeight Sharing(重み共有)について見ていきましょう。
Weight Sharing — 埋め込み層の共有戦略
なぜ重みを共有するのか
ELECTRAでは、GeneratorとDiscriminatorのトークン埋め込み層を共有するという設計判断がなされています。つまり、同じ単語は両方のネットワークで同じ初期ベクトルにマッピングされます。
直感的に考えてみましょう。Generatorは「マスク位置に正しいトークンを予測する」ために、各トークンの意味を理解する必要があります。Discriminatorは「各トークンが文脈に合っているかを判定する」ために、やはり各トークンの意味を理解する必要があります。両者が必要とするトークン理解は共通する部分が多いため、埋め込み層を共有することで効率的に学習できます。
共有の方法
GeneratorとDiscriminatorの隠れ次元が異なる場合(通常はDiscriminatorの方が大きい)、直接的な共有には工夫が必要です。ELECTRAでは以下のアプローチを取ります。
語彙サイズを $V$、共有埋め込みの次元を $d_e$ とすると、共有される埋め込み行列 $\bm{E} \in \mathbb{R}^{V \times d_e}$ は、GeneratorとDiscriminatorの両方で使用されます。
Generatorの隠れ次元 $d_G$ が $d_e$ と異なる場合は、線形射影で次元を合わせます。
$$ \bm{h}_G^{(0)}(t) = \bm{W}_G \, \bm{E}(x_t) + \bm{b}_G $$
同様に、Discriminatorの隠れ次元 $d_D$ が $d_e$ と異なる場合も射影を行います。
$$ \bm{h}_D^{(0)}(t) = \bm{W}_D \, \bm{E}(\tilde{x}_t) + \bm{b}_D $$
ここで $\bm{W}_G \in \mathbb{R}^{d_G \times d_e}$ と $\bm{W}_D \in \mathbb{R}^{d_D \times d_e}$ はそれぞれの射影行列です。
共有の効果
Clark et al. の論文では、重み共有の効果を消去実験で検証しています。
| 設定 | GLUE平均スコア |
|---|---|
| 埋め込み共有なし | 83.5 |
| トークン埋め込みのみ共有 | 84.3 |
| 全重み共有(Generator = Discriminatorサイズ) | 84.4 |
トークン埋め込みの共有だけで +0.8 ポイントの改善が見られます。一方、Transformer層の重みまですべて共有すると、性能はほぼ同等ですが、GeneratorとDiscriminatorのサイズが同じになるため計算コストが増大します。そのため、実用的にはトークン埋め込みのみの共有が最適なバランスです。
重み共有のもう一つの利点は、GeneratorとDiscriminatorが共通のトークン表現空間で学習することです。Generatorが「この文脈ではトークン A が適切だ」と学習した知識は、埋め込みを通じてDiscriminatorにも伝わります。これにより、特に学習初期の段階でDiscriminatorがより効率的に言語構造を理解できるようになります。
理論的な設計をすべて把握したところで、次はPyTorchを使ってELECTRAを実装し、実際に動かしてみましょう。
PyTorchでの実装
ここでは、ELECTRAのGenerator、Discriminator、および学習ループをPyTorchでスクラッチ実装します。まず個々のコンポーネントを構築し、最後にそれらを組み合わせた学習ループを作ります。
Generator(小さなMLMモデル)
Generatorは、小さなTransformer EncoderにMLMヘッドを付けた構造です。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TransformerBlock(nn.Module):
"""Transformer Encoderの1ブロック"""
def __init__(self, hidden_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.ff = nn.Sequential(
nn.Linear(hidden_dim, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, hidden_dim),
nn.Dropout(dropout),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-Attention + 残差接続
attn_out, _ = self.attention(x, x, x, key_padding_mask=mask)
x = self.norm1(x + self.dropout(attn_out))
# Feed-Forward + 残差接続
x = self.norm2(x + self.ff(x))
return x
このTransformerBlockは、GeneratorとDiscriminatorの両方で共有するビルディングブロックです。Multi-Head Self-Attention、残差接続、Layer Normalization、Feed-Forward Networkの標準的な構成を実装しています。
次に、このブロックを使ってGeneratorを構築します。
class Generator(nn.Module):
"""ELECTRA Generator: 小さなMLMモデル"""
def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers,
num_heads, ff_dim, max_len=512, dropout=0.1):
super().__init__()
# 埋め込み次元と隠れ次元の射影
self.embed_proj = nn.Linear(embed_dim, hidden_dim)
self.pos_embed = nn.Embedding(max_len, hidden_dim)
self.layers = nn.ModuleList([
TransformerBlock(hidden_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_dim)
# MLMヘッド: 隠れ表現 → 語彙サイズの確率分布
self.mlm_head = nn.Sequential(
nn.Linear(hidden_dim, embed_dim),
nn.GELU(),
nn.LayerNorm(embed_dim),
)
self.vocab_proj = None # 埋め込み行列と重み共有するため後で設定
def forward(self, token_embeds, padding_mask=None):
"""
Args:
token_embeds: 共有埋め込みを通した後のテンソル [B, T, embed_dim]
padding_mask: パディングマスク [B, T](True=パディング)
Returns:
logits: [B, T, vocab_size]
"""
B, T, _ = token_embeds.shape
positions = torch.arange(T, device=token_embeds.device).unsqueeze(0)
x = self.embed_proj(token_embeds) + self.pos_embed(positions)
for layer in self.layers:
x = layer(x, mask=padding_mask)
x = self.norm(x)
x = self.mlm_head(x) # [B, T, embed_dim]
# 埋め込み行列との内積で語彙分布を得る
logits = F.linear(x, self.vocab_proj) # [B, T, vocab_size]
return logits
Generatorは embed_dim のトークン埋め込みを受け取り、より小さな hidden_dim に射影してからTransformerに通します。出力は再び embed_dim に戻し、共有埋め込み行列との内積で語彙分布を計算します。この構造により、Generatorの隠れ次元をDiscriminatorより小さくしつつ、埋め込み層は共有できます。
Discriminator(全トークンの二値分類)
class Discriminator(nn.Module):
"""ELECTRA Discriminator: 全トークンで本物/偽物を判別"""
def __init__(self, embed_dim, hidden_dim, num_layers,
num_heads, ff_dim, max_len=512, dropout=0.1):
super().__init__()
self.embed_proj = nn.Linear(embed_dim, hidden_dim)
self.pos_embed = nn.Embedding(max_len, hidden_dim)
self.layers = nn.ModuleList([
TransformerBlock(hidden_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_dim)
# 判別ヘッド: 各位置で本物/偽物の二値分類
self.disc_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, token_embeds, padding_mask=None):
"""
Args:
token_embeds: 共有埋め込みを通した後のテンソル [B, T, embed_dim]
padding_mask: パディングマスク [B, T]
Returns:
logits: [B, T] 各位置の判別ロジット
"""
B, T, _ = token_embeds.shape
positions = torch.arange(T, device=token_embeds.device).unsqueeze(0)
x = self.embed_proj(token_embeds) + self.pos_embed(positions)
for layer in self.layers:
x = layer(x, mask=padding_mask)
x = self.norm(x)
logits = self.disc_head(x).squeeze(-1) # [B, T]
return logits
DiscriminatorもGeneratorと同じTransformerBlock構造を使いますが、出力ヘッドが大きく異なります。Generatorは語彙全体に対する分類を行うのに対し、Discriminatorは各位置で1つのスカラー値を出力し、sigmoid関数を通して「本物である確率」を算出します。
ELECTRAモデルの統合と学習ループ
GeneratorとDiscriminatorを統合し、学習ループを構築します。
class ELECTRA(nn.Module):
"""Generator + Discriminator + 共有埋め込みの統合モデル"""
def __init__(self, vocab_size, embed_dim=128,
gen_hidden=64, gen_layers=2, gen_heads=2, gen_ff=256,
disc_hidden=256, disc_layers=6, disc_heads=4, disc_ff=1024,
max_len=128, mask_prob=0.15, mask_token_id=4,
lambda_disc=50.0):
super().__init__()
self.mask_prob = mask_prob
self.mask_token_id = mask_token_id
self.lambda_disc = lambda_disc
# 共有トークン埋め込み
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
# Generator(小さい)
self.generator = Generator(
vocab_size, embed_dim, gen_hidden, gen_layers,
gen_heads, gen_ff, max_len
)
# 埋め込み行列の重み共有
self.generator.vocab_proj = self.token_embedding.weight
# Discriminator(大きい)
self.discriminator = Discriminator(
embed_dim, disc_hidden, disc_layers,
disc_heads, disc_ff, max_len
)
def _mask_tokens(self, input_ids, padding_mask):
"""マスク位置を選択し、[MASK]トークンに置換"""
# パディングでない位置のみマスク対象
can_mask = ~padding_mask # [B, T]
# ランダムにmask_prob分を選択
rand = torch.rand_like(input_ids, dtype=torch.float)
rand[~can_mask] = 1.0 # パディング位置は選ばれないようにする
mask_positions = rand < self.mask_prob # [B, T]
# マスク済み入力を生成
masked_ids = input_ids.clone()
masked_ids[mask_positions] = self.mask_token_id
return masked_ids, mask_positions
def forward(self, input_ids, padding_mask=None):
"""
Args:
input_ids: [B, T] 入力トークンID
padding_mask: [B, T] パディングマスク(True=パディング)
Returns:
gen_loss, disc_loss, total_loss
"""
if padding_mask is None:
padding_mask = torch.zeros_like(input_ids, dtype=torch.bool)
B, T = input_ids.shape
# ステップ1: マスク位置の選択
masked_ids, mask_positions = self._mask_tokens(input_ids, padding_mask)
# ステップ2: Generatorでマスク位置のトークンを予測
masked_embeds = self.token_embedding(masked_ids)
gen_logits = self.generator(masked_embeds, padding_mask) # [B, T, V]
# Generator損失(マスク位置のみ)
gen_loss = F.cross_entropy(
gen_logits[mask_positions], # マスク位置のlogits
input_ids[mask_positions], # 正解トークン
reduction='mean'
)
# ステップ3: Generatorの出力からトークンをサンプリング
with torch.no_grad():
gen_probs = F.softmax(gen_logits, dim=-1) # [B, T, V]
# マスク位置でサンプリング
sampled_ids = input_ids.clone()
if mask_positions.any():
mask_probs = gen_probs[mask_positions] # [num_masked, V]
sampled_tokens = torch.multinomial(mask_probs, 1).squeeze(-1)
sampled_ids[mask_positions] = sampled_tokens
# ステップ4: Discriminatorで全位置の真贋を判定
corrupted_embeds = self.token_embedding(sampled_ids)
disc_logits = self.discriminator(corrupted_embeds, padding_mask)
# Discriminator損失(全位置)
# ラベル: トークンがオリジナルと一致→1(本物)、不一致→0(偽物)
is_replaced = (sampled_ids != input_ids).float() # 1=偽物, 0=本物
active = ~padding_mask # パディングでない位置のみ計算
disc_loss = F.binary_cross_entropy_with_logits(
disc_logits[active],
1.0 - is_replaced[active], # 本物=1, 偽物=0
reduction='mean'
)
# ステップ5: 結合損失
total_loss = gen_loss + self.lambda_disc * disc_loss
return gen_loss, disc_loss, total_loss
このコードのポイントをいくつか解説します。
_mask_tokens メソッドでは、パディング位置をマスク対象から除外しています。乱数を生成し、パディング位置の乱数を1.0に設定することで、rand < self.mask_prob の条件を満たさないようにしています。
forward メソッドのステップ3で torch.no_grad() を使っているのは、サンプリング操作がDiscriminatorへの入力を生成するだけで、Generatorの勾配計算には関与しないためです。これにより、Discriminatorの損失がGeneratorに逆伝播しないことを保証しています。
is_replaced の計算では、サンプリングされたトークンとオリジナルが異なる位置を「偽物」としてラベル付けしています。マスク位置であっても、Generatorがたまたま正解を生成した場合は「本物」扱いになります。
学習ループの実行
実際に小規模なデータでELECTRAの学習を動かしてみましょう。
import numpy as np
# 設定
vocab_size = 1000
max_len = 32
batch_size = 16
num_epochs = 50
# 特殊トークン
PAD_ID = 0
MASK_ID = 4
# 合成データの生成(簡易的なパターンのあるデータ)
def generate_synthetic_data(num_samples, max_len, vocab_size):
"""パターンのある合成データを生成"""
data = []
for _ in range(num_samples):
seq_len = np.random.randint(max_len // 2, max_len)
# 基本パターン: 連続するトークンのペア
seq = []
for i in range(seq_len):
base = np.random.randint(5, vocab_size // 2)
seq.append(base + (i % 3))
# パディング
seq = seq[:max_len]
seq += [PAD_ID] * (max_len - len(seq))
data.append(seq)
return torch.tensor(data, dtype=torch.long)
# データ生成
train_data = generate_synthetic_data(500, max_len, vocab_size)
print(f"学習データ: {train_data.shape}")
# モデル初期化
model = ELECTRA(
vocab_size=vocab_size,
embed_dim=64,
gen_hidden=32, gen_layers=2, gen_heads=2, gen_ff=128,
disc_hidden=128, disc_layers=4, disc_heads=4, disc_ff=512,
max_len=max_len,
mask_token_id=MASK_ID,
lambda_disc=50.0,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
# 学習ループ
gen_losses, disc_losses, total_losses = [], [], []
for epoch in range(num_epochs):
model.train()
epoch_gen, epoch_disc, epoch_total = 0.0, 0.0, 0.0
num_batches = 0
indices = torch.randperm(len(train_data))
for start in range(0, len(train_data), batch_size):
batch_idx = indices[start:start + batch_size]
input_ids = train_data[batch_idx]
padding_mask = (input_ids == PAD_ID)
gen_loss, disc_loss, total_loss = model(input_ids, padding_mask)
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
epoch_gen += gen_loss.item()
epoch_disc += disc_loss.item()
epoch_total += total_loss.item()
num_batches += 1
gen_losses.append(epoch_gen / num_batches)
disc_losses.append(epoch_disc / num_batches)
total_losses.append(epoch_total / num_batches)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1:3d} | "
f"Gen Loss: {gen_losses[-1]:.4f} | "
f"Disc Loss: {disc_losses[-1]:.4f} | "
f"Total Loss: {total_losses[-1]:.4f}")
学習が進むにつれて、Generator損失は低下していき(マスク位置の予測が改善される)、Discriminator損失も低下していきます(真贋判定の精度が向上する)。Generator損失の低下は、Generatorがより「もっともらしい」トークンを生成することを意味するため、Discriminatorにとってタスクが難しくなります。しかしDiscriminator自身も学習が進むため、損失は全体として減少していくのが理想的な挙動です。
学習曲線の可視化
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(gen_losses, color='#00bcd4', linewidth=2)
axes[0].set_title('Generator Loss (MLM)', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].grid(True, alpha=0.3)
axes[1].plot(disc_losses, color='#ff9800', linewidth=2)
axes[1].set_title('Discriminator Loss (BCE)', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].grid(True, alpha=0.3)
axes[2].plot(total_losses, color='#4caf50', linewidth=2)
axes[2].set_title('Total Loss (Gen + λ·Disc)', fontsize=12)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Loss')
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('electra_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
上のグラフから、ELECTRAの学習ダイナミクスの特徴が読み取れます。Generator損失(左)はMLM損失であり、Generatorがマスク位置のトークンをより正確に予測できるようになるにつれて単調に減少します。Discriminator損失(中央)はバイナリクロスエントロピーであり、真贋判定の精度が向上するにつれて低下します。総合損失(右)は $\lambda = 50$ のスケーリングにより、Discriminator損失が支配的であることが確認できます。Discriminatorの学習が全体の損失低下を牽引しているのがわかります。
Discriminatorの判定精度の確認
学習後のDiscriminatorが実際にどの程度正確に真贋を判定できるかを確認しましょう。
model.eval()
with torch.no_grad():
# テストデータで評価
test_data = generate_synthetic_data(100, max_len, vocab_size)
test_padding = (test_data == PAD_ID)
# マスクとGeneratorによる置換を実行
masked_ids, mask_pos = model._mask_tokens(test_data, test_padding)
masked_embeds = model.token_embedding(masked_ids)
gen_logits = model.generator(masked_embeds, test_padding)
gen_probs = F.softmax(gen_logits, dim=-1)
# マスク位置でサンプリング
sampled_ids = test_data.clone()
if mask_pos.any():
sampled_tokens = torch.multinomial(gen_probs[mask_pos], 1).squeeze(-1)
sampled_ids[mask_pos] = sampled_tokens
# Discriminatorで判定
corrupted_embeds = model.token_embedding(sampled_ids)
disc_logits = model.discriminator(corrupted_embeds, test_padding)
disc_preds = (torch.sigmoid(disc_logits) > 0.5).long()
# 正解ラベル
is_original = (sampled_ids == test_data).long()
active = ~test_padding
# 精度計算
correct = (disc_preds[active] == is_original[active]).float()
accuracy = correct.mean().item()
# 置換された位置のみの精度
is_replaced_mask = (sampled_ids != test_data) & active
if is_replaced_mask.any():
replaced_acc = (disc_preds[is_replaced_mask] == 0).float().mean().item()
else:
replaced_acc = 0.0
# 本物の位置の精度
is_real_mask = (sampled_ids == test_data) & active
real_acc = (disc_preds[is_real_mask] == 1).float().mean().item()
print(f"Discriminator 全体精度: {accuracy:.4f}")
print(f" 偽物の検出率 (replaced → 0): {replaced_acc:.4f}")
print(f" 本物の正答率 (original → 1): {real_acc:.4f}")
print(f" 置換率: {is_replaced_mask.sum().item() / active.sum().item():.4f}")
この評価から、Discriminatorの判定能力の内訳がわかります。全体精度に加えて、偽物の検出率(置換されたトークンを正しく偽物と判定する能力)と、本物の正答率(オリジナルのトークンを正しく本物と判定する能力)を分けて見ることが重要です。置換率は、マスク率15%に対してGeneratorが実際にオリジナルと異なるトークンを生成した割合を示します。Generatorの精度が上がると置換率は下がる(正解を生成してしまう確率が上がるため)ことも確認できます。この「Generatorが上手くなるとDiscriminatorの学習データが変化する」ダイナミクスが、ELECTRAの学習をGANと類似させつつも安定させている要因の一つです。
性能比較 — ELECTRAの学習効率とスケーリング
ELECTRAの真価は、同じ計算量でBERTをどれだけ上回るかにあります。ここでは、Clark et al.(2020)の論文で報告された主要な実験結果を整理し、ELECTRAの学習効率の優位性を分析します。
ELECTRA-Small vs BERT-Small: 小規模モデルでの比較
まず、計算資源が限られた状況での比較を見てみましょう。ELECTRA-SmallとBERT-Smallは、同じ計算量(1つのGPUで4日間相当)で事前学習されたモデルです。
| モデル | パラメータ数 | GLUE平均 | SQuAD 2.0 (EM) |
|---|---|---|---|
| BERT-Small | 14M | 75.1 | — |
| ELECTRA-Small | 14M | 79.9 | 65.1 |
| 改善幅 | 同じ | +4.8 | — |
同じパラメータ数、同じ計算量で、GLUE平均スコアが4.8ポイントも向上しています。これは、MLMの15%学習効率問題がいかに深刻だったかを示しています。
ELECTRA-Base vs BERT-Base: 標準モデルでの比較
次に、より大きなモデルでの比較です。
| モデル | パラメータ数 | 計算量(FLOPs) | GLUE平均 |
|---|---|---|---|
| BERT-Base | 110M | $6.4 \times 10^{18}$ | 82.2 |
| ELECTRA-Base | 110M | $6.4 \times 10^{18}$ | 85.1 |
| BERT-Large | 340M | $5.1 \times 10^{19}$ | 84.4 |
| ELECTRA-Base | 110M | $6.4 \times 10^{18}$ | 85.1 |
注目すべきは、ELECTRA-Base(110M パラメータ)が BERT-Large(340M パラメータ)を上回っている点です。BERT-Largeの計算コストは ELECTRA-Baseの約8倍ですが、ELECTRA-Baseの方が高いGLUEスコアを達成しています。
計算量と性能の関係
ELECTRAの学習効率を視覚化するために、計算量(FLOPs)を横軸、GLUE平均スコアを縦軸にプロットしてみましょう。
import matplotlib.pyplot as plt
import numpy as np
# 論文の実験データ(概算値)
models = {
'BERT-Small': {'flops': 1.4e17, 'glue': 75.1, 'color': '#2196f3', 'marker': 's'},
'ELECTRA-Small': {'flops': 1.4e17, 'glue': 79.9, 'color': '#ff5722', 'marker': 'o'},
'BERT-Base': {'flops': 6.4e18, 'glue': 82.2, 'color': '#2196f3', 'marker': 's'},
'ELECTRA-Base': {'flops': 6.4e18, 'glue': 85.1, 'color': '#ff5722', 'marker': 'o'},
'BERT-Large': {'flops': 5.1e19, 'glue': 84.4, 'color': '#2196f3', 'marker': 's'},
'ELECTRA-Large': {'flops': 5.1e19, 'glue': 88.5, 'color': '#ff5722', 'marker': 'o'},
'RoBERTa': {'flops': 3.2e20, 'glue': 88.5, 'color': '#9c27b0', 'marker': '^'},
'XLNet': {'flops': 3.2e20, 'glue': 88.4, 'color': '#4caf50', 'marker': 'D'},
}
fig, ax = plt.subplots(figsize=(10, 6))
for name, info in models.items():
ax.scatter(info['flops'], info['glue'],
c=info['color'], marker=info['marker'],
s=120, zorder=5, edgecolors='white', linewidths=0.5)
offset_x = 1.2
offset_y = 0.3 if 'ELECTRA' in name else -0.5
ax.annotate(name, (info['flops'], info['glue']),
xytext=(info['flops'] * offset_x, info['glue'] + offset_y),
fontsize=9, ha='left')
# BERTとELECTRAをそれぞれ線で結ぶ
bert_flops = [1.4e17, 6.4e18, 5.1e19]
bert_glue = [75.1, 82.2, 84.4]
electra_flops = [1.4e17, 6.4e18, 5.1e19]
electra_glue = [79.9, 85.1, 88.5]
ax.plot(bert_flops, bert_glue, '--', color='#2196f3', alpha=0.5, label='BERT')
ax.plot(electra_flops, electra_glue, '-', color='#ff5722', alpha=0.5, label='ELECTRA')
ax.set_xscale('log')
ax.set_xlabel('Pre-training FLOPs (log scale)', fontsize=12)
ax.set_ylabel('GLUE Average Score', fontsize=12)
ax.set_title('Compute Efficiency: ELECTRA vs BERT', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_ylim(73, 90)
plt.tight_layout()
plt.savefig('electra_compute_efficiency.png', dpi=150, bbox_inches='tight')
plt.show()
上のグラフから、ELECTRAの計算効率の優位性が明確に読み取れます。同じFLOPs(横軸の同じ位置)で比較すると、ELECTRAのGLUEスコア(縦軸)は常にBERTを上回っています。特に小規模モデル(左端)での差が顕著であり、計算リソースが限られた環境でこそELECTRAの恩恵が大きいことがわかります。さらに注目すべきは、ELECTRA-BaseがBERT-Largeを上回る性能を示す点です。これは「同じ計算予算があるなら、大きなBERTを学習するよりも、適切なサイズのELECTRAを学習する方が高性能」ということを意味します。
なぜELECTRAの効率がこれほど高いのか — 定量的分析
ELECTRAの効率が高い理由を、学習信号の観点から定量的に分析してみましょう。
import matplotlib.pyplot as plt
import numpy as np
# 1エポックあたりの学習信号量を比較
seq_len = 512
mask_ratio = 0.15
num_samples = 10000
# MLM: マスク位置のみ(語彙分類, V=30522)
mlm_signals_per_sample = int(seq_len * mask_ratio)
mlm_bits_per_signal = np.log2(30522) # 約14.9ビット
mlm_total_bits = mlm_signals_per_sample * mlm_bits_per_signal
# RTD: 全位置(二値分類)
rtd_signals_per_sample = seq_len
rtd_bits_per_signal = 1.0 # 1ビット(本物/偽物)
rtd_total_bits = rtd_signals_per_sample * rtd_bits_per_signal
print("=== 1サンプルあたりの学習信号分析 ===")
print(f"MLM: {mlm_signals_per_sample}位置 × {mlm_bits_per_signal:.1f}ビット"
f" = {mlm_total_bits:.0f}ビット")
print(f"RTD: {rtd_signals_per_sample}位置 × {rtd_bits_per_signal:.1f}ビット"
f" = {rtd_total_bits:.0f}ビット")
print(f"位置数の比: RTD/MLM = {rtd_signals_per_sample/mlm_signals_per_sample:.1f}x")
print(f"総ビットの比: MLM/RTD = {mlm_total_bits/rtd_total_bits:.1f}x")
print()
print("→ MLMは1位置あたりの情報量は大きいが、位置数が少ない")
print("→ RTDは1位置あたりの情報量は小さいが、全位置をカバー")
print("→ 実験的にはRTDの方が効率的(位置数の多さが勝る)")
# 可視化: 入力トークンのうち学習に使われる割合
fig, axes = plt.subplots(1, 2, figsize=(12, 3))
# MLM
tokens_mlm = np.zeros(20)
mask_indices = np.random.choice(20, size=3, replace=False)
tokens_mlm[mask_indices] = 1
colors_mlm = ['#ff5722' if t == 1 else '#e0e0e0' for t in tokens_mlm]
axes[0].bar(range(20), [1]*20, color=colors_mlm, edgecolor='white')
axes[0].set_title(f'MLM: {int(mask_ratio*100)}% of tokens used (red)', fontsize=11)
axes[0].set_xlabel('Token position')
axes[0].set_yticks([])
# RTD
colors_rtd = ['#ff5722'] * 20
axes[1].bar(range(20), [1]*20, color=colors_rtd, edgecolor='white')
axes[1].set_title('RTD: 100% of tokens used (red)', fontsize=11)
axes[1].set_xlabel('Token position')
axes[1].set_yticks([])
plt.tight_layout()
plt.savefig('mlm_vs_rtd_signals.png', dpi=150, bbox_inches='tight')
plt.show()
上の分析と可視化から、MLMとRTDの学習信号の性質の違いが明確になります。MLMは1位置あたりの情報量が大きい(語彙全体から正解を選ぶため約14.9ビット)のに対し、RTDは1位置あたり1ビット(本物/偽物)しかありません。単純にビット数を合計するとMLMの方が多くなります。しかし実際には、RTDの「全位置からの学習」が実効的な学習効率ではMLMを上回ります。これは、各位置での判定がトークンの局所的な特徴だけでなく文全体の文脈理解を必要とするため、1ビットの判定に含まれる「暗黙の情報量」が非常に大きいためと解釈できます。
Generatorサイズの影響
最後に、Generatorのサイズがモデル全体の性能にどう影響するかを確認しましょう。
import matplotlib.pyplot as plt
# Clark et al. の消去実験結果(概算値)
gen_size_ratio = [0.125, 0.25, 0.33, 0.5, 1.0]
glue_scores = [83.2, 84.3, 84.9, 84.7, 83.5]
labels = ['1/8', '1/4', '1/3', '1/2', '1/1\n(same size)']
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(range(len(gen_size_ratio)), glue_scores,
color=['#90a4ae', '#00bcd4', '#ff5722', '#00bcd4', '#90a4ae'],
edgecolor='white', width=0.6)
ax.set_xticks(range(len(gen_size_ratio)))
ax.set_xticklabels(labels, fontsize=11)
ax.set_xlabel('Generator size / Discriminator size', fontsize=12)
ax.set_ylabel('GLUE Average Score', fontsize=12)
ax.set_title('Effect of Generator Size on ELECTRA Performance', fontsize=13)
ax.set_ylim(82.5, 85.5)
ax.grid(True, axis='y', alpha=0.3)
# 最適値をハイライト
ax.annotate('Optimal', xy=(2, 84.9), xytext=(2, 85.3),
fontsize=11, ha='center', color='#ff5722',
arrowprops=dict(arrowstyle='->', color='#ff5722'))
for i, (score, bar) in enumerate(zip(glue_scores, bars)):
ax.text(i, score + 0.05, f'{score}', ha='center', va='bottom', fontsize=10)
plt.tight_layout()
plt.savefig('generator_size_effect.png', dpi=150, bbox_inches='tight')
plt.show()
上のグラフから、Generatorのサイズには明確な最適点があることがわかります。Generatorが小さすぎる(1/8)と、生成されるトークンが明らかに不自然になり、Discriminatorの学習タスクが簡単すぎます。逆にGeneratorが大きすぎる(1/1、Discriminatorと同サイズ)と、ほぼ正解のトークンを生成してしまい、やはりDiscriminatorの学習が難しくなります。さらに、大きなGeneratorは計算コストも増大させ、Discriminatorに回せる計算予算を圧迫します。1/3〜1/4のサイズがちょうど良いバランスを提供しており、「適度に間違える出題者」がDiscriminatorの学習を最も効率化するという直感と一致します。
ELECTRAの設計判断のまとめと限界
ここまでの議論を踏まえて、ELECTRAの主要な設計判断とその根拠を整理します。
設計判断の一覧
| 設計判断 | 選択肢 | 採用された選択 | 根拠 |
|---|---|---|---|
| 事前学習タスク | MLM / RTD | RTD | 全トークンからの学習信号 |
| Generatorの学習 | 敵対的 / 最尤推定 | 最尤推定 | 離散トークンのため勾配が流れない、安定性 |
| Generatorのサイズ | 同サイズ / 小さく | 1/4〜1/3 | 適度な難易度、計算コスト |
| 重み共有 | なし / 埋め込みのみ / 全て | 埋め込みのみ | 性能改善と計算コストのバランス |
| $\lambda$(Disc重み) | 1 / 50 | 50 | 損失スケールの均衡化 |
| ファインチューニング対象 | Generator / Discriminator | Discriminator | 全トークンから学習した豊かな表現 |
ELECTRAの限界
ELECTRAにも限界はあります。
1. 下流タスクでの微調整の不整合: 事前学習ではDiscriminatorは二値分類(本物/偽物)を行いますが、下流タスクではトークン分類や系列分類など多様なタスクに使われます。RTDヘッドと下流タスクのヘッドの間には若干のギャップがあります。MLMの場合、穴埋め予測と言語理解タスクの間の類似性が高い側面もあります。
2. Generatorの学習コスト: ELECTRAは実質的に2つのモデル(Generator + Discriminator)を同時に学習します。Generatorは小さいとはいえ、BERTの単体学習と比べると追加の計算コストが発生します。ただし、この追加コストを大きく上回る学習効率の改善が得られるため、トータルでは有利です。
3. ハイパーパラメータの増加: Generatorのサイズ比率、$\lambda$ の値、埋め込み共有の方法など、BERTにはないハイパーパラメータが追加されます。これらの最適値はモデルのスケールやデータセットによって変わる可能性があり、チューニングの手間が増えます。
まとめ
本記事では、ELECTRAのReplaced Token Detectionによる効率的な事前学習について解説しました。
- BERTのMLMの非効率性: 入力の15%しか学習信号が得られないこと、
[MASK]トークンの不整合問題という2つの構造的な弱点を確認しました - ELECTRAのアーキテクチャ: 小さなGenerator(MLMモデル)がマスク位置にトークンを生成し、大きなDiscriminatorが全トークンで真贋を判定する構成により、入力の100%から学習信号を得られることを示しました
- 数学的定式化: Generator損失(MLM交差エントロピー)、Discriminator損失(バイナリクロスエントロピー)、および $\lambda = 50$ で重み付けした結合損失を定式化しました
- GANとの本質的な違い: テキストの離散性によりDiscriminatorの勾配がGeneratorに流れないため、GeneratorはMLMの最尤推定で独立に学習すること、これがGANと異なり学習の安定性につながることを解説しました
- Weight Sharing: トークン埋め込み層の共有が、計算コストを増やさずに性能を改善する効果的な設計判断であることを確認しました
- PyTorch実装: Generator、Discriminator、学習ループをスクラッチで実装し、学習曲線とDiscriminatorの判定精度を可視化しました
- 性能比較: ELECTRA-SmallはBERT-Smallを4.8ポイント上回り、ELECTRA-BaseはBERT-Largeを超える性能を約1/8の計算量で達成することを確認しました
ELECTRAの「全トークンから学習する」というアイデアは、事前学習の効率化という文脈で非常に重要な設計原理を提示しました。特に計算リソースが限られた環境では、BERTではなくELECTRAを選択することで、大幅な性能向上が期待できます。
ELECTRAの理解を深めるために、以下の記事も参考にしてください。