RoPE(回転位置埋め込み)の数学的導出と実装

RoPE(Rotary Position Embedding、回転位置埋め込み)は、Transformerにおける位置情報の埋め込み手法です。Su et al. (2021) の論文 “RoFormer” で提案され、現在ではLlama、Mistral、Qwenなど多くの最新LLMで採用されています。

RoPEの特徴は、回転行列を用いて位置情報を埋め込むことで、相対位置を自然に表現できる点です。本記事では、RoPEの数学的な背景から実装まで、段階的に解説します。

本記事の内容

  • 位置エンコーディングの課題とRoPEの解決策
  • 回転行列による位置埋め込みの数学
  • 相対位置の表現
  • 長系列への外挿(Length Extrapolation)
  • PyTorchでの実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

位置エンコーディングの背景

Self-Attentionの順序不変性

Self-Attentionは本質的に順序不変(permutation invariant)です。入力トークンの順序を入れ替えても、適切に入れ替えれば出力も同じように入れ替わるだけです。

$$ \text{Attention}(\bm{Q}\bm{P}, \bm{K}\bm{P}, \bm{V}\bm{P}) = \text{Attention}(\bm{Q}, \bm{K}, \bm{V})\bm{P} $$

ここで $\bm{P}$ は置換行列です。

しかし、自然言語では語順が意味を決定します。「犬が猫を追う」と「猫が犬を追う」は全く異なる意味です。そのため、何らかの方法で位置情報を注入する必要があります。

従来の位置エンコーディング

絶対位置エンコーディング(Sinusoidal / 学習可能):

原論文のTransformerでは、sin/cos関数による絶対位置エンコーディングを入力埋め込みに加算します。

$$ \bm{x}’_i = \bm{x}_i + \bm{PE}(i) $$

問題点: – 学習時より長い系列への汎化が困難 – 絶対位置は固定で、文脈に応じた相対関係を直接表現しない

相対位置エンコーディング:

Shaw et al. (2018) は、アテンションスコアに相対位置のバイアスを加える手法を提案しました。

$$ \text{score}_{ij} = \bm{q}_i^\top \bm{k}_j + \bm{q}_i^\top \bm{r}_{j-i} $$

問題点: – 相対位置の埋め込み $\bm{r}_{j-i}$ を別途学習する必要がある – 実装が複雑

RoPEの基本アイデア

目標

RoPEは以下の性質を満たす位置埋め込みを設計します:

  1. アテンションスコアが相対位置のみに依存: $\bm{q}_i^\top \bm{k}_j = f(\bm{x}_i, \bm{x}_j, j – i)$
  2. 追加のパラメータ不要: 位置情報は純粋に幾何学的な変換で表現
  3. 実装の効率性: 行列演算で効率的に計算可能

回転による位置埋め込み

RoPEの核心的アイデアは、ベクトルを位置に応じた角度で回転させることです。

2次元平面での回転を考えます。角度 $\theta$ の回転行列は:

$$ \bm{R}(\theta) = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix} $$

この回転行列には重要な性質があります:

$$ \bm{R}(\alpha)^\top \bm{R}(\beta) = \bm{R}(\beta – \alpha) $$

つまり、2つの回転されたベクトルの内積は、それらの角度の差(相対位置)にのみ依存します。

RoPEの数学的定義

2次元の場合

位置 $m$ のベクトル $\bm{x} = (x_1, x_2)$ に対して、RoPEは以下の回転を適用します:

$$ \bm{R}_m \bm{x} = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix} \begin{pmatrix} x_1 \\ x_2 \end{pmatrix} $$

ここで $\theta$ は基本角度(周波数)です。

位置 $m$ の Query $\bm{q}_m$ と位置 $n$ の Key $\bm{k}_n$ の内積は:

$$ \begin{align} (\bm{R}_m \bm{q})^\top (\bm{R}_n \bm{k}) &= \bm{q}^\top \bm{R}_m^\top \bm{R}_n \bm{k} \\ &= \bm{q}^\top \bm{R}_{n-m} \bm{k} \end{align} $$

これは位置 $m, n$ ではなく、相対位置 $n – m$ のみに依存します。

高次元への拡張

実際のTransformerでは、埋め込み次元 $d$ は64や128など大きな値です。RoPEは次元を2つずつのペアに分割し、各ペアに異なる周波数の回転を適用します。

$d$ 次元ベクトル $\bm{x} = (x_1, x_2, \ldots, x_d)$ に対して:

$$ \bm{R}_m^{(d)} = \begin{pmatrix} \bm{R}_m^{(1)} & & & \\ & \bm{R}_m^{(2)} & & \\ & & \ddots & \\ & & & \bm{R}_m^{(d/2)} \end{pmatrix} $$

各ブロック $\bm{R}_m^{(i)}$ は:

$$ \bm{R}_m^{(i)} = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix} $$

周波数の設計

原論文のsin/cos位置エンコーディングと同様に、各次元ペアに異なる周波数を割り当てます:

$$ \theta_i = 10000^{-2(i-1)/d}, \quad i = 1, 2, \ldots, d/2 $$

または、より一般的に基数 $\beta$(デフォルト: 10000)を用いて:

$$ \theta_i = \beta^{-2(i-1)/d} $$

低いインデックスの次元ペアは高周波数(細かい位置変化を捉える)、高いインデックスの次元ペアは低周波数(大きな位置変化を捉える)を持ちます。

効率的な計算形式

ブロック対角回転行列を明示的に作成する代わりに、要素ごとの演算で効率的に計算できます。

位置 $m$ のベクトル $\bm{x} = (x_1, x_2, \ldots, x_d)$ に対して:

$$ \text{RoPE}(\bm{x}, m) = \bm{x} \odot \cos(\bm{m}\bm{\theta}) + \text{rotate\_half}(\bm{x}) \odot \sin(\bm{m}\bm{\theta}) $$

ここで: – $\odot$ は要素ごとの積 – $\bm{\theta} = (\theta_1, \theta_1, \theta_2, \theta_2, \ldots, \theta_{d/2}, \theta_{d/2})$ – $\bm{m}\bm{\theta} = (m\theta_1, m\theta_1, m\theta_2, m\theta_2, \ldots)$ – $\text{rotate\_half}(\bm{x})$ は隣接する要素をペアで交換し、片方の符号を反転:

$(x_1, x_2, x_3, x_4, \ldots) \to (-x_2, x_1, -x_4, x_3, \ldots)$

相対位置の明示的な導出

RoPEがなぜ相対位置を表現できるのか、より詳しく見てみましょう。

位置 $m$ の Query: $\bm{q}_m = \bm{R}_m \bm{q}$ 位置 $n$ の Key: $\bm{k}_n = \bm{R}_n \bm{k}$

内積を計算:

$$ \begin{align} \bm{q}_m^\top \bm{k}_n &= (\bm{R}_m \bm{q})^\top (\bm{R}_n \bm{k}) \\ &= \bm{q}^\top \bm{R}_m^\top \bm{R}_n \bm{k} \end{align} $$

回転行列の性質 $\bm{R}(\alpha)^\top = \bm{R}(-\alpha)$ より:

$$ \bm{R}_m^\top \bm{R}_n = \bm{R}_{-m} \bm{R}_n = \bm{R}_{n-m} $$

したがって:

$$ \bm{q}_m^\top \bm{k}_n = \bm{q}^\top \bm{R}_{n-m} \bm{k} $$

この結果は、絶対位置 $m, n$ ではなく、相対位置 $n – m$ のみに依存します。

長系列への外挿

問題

学習時に最大長 $L$ の系列で訓練したモデルを、推論時に $L$ より長い系列で使用したい場合があります。従来の絶対位置エンコーディングでは、未見の位置に対する性能が著しく低下します。

RoPEの外挿性

RoPEは相対位置を直接エンコードするため、ある程度の外挿が可能です。ただし、訓練時に見なかった周波数成分が問題になることがあります。

位置補間(Position Interpolation)

Chen et al. (2023) は、位置インデックスをスケーリングする位置補間を提案しました。

目標長 $L’$ の系列を、訓練時の最大長 $L$ に収まるように位置をスケール:

$$ m’ = m \cdot \frac{L}{L’} $$

つまり、位置 $m$ を $m’$ として扱います。これにより、新しい位置でも訓練時に見た周波数範囲に収まります。

YaRN (Yet another RoPE extensioN)

Peng et al. (2023) は、周波数ごとに異なるスケーリングを適用するYaRNを提案しました。高周波成分は外挿しやすく、低周波成分は補間が必要という観察に基づいています。

PyTorchでの実装

基本的なRoPE実装

import torch
import torch.nn as nn
import math


class RotaryPositionEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""

    def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
        """
        Args:
            dim: 埋め込み次元(偶数である必要がある)
            max_seq_len: 最大系列長
            base: 周波数の基数
        """
        super().__init__()
        assert dim % 2 == 0, "dim must be even"

        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # 周波数を計算: theta_i = base^(-2i/dim)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # cos/sin キャッシュを事前計算
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        """cos/sin値のキャッシュを構築"""
        positions = torch.arange(seq_len, dtype=torch.float32)
        # (seq_len, dim/2)
        freqs = torch.outer(positions, self.inv_freq)
        # (seq_len, dim) - 各周波数を2回繰り返す
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """
        ベクトルの半分を回転
        (x1, x2, x3, x4, ...) -> (-x2, x1, -x4, x3, ...)
        """
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat([-x2, x1], dim=-1)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        position_ids: torch.Tensor = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        QueryとKeyにRoPEを適用

        Args:
            q: Query (batch_size, n_heads, seq_len, head_dim)
            k: Key (batch_size, n_heads, seq_len, head_dim)
            position_ids: 位置インデックス (batch_size, seq_len)
                         Noneの場合は0, 1, 2, ...を使用

        Returns:
            q_embed: RoPE適用後のQuery
            k_embed: RoPE適用後のKey
        """
        seq_len = q.shape[2]

        # 位置に対応するcos/sinを取得
        if position_ids is None:
            cos = self.cos_cached[:seq_len]
            sin = self.sin_cached[:seq_len]
        else:
            cos = self.cos_cached[position_ids]
            sin = self.sin_cached[position_ids]

        # (seq_len, dim) -> (1, 1, seq_len, dim) にブロードキャスト
        cos = cos.unsqueeze(0).unsqueeze(0)
        sin = sin.unsqueeze(0).unsqueeze(0)

        # RoPE適用
        q_embed = (q * cos) + (self._rotate_half(q) * sin)
        k_embed = (k * cos) + (self._rotate_half(k) * sin)

        return q_embed, k_embed


# 使用例
batch_size = 2
n_heads = 8
seq_len = 128
head_dim = 64

rope = RotaryPositionEmbedding(dim=head_dim)

# ランダムなQ, K
torch.manual_seed(42)
q = torch.randn(batch_size, n_heads, seq_len, head_dim)
k = torch.randn(batch_size, n_heads, seq_len, head_dim)

# RoPE適用
q_rope, k_rope = rope(q, k)

print(f"入力Qの形状: {q.shape}")
print(f"出力Qの形状: {q_rope.shape}")

アテンション計算への統合

class RoPEMultiHeadAttention(nn.Module):
    """RoPE付きMulti-Head Attention"""

    def __init__(
        self,
        d_model: int,
        n_heads: int,
        max_seq_len: int = 8192,
        rope_base: float = 10000.0,
    ):
        super().__init__()
        assert d_model % n_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.rope = RotaryPositionEmbedding(
            dim=self.head_dim,
            max_seq_len=max_seq_len,
            base=rope_base,
        )

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
        position_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Args:
            x: 入力 (batch_size, seq_len, d_model)
            mask: アテンションマスク
            position_ids: 位置インデックス

        Returns:
            output: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape

        # Q, K, V を計算
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # マルチヘッドに分割
        Q = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        # RoPE適用(QとKにのみ)
        Q, K = self.rope(Q, K, position_ids)

        # アテンション計算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores + mask

        attn_weights = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)

        # ヘッドを結合
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )

        output = self.W_o(context)
        return output


# 使用例
d_model = 256
n_heads = 8
batch_size = 2
seq_len = 64

model = RoPEMultiHeadAttention(d_model, n_heads)
x = torch.randn(batch_size, seq_len, d_model)

output = model(x)
print(f"入力形状: {x.shape}")
print(f"出力形状: {output.shape}")

可視化

回転パターンの可視化

import matplotlib.pyplot as plt
import numpy as np

# 異なる位置での2次元回転を可視化
positions = [0, 1, 2, 4, 8, 16]
theta = 0.1  # 基本角度

fig, ax = plt.subplots(figsize=(8, 8))

# 元のベクトル
x0 = np.array([1.0, 0.0])

for pos in positions:
    angle = pos * theta
    rotation = np.array([
        [np.cos(angle), -np.sin(angle)],
        [np.sin(angle), np.cos(angle)]
    ])
    x_rotated = rotation @ x0

    ax.arrow(0, 0, x_rotated[0], x_rotated[1],
             head_width=0.05, head_length=0.03, fc=f'C{positions.index(pos)}',
             ec=f'C{positions.index(pos)}', linewidth=2,
             label=f'position {pos}')

ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper left')
ax.set_title('RoPE: Vector Rotation by Position', fontsize=14)
ax.set_xlabel('Dimension 1')
ax.set_ylabel('Dimension 2')
plt.tight_layout()
plt.show()

周波数成分の可視化

import matplotlib.pyplot as plt
import numpy as np

# 異なる次元での周波数を可視化
dim = 64
base = 10000
positions = np.arange(0, 100)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 各次元ペアの周波数
dims_to_plot = [0, 15, 31, 31]  # dim/2のインデックス
labels = ['High frequency (dim 0-1)', 'Mid frequency (dim 30-31)',
          'Low frequency (dim 62-63)', 'cos vs sin']

for idx, (ax, d) in enumerate(zip(axes.flat[:3], dims_to_plot[:3])):
    theta = 1.0 / (base ** (2 * d / dim))
    angles = positions * theta

    ax.plot(positions, np.cos(angles), label='cos', linewidth=2)
    ax.plot(positions, np.sin(angles), label='sin', linewidth=2, linestyle='--')
    ax.set_xlabel('Position')
    ax.set_ylabel('Value')
    ax.set_title(f'{labels[idx]} (theta={theta:.6f})')
    ax.legend()
    ax.grid(True, alpha=0.3)

# 全次元の周波数を一覧
ax = axes[1, 1]
dim_indices = np.arange(0, dim // 2)
thetas = 1.0 / (base ** (2 * dim_indices / dim))
ax.semilogy(dim_indices * 2, thetas, 'o-', linewidth=2, markersize=4)
ax.set_xlabel('Dimension Index')
ax.set_ylabel('Frequency (theta)')
ax.set_title('Frequency vs Dimension')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

RoPEの利点と考慮点

利点

  1. 相対位置の自然な表現: 内積が相対位置のみに依存
  2. 追加パラメータ不要: 純粋に幾何学的な変換
  3. 効率的な計算: 要素ごとの演算で実装可能
  4. 長系列への外挿: 相対位置ベースのため、ある程度の外挿が可能

考慮点

  1. 複素数解釈: RoPEは実数の回転だが、複素数として解釈することもできる
  2. 次元の偶数制約: 2次元ペアで回転するため、次元数は偶数である必要がある
  3. 基数の選択: $\beta = 10000$ が一般的だが、タスクによって調整が必要な場合がある

まとめ

本記事では、RoPE(Rotary Position Embedding)の仕組みと実装について解説しました。

  • 基本アイデア: ベクトルを位置に応じた角度で回転させることで、位置情報を埋め込む
  • 相対位置の表現: 回転行列の性質により、内積が相対位置 $n – m$ のみに依存する
  • 効率的な計算: ブロック対角行列ではなく、要素ごとの cos/sin 演算で実装
  • 周波数設計: 低次元は高周波、高次元は低周波を持ち、異なるスケールの位置情報を捉える
  • 長系列への外挿: Position Interpolation や YaRN により、訓練長を超える系列に対応可能

RoPEは、現代のLLMにおいて標準的な位置埋め込み手法となっています。Llama、Mistral、Qwenなど多くのモデルで採用されており、その効果は実証済みです。

次のステップとして、以下の記事も参考にしてください。