BERTの仕組みと双方向Transformerの理論と実装

BERT(Bidirectional Encoder Representations from Transformers)は、2018年にGoogleが発表した事前学習済み言語モデルです。自然言語処理の様々なタスクにおいてSOTA(State-of-the-Art)を更新し、NLP分野に革命をもたらしました。

BERTの革新性は、双方向(Bidirectional)の文脈を活用した事前学習にあります。従来の言語モデル(GPTなど)は左から右への一方向のみでしたが、BERTは文の両方向から文脈を捉えることで、より深い言語理解を実現しました。

本記事の内容

  • BERTの全体アーキテクチャ(Encoder-only構造)
  • Masked Language Model(MLM)による事前学習
  • Next Sentence Prediction(NSP)タスク
  • 入力表現(トークン埋め込み + セグメント埋め込み + 位置埋め込み)
  • ファインチューニングの方法
  • PyTorchでの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

BERTの登場背景

従来のアプローチの限界

BERTが登場する以前、自然言語処理では以下のアプローチが主流でした。

1. Word2Vec / GloVe(静的な単語埋め込み)

単語を固定次元のベクトルに変換する手法です。しかし、同じ単語は常に同じベクトルを持つため、文脈による意味の違い(例:「bank」が銀行なのか川岸なのか)を区別できません。

2. ELMo(文脈を考慮した埋め込み)

双方向LSTMを用いて文脈を考慮した単語表現を生成します。しかし、LSTMの逐次処理により並列化が困難であり、また左から右・右から左の情報を独立に処理した後に連結するため、真の双方向ではありませんでした。

3. GPT(Transformer + 単方向言語モデル)

TransformerのDecoder(因果マスク付き)を用いた言語モデルです。強力な生成能力を持ちますが、左から右への一方向のみの文脈しか利用できないため、文の理解タスクには制約がありました。

BERTの革新

BERTは以下の点で革新的でした。

  1. 真の双方向性: Masked Language Model(MLM)により、マスクされたトークンを予測する際に左右両方の文脈を同時に参照できる
  2. 事前学習とファインチューニング: 大規模コーパスで事前学習し、少量のラベル付きデータでファインチューニングすることで、様々なタスクに適用可能
  3. Transformerベース: Self-Attentionにより長距離依存を効率的に捉え、並列計算が可能

BERTのアーキテクチャ

Encoder-only構造

BERTはTransformerのEncoderのみを使用します。Decoder(因果マスク付きSelf-Attention + Cross-Attention)は含まれていません。

[CLS] トークン1 トークン2 ... トークンN [SEP]
  │       │        │          │      │
  ▼       ▼        ▼          ▼      ▼
┌─────────────────────────────────────────┐
│      Token Embedding (E_token)          │
│    + Segment Embedding (E_segment)      │
│    + Position Embedding (E_pos)         │
└─────────────────────────────────────────┘
  │
  ▼
┌─────────────────────────────────────────┐
│       Transformer Encoder Layer × L     │
│  ┌─────────────────────────────────┐    │
│  │  Multi-Head Self-Attention      │    │
│  │  Add & LayerNorm                │    │
│  │  Feed-Forward Network           │    │
│  │  Add & LayerNorm                │    │
│  └─────────────────────────────────┘    │
└─────────────────────────────────────────┘
  │
  ▼
[CLS]表現  トークン1表現  ...  [SEP]表現

モデルサイズ

BERTには2つのサイズがあります。

モデル 層数 $L$ 隠れ層次元 $H$ ヘッド数 $A$ パラメータ数
BERT-Base 12 768 12 110M
BERT-Large 24 1024 16 340M

各Encoder層のFeed-Forward Networkの中間次元は $4H$ です(Base: 3072, Large: 4096)。

特殊トークン

BERTは以下の特殊トークンを使用します。

  • [CLS]: 各入力の先頭に配置。このトークンに対応する最終層の出力は、文全体の表現として分類タスクに使用される
  • [SEP]: 文の区切りを示す。2文を入力する場合、文1と文2の間、および末尾に配置される
  • [MASK]: MLM事前学習でマスクされたトークンの位置に配置される
  • [PAD]: バッチ処理でシーケンス長を揃えるためのパディング

入力表現

BERTの入力は、3種類の埋め込みの和で構成されます。

$$ \bm{E}_{\text{input}} = \bm{E}_{\text{token}} + \bm{E}_{\text{segment}} + \bm{E}_{\text{position}} $$

1. トークン埋め込み(Token Embedding)

各トークンを $H$ 次元のベクトルに変換します。WordPieceトークナイザによりサブワード単位で分割されたトークンに対して埋め込みを行います。

$$ \bm{E}_{\text{token}} = \text{Embedding}(\text{token\_id}) \in \mathbb{R}^{H} $$

2. セグメント埋め込み(Segment Embedding)

2文を入力する場合に、各トークンがどちらの文に属するかを示します。

$$ \bm{E}_{\text{segment}} = \begin{cases} \bm{E}_A & \text{if token belongs to sentence A} \\ \bm{E}_B & \text{if token belongs to sentence B} \end{cases} $$

例えば、入力が [CLS] I love cats [SEP] Dogs are great [SEP] の場合、セグメントIDは [0, 0, 0, 0, 0, 1, 1, 1, 1] となります。

3. 位置埋め込み(Position Embedding)

原論文のTransformerではsin/cos関数による固定の位置エンコーディングを使用しましたが、BERTでは学習可能な位置埋め込みを使用します。

$$ \bm{E}_{\text{position}} = \bm{P}[pos] \in \mathbb{R}^{H} $$

ここで $\bm{P} \in \mathbb{R}^{L_{\max} \times H}$ は学習可能な行列($L_{\max}$ は最大系列長、通常512)です。

事前学習タスク

BERTは2つの事前学習タスクで訓練されます。

Masked Language Model(MLM)

MLMはBERTの中核となる事前学習タスクです。入力トークン列の約15%をランダムに選択し、以下のルールで置換します。

  • 80%: [MASK] トークンに置換
  • 10%: ランダムなトークンに置換
  • 10%: そのまま(変更なし)

モデルは、マスクされた位置の元のトークンを予測します。

$$ \mathcal{L}_{\text{MLM}} = -\sum_{i \in \mathcal{M}} \log P(x_i \mid \tilde{\bm{x}}) $$

ここで $\mathcal{M}$ はマスクされた位置の集合、$\tilde{\bm{x}}$ はマスク処理後の入力列です。

なぜ100%を[MASK]にしないのか?

ファインチューニング時には[MASK]トークンは出現しないため、事前学習とファインチューニングの間にミスマッチが生じます。10%をランダムトークンに、10%をそのままにすることで、モデルは[MASK]だけでなく全てのトークンの表現を学習する必要があり、より汎用的な表現を獲得します。

Next Sentence Prediction(NSP)

NSPは、2つの文AとBが与えられたとき、BがAの直後に続く文かどうかを予測する二値分類タスクです。

学習データの構成: – 50%: 実際に連続する2文(ラベル: IsNext) – 50%: ランダムに選ばれた2文(ラベル: NotNext)

$$ \mathcal{L}_{\text{NSP}} = -\log P(y \mid \bm{h}_{[\text{CLS}]}) $$

ここで $\bm{h}_{[\text{CLS}]}$ は[CLS]トークンに対応する最終層の出力、$y \in \{\text{IsNext}, \text{NotNext}\}$ です。

NSPは文間の関係を理解するタスク(質問応答、自然言語推論など)に役立つと考えられていました。ただし、後の研究(RoBERTa等)ではNSPの有効性に疑問が呈され、NSPを除去することで性能が向上することが報告されています。

全体の損失関数

事前学習の損失関数は、MLMとNSPの和です。

$$ \mathcal{L} = \mathcal{L}_{\text{MLM}} + \mathcal{L}_{\text{NSP}} $$

MLMの数式的詳細

MLMの出力層を数式で詳しく見ていきましょう。

マスクされた位置 $i$ について、Encoder最終層の出力を $\bm{h}_i \in \mathbb{R}^H$ とします。この出力を語彙全体の確率分布に変換します。

$$ \bm{z}_i = \bm{W}_{\text{embed}}^\top \text{GELU}(\bm{W}_h \bm{h}_i + \bm{b}_h) + \bm{b}_o $$

$$ P(x_i = w \mid \tilde{\bm{x}}) = \frac{\exp(z_{i,w})}{\sum_{w’=1}^{|V|} \exp(z_{i,w’})} = \text{softmax}(\bm{z}_i)_w $$

ここで: – $\bm{W}_h \in \mathbb{R}^{H \times H}$: 全結合層の重み – $\bm{b}_h \in \mathbb{R}^H$: バイアス – GELU: 活性化関数(Gaussian Error Linear Unit) – $\bm{W}_{\text{embed}} \in \mathbb{R}^{|V| \times H}$: トークン埋め込み行列(共有) – $\bm{b}_o \in \mathbb{R}^{|V|}$: 出力バイアス – $|V|$: 語彙サイズ

注目すべき点として、出力層の重み行列にトークン埋め込み行列 $\bm{W}_{\text{embed}}$ を転置して再利用しています(Weight Tying)。これによりパラメータ数を削減しつつ、入力と出力の表現空間を一貫させることができます。

ファインチューニング

事前学習済みBERTは、様々な下流タスクに対してファインチューニングできます。

文分類(Single Sentence Classification)

例: 感情分析、スパム検出

[CLS]トークンの出力を分類器に入力します。

$$ P(y \mid \text{sentence}) = \text{softmax}(\bm{W}_c \bm{h}_{[\text{CLS}]} + \bm{b}_c) $$

文ペア分類(Sentence Pair Classification)

例: 自然言語推論(NLI)、質問-回答ペア判定

2つの文を[SEP]で区切って入力し、[CLS]トークンの出力を分類に使用します。

入力: [CLS] 前提文 [SEP] 仮説文 [SEP]
出力: entailment / contradiction / neutral

系列ラベリング(Token Classification)

例: 固有表現抽出(NER)、品詞タグ付け

各トークンの出力を分類器に入力します。

$$ P(y_i \mid \text{token}_i) = \text{softmax}(\bm{W}_t \bm{h}_i + \bm{b}_t) $$

質問応答(Extractive Question Answering)

例: SQuAD

質問と文書を入力し、回答の開始位置と終了位置を予測します。

$$ P_{\text{start}}(i) = \frac{\exp(\bm{w}_s \cdot \bm{h}_i)}{\sum_j \exp(\bm{w}_s \cdot \bm{h}_j)} $$

$$ P_{\text{end}}(i) = \frac{\exp(\bm{w}_e \cdot \bm{h}_i)}{\sum_j \exp(\bm{w}_e \cdot \bm{h}_j)} $$

PyTorchでの実装

BERTの主要コンポーネントをPyTorchで実装してみましょう。

BERTの埋め込み層

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class BERTEmbedding(nn.Module):
    """BERTの入力埋め込み(トークン + セグメント + 位置)"""
    def __init__(self, vocab_size, hidden_size, max_len=512, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.segment_embedding = nn.Embedding(2, hidden_size)
        self.position_embedding = nn.Embedding(max_len, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, segment_ids):
        """
        Args:
            input_ids: (batch_size, seq_len) トークンID
            segment_ids: (batch_size, seq_len) セグメントID(0 or 1)
        Returns:
            embeddings: (batch_size, seq_len, hidden_size)
        """
        seq_len = input_ids.size(1)
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)

        token_emb = self.token_embedding(input_ids)
        segment_emb = self.segment_embedding(segment_ids)
        position_emb = self.position_embedding(position_ids)

        embeddings = token_emb + segment_emb + position_emb
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

GELU活性化関数

class GELU(nn.Module):
    """Gaussian Error Linear Unit"""
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))
        ))

GELUは以下の式で定義されます。

$$ \text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right] $$

上記の実装は近似式を使用しています。

Transformer Encoder層

class TransformerEncoderLayer(nn.Module):
    """BERTのEncoder層"""
    def __init__(self, hidden_size, num_heads, intermediate_size, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadSelfAttention(hidden_size, num_heads, dropout)
        self.attention_norm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size),
            GELU(),
            nn.Linear(intermediate_size, hidden_size),
            nn.Dropout(dropout)
        )
        self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        # Self-Attention + 残差接続 + LayerNorm
        attn_output = self.attention(x, attention_mask)
        x = self.attention_norm(x + self.dropout(attn_output))

        # FFN + 残差接続 + LayerNorm
        ffn_output = self.ffn(x)
        x = self.ffn_norm(x + ffn_output)
        return x


class MultiHeadSelfAttention(nn.Module):
    """マルチヘッドSelf-Attention"""
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super().__init__()
        assert hidden_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        batch_size, seq_len, _ = x.size()

        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            scores = scores + attention_mask

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
        output = self.out(context)
        return output

BERTモデル本体

class BERT(nn.Module):
    """BERTモデル"""
    def __init__(self, vocab_size, hidden_size=768, num_layers=12,
                 num_heads=12, intermediate_size=3072, max_len=512, dropout=0.1):
        super().__init__()
        self.embedding = BERTEmbedding(vocab_size, hidden_size, max_len, dropout)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_size, num_heads, intermediate_size, dropout)
            for _ in range(num_layers)
        ])
        self.hidden_size = hidden_size

    def forward(self, input_ids, segment_ids, attention_mask=None):
        """
        Args:
            input_ids: (batch_size, seq_len)
            segment_ids: (batch_size, seq_len)
            attention_mask: (batch_size, seq_len) - 1: 有効, 0: パディング
        Returns:
            sequence_output: (batch_size, seq_len, hidden_size)
            pooled_output: (batch_size, hidden_size) - [CLS]の出力
        """
        # Attention maskを適切な形式に変換
        if attention_mask is not None:
            # (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
            extended_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            # 0の位置に大きな負の値を加えてsoftmax後に0にする
            extended_mask = (1.0 - extended_mask) * -10000.0
        else:
            extended_mask = None

        x = self.embedding(input_ids, segment_ids)

        for layer in self.encoder_layers:
            x = layer(x, extended_mask)

        sequence_output = x
        pooled_output = x[:, 0]  # [CLS]トークンの出力

        return sequence_output, pooled_output

MLMヘッド

class BERTMLMHead(nn.Module):
    """Masked Language Model用の出力ヘッド"""
    def __init__(self, hidden_size, vocab_size, embedding_weight):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.gelu = GELU()
        self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12)
        # Weight Tying: 埋め込み行列を共有
        self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
        self.decoder.weight = embedding_weight
        self.bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, hidden_states):
        """
        Args:
            hidden_states: (batch_size, seq_len, hidden_size)
        Returns:
            logits: (batch_size, seq_len, vocab_size)
        """
        x = self.dense(hidden_states)
        x = self.gelu(x)
        x = self.layer_norm(x)
        logits = self.decoder(x) + self.bias
        return logits

動作確認

# ハイパーパラメータ
vocab_size = 30522
hidden_size = 768
num_layers = 12
num_heads = 12
intermediate_size = 3072

# モデル作成
bert = BERT(vocab_size, hidden_size, num_layers, num_heads, intermediate_size)

# ダミー入力
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
segment_ids = torch.zeros(batch_size, seq_len, dtype=torch.long)
attention_mask = torch.ones(batch_size, seq_len)

# 順伝播
sequence_output, pooled_output = bert(input_ids, segment_ids, attention_mask)
print(f"Sequence output shape: {sequence_output.shape}")  # (2, 128, 768)
print(f"Pooled output shape: {pooled_output.shape}")      # (2, 768)

# パラメータ数
total_params = sum(p.numel() for p in bert.parameters())
print(f"Total parameters: {total_params:,}")  # 約109M

BERTの応用例

文分類の例

class BERTForSequenceClassification(nn.Module):
    """文分類用BERT"""
    def __init__(self, bert, num_classes, dropout=0.1):
        super().__init__()
        self.bert = bert
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(bert.hidden_size, num_classes)

    def forward(self, input_ids, segment_ids, attention_mask=None):
        _, pooled_output = self.bert(input_ids, segment_ids, attention_mask)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits


# 使用例
num_classes = 3  # 例: positive, negative, neutral
classifier = BERTForSequenceClassification(bert, num_classes)
logits = classifier(input_ids, segment_ids, attention_mask)
print(f"Classification logits shape: {logits.shape}")  # (2, 3)

まとめ

本記事では、BERT(Bidirectional Encoder Representations from Transformers)について解説しました。

  • 双方向Transformer: BERTはTransformer Encoderのみを使用し、マスク言語モデル(MLM)により真の双方向文脈を学習する
  • 入力表現: トークン埋め込み + セグメント埋め込み + 位置埋め込みの3つの和で構成される
  • 事前学習タスク: MLM(15%のトークンをマスクして予測)とNSP(文の連続性を予測)
  • ファインチューニング: 事前学習済みモデルに小さな出力層を追加し、下流タスク向けに追加学習
  • 多様なタスクへの適用: 文分類、系列ラベリング、質問応答など、様々なNLPタスクに適用可能

BERTは、事前学習済み言語モデルの有効性を実証し、「事前学習 + ファインチューニング」パラダイムをNLP分野に定着させました。その後、RoBERTa、ALBERT、DistilBERTなど多くの改良モデルが提案されています。

次のステップとして、以下の記事も参考にしてください。