LLMの量子化(INT8/INT4)を理論から実装まで解説

量子化(Quantization)は、ニューラルネットワークの重みや活性化を低精度(INT8、INT4など)で表現することで、メモリ使用量と計算コストを削減する技術です。大規模言語モデル(LLM)の推論において、量子化は実用的なデプロイに欠かせない技術となっています。

本記事では、LLM向けの量子化技術について、数学的な背景から実装の考え方まで解説します。

本記事の内容

  • 量子化の基礎と数学
  • 対称量子化と非対称量子化
  • 重みのみ量子化(Weight-Only Quantization)
  • 量子化手法(GPTQ、AWQ、bitsandbytes)
  • 量子化による精度への影響
  • PyTorchでの実装の考え方

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

量子化の動機

メモリ削減

LLMのパラメータは通常、float32(4バイト)またはfloat16(2バイト)で保存されます。INT8(1バイト)やINT4(0.5バイト)に量子化することで、メモリ使用量を大幅に削減できます。

データ型 バイト数 70Bモデルのサイズ
FP32 4 280 GB
FP16/BF16 2 140 GB
INT8 1 70 GB
INT4 0.5 35 GB

INT4量子化により、70BモデルをH100(80GB)1台に収めることができます。

計算速度

低精度演算は高速です。特にINT8/INT4演算に最適化されたハードウェア(Tensor Coresなど)では、FP16と比較して2-4倍の高速化が可能です。

量子化の数学

基本概念

量子化は、連続的な実数値を離散的な整数値にマッピングする操作です。

浮動小数点数 $x$ を $b$ ビット整数 $x_q$ に量子化:

$$ x_q = \text{round}\left(\frac{x}{s}\right) + z $$

ここで: – $s$: スケール(scale)— 量子化の粒度を決める正の実数 – $z$: ゼロポイント(zero-point)— 量子化後の原点 – $\text{round}$: 最近接整数への丸め

逆量子化(dequantization): $$ \hat{x} = s \cdot (x_q – z) $$

量子化範囲

$b$ ビット符号付き整数の場合: $$ x_q \in [-2^{b-1}, 2^{b-1} – 1] $$

例: INT8 なら $[-128, 127]$、INT4 なら $[-8, 7]$

対称量子化と非対称量子化

対称量子化(Symmetric Quantization):

ゼロポイント $z = 0$ とし、正負対称に量子化。

$$ s = \frac{\max(|x|)}{2^{b-1} – 1} $$

$$ x_q = \text{round}\left(\frac{x}{s}\right) $$

$$ x_q = \text{clamp}(x_q, -2^{b-1}, 2^{b-1} – 1) $$

非対称量子化(Asymmetric Quantization):

データの範囲 $[x_{\min}, x_{\max}]$ を量子化範囲全体にマッピング。

$$ s = \frac{x_{\max} – x_{\min}}{2^b – 1} $$

$$ z = \text{round}\left(-\frac{x_{\min}}{s}\right) $$

量子化誤差

量子化誤差は、量子化と逆量子化の差:

$$ e = x – \hat{x} = x – s \cdot (x_q – z) $$

誤差の最大値はスケール $s$ の半分程度: $$ |e| \leq \frac{s}{2} $$

重みのみ量子化(Weight-Only Quantization)

概要

LLMの推論では、重みのみを量子化し、活性化はFP16で計算するアプローチが一般的です。

理由: 1. 重みは固定なので、事前に最適な量子化パラメータを計算できる 2. 活性化は入力依存で範囲が予測困難 3. 推論時のボトルネックはメモリ帯域(重みの読み込み)であることが多い

計算フロー

1. 量子化された重み W_q (INT8/INT4) をメモリから読み込み
2. 逆量子化して W_dequant (FP16) を得る
3. 活性化 X (FP16) と行列積: Y = X @ W_dequant

重みのメモリフットプリントは削減しつつ、計算はFP16で行うため精度劣化が少ない。

Per-tensor vs Per-channel vs Per-group

Per-tensor量子化

テンソル全体で1つのスケールとゼロポイントを使用。

$$ s = \frac{\max(|\bm{W}|)}{2^{b-1} – 1} $$

計算が単純だが、異なるチャネルで値の範囲が大きく異なる場合に精度が低下。

Per-channel量子化

出力チャネルごとに異なるスケールを使用。

$$ s_c = \frac{\max(|\bm{W}_{:, c}|)}{2^{b-1} – 1} $$

精度が向上するが、スケールの数が増える。

Per-group量子化

重みを固定サイズのグループに分割し、グループごとにスケールを計算。

グループサイズ $g$(例: 128)で分割:

$$ s_{g,i} = \frac{\max(|\bm{W}_{g \cdot i : g \cdot (i+1)}|)}{2^{b-1} – 1} $$

INT4量子化ではper-group(グループサイズ128)が一般的。

量子化手法

GPTQ(GPT Quantization)

Frantar et al. (2022) による手法。層ごとに最適な量子化パラメータを求める。

アイデア: 1. 層の出力誤差を最小化するように量子化 2. Hessian(二次導関数)を用いて重みの重要度を推定 3. 重要度の低い重みから順に量子化し、誤差を残りの重みに補償

量子化誤差 $\bm{E}$ を最小化:

$$ \min_{\bm{W}_q} ||\bm{X}\bm{W} – \bm{X}\bm{W}_q||_F^2 $$

ここで $\bm{X}$ はキャリブレーションデータの活性化。

AWQ(Activation-aware Weight Quantization)

Lin et al. (2023) による手法。活性化の大きさに基づいて重みの重要度を判断。

観察: – 活性化が大きいチャネルに対応する重みは重要 – これらの重みの量子化誤差が出力に大きく影響

対策: – 重要な重みのスケールを調整(スケールファクター $s$ を導入) – $\bm{W}’ = \bm{W} \cdot s$, $\bm{X}’ = \bm{X} / s$ として、等価な計算を維持しつつ量子化しやすくする

bitsandbytes

Hugging Faceエコシステムで広く使われるライブラリ。LLM.int8()とQLoRAを実装。

LLM.int8(): – 異常値(outlier)を検出し、FP16で計算 – 正常値のみINT8で計算 – 混合精度で精度を維持

QLoRA: – 量子化されたベースモデル + LoRAアダプタ – メモリ効率の良いファインチューニングを実現

PyTorchでの実装の考え方

シンプルな量子化の実装

import torch
import torch.nn as nn


def symmetric_quantize(x: torch.Tensor, bits: int = 8) -> tuple:
    """
    対称量子化

    Args:
        x: 入力テンソル
        bits: 量子化ビット数

    Returns:
        x_q: 量子化されたテンソル (int)
        scale: スケール
    """
    qmax = 2 ** (bits - 1) - 1
    qmin = -2 ** (bits - 1)

    # スケール計算
    scale = x.abs().max() / qmax

    # 量子化
    x_q = torch.round(x / scale).clamp(qmin, qmax).to(torch.int8)

    return x_q, scale


def symmetric_dequantize(x_q: torch.Tensor, scale: float) -> torch.Tensor:
    """
    逆量子化

    Args:
        x_q: 量子化されたテンソル
        scale: スケール

    Returns:
        x_dequant: 逆量子化されたテンソル
    """
    return x_q.float() * scale


# 使用例
torch.manual_seed(42)
W = torch.randn(256, 512)  # 元の重み

# 量子化
W_q, scale = symmetric_quantize(W, bits=8)
print(f"元の重み: dtype={W.dtype}, メモリ={W.numel() * 4} bytes")
print(f"量子化後: dtype={W_q.dtype}, メモリ={W_q.numel() * 1} bytes")

# 逆量子化
W_dequant = symmetric_dequantize(W_q, scale)

# 量子化誤差
error = (W - W_dequant).abs().mean()
print(f"平均絶対誤差: {error:.6f}")

Per-group量子化

def group_quantize(x: torch.Tensor, group_size: int = 128, bits: int = 4) -> tuple:
    """
    グループごとの量子化

    Args:
        x: 入力テンソル (in_features, out_features)
        group_size: グループサイズ
        bits: 量子化ビット数

    Returns:
        x_q: 量子化されたテンソル
        scales: グループごとのスケール
    """
    in_features, out_features = x.shape
    assert in_features % group_size == 0, "in_features must be divisible by group_size"

    qmax = 2 ** (bits - 1) - 1
    qmin = -2 ** (bits - 1)

    # グループに分割
    x_grouped = x.view(-1, group_size, out_features)
    n_groups = x_grouped.shape[0]

    # グループごとにスケール計算
    scales = x_grouped.abs().max(dim=1, keepdim=True).values / qmax
    scales = scales.clamp(min=1e-10)  # ゼロ除算防止

    # 量子化
    x_q = torch.round(x_grouped / scales).clamp(qmin, qmax)
    x_q = x_q.view(in_features, out_features).to(torch.int8)

    scales = scales.squeeze(1)  # (n_groups, out_features)

    return x_q, scales


def group_dequantize(
    x_q: torch.Tensor,
    scales: torch.Tensor,
    group_size: int = 128,
) -> torch.Tensor:
    """
    グループごとの逆量子化
    """
    in_features, out_features = x_q.shape

    x_q_grouped = x_q.view(-1, group_size, out_features).float()
    scales_expanded = scales.unsqueeze(1)  # (n_groups, 1, out_features)

    x_dequant = x_q_grouped * scales_expanded
    return x_dequant.view(in_features, out_features)


# 使用例
W = torch.randn(512, 1024)

# INT4 グループ量子化
W_q, scales = group_quantize(W, group_size=128, bits=4)
W_dequant = group_dequantize(W_q, scales, group_size=128)

error = (W - W_dequant).abs().mean()
print(f"INT4 グループ量子化の平均絶対誤差: {error:.6f}")

# スケールのメモリ
print(f"元のメモリ: {W.numel() * 4} bytes")
print(f"量子化後: 重み {W_q.numel() * 0.5} bytes + スケール {scales.numel() * 2} bytes")

量子化Linear層

class QuantizedLinear(nn.Module):
    """量子化されたLinear層"""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bits: int = 8,
        group_size: int = None,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bits = bits
        self.group_size = group_size

        # 量子化された重みを登録(学習は想定しない)
        self.register_buffer('weight_q', torch.zeros(in_features, out_features, dtype=torch.int8))
        self.register_buffer('scales', torch.zeros(in_features // (group_size or in_features), out_features))
        self.register_buffer('bias', None)

    @classmethod
    def from_float(cls, linear: nn.Linear, bits: int = 8, group_size: int = None):
        """float Linear層から量子化版を作成"""
        quant_linear = cls(
            linear.in_features,
            linear.out_features,
            bits=bits,
            group_size=group_size,
        )

        W = linear.weight.data.T  # (in, out)

        if group_size:
            W_q, scales = group_quantize(W, group_size, bits)
        else:
            W_q, scale = symmetric_quantize(W, bits)
            scales = scale.expand(1, linear.out_features)

        quant_linear.weight_q.copy_(W_q)
        quant_linear.scales.copy_(scales)

        if linear.bias is not None:
            quant_linear.bias = linear.bias.data.clone()

        return quant_linear

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 逆量子化
        if self.group_size:
            W = group_dequantize(self.weight_q, self.scales, self.group_size)
        else:
            W = self.weight_q.float() * self.scales

        # 行列積
        out = x @ W

        if self.bias is not None:
            out = out + self.bias

        return out


# 使用例
linear = nn.Linear(512, 1024)
quant_linear = QuantizedLinear.from_float(linear, bits=4, group_size=128)

x = torch.randn(2, 10, 512)
y_original = linear(x)
y_quantized = quant_linear(x)

print(f"出力の差(平均絶対誤差): {(y_original - y_quantized).abs().mean():.6f}")

量子化の影響

精度への影響

量子化 方式 Llama 7B Perplexity
FP16 5.68
INT8 Per-channel 5.70
INT4 GPTQ (g=128) 5.85
INT4 AWQ (g=128) 5.78

適切な量子化手法を使えば、INT4でも精度低下は限定的です。

速度への影響

量子化による速度向上は、ハードウェアと実装に依存します。

  • メモリバウンド(小バッチ): メモリ帯域の削減により高速化
  • 計算バウンド(大バッチ): 専用INT8演算器があれば高速化

可視化

重みの分布と量子化

import matplotlib.pyplot as plt
import numpy as np

# 重みの分布をシミュレート
np.random.seed(42)
weights = np.random.randn(10000) * 0.1

# INT8量子化
qmax = 127
scale = np.max(np.abs(weights)) / qmax
weights_q = np.round(weights / scale).clip(-128, 127)
weights_dequant = weights_q * scale

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 元の分布
ax = axes[0]
ax.hist(weights, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
ax.set_title('Original Weights (FP32)', fontsize=12)
ax.set_xlabel('Value')
ax.set_ylabel('Count')

# 量子化後の分布
ax = axes[1]
ax.hist(weights_q, bins=50, color='coral', alpha=0.7, edgecolor='black')
ax.set_title('Quantized Weights (INT8)', fontsize=12)
ax.set_xlabel('Quantized Value')
ax.set_ylabel('Count')

# 量子化誤差
ax = axes[2]
error = weights - weights_dequant
ax.hist(error, bins=50, color='seagreen', alpha=0.7, edgecolor='black')
ax.set_title('Quantization Error', fontsize=12)
ax.set_xlabel('Error')
ax.set_ylabel('Count')
ax.axvline(x=0, color='red', linestyle='--')

plt.tight_layout()
plt.show()

まとめ

本記事では、LLMの量子化技術について解説しました。

  • 量子化の基本: 浮動小数点数を整数にマッピングし、スケールとゼロポイントで元の値を近似
  • 対称/非対称量子化: 対称量子化は実装が単純、非対称量子化は非対称な分布に有効
  • 重みのみ量子化: LLMでは重みのみを量子化し、活性化はFP16で計算するのが一般的
  • Per-group量子化: グループごとにスケールを持つことで、INT4でも精度を維持
  • GPTQ/AWQ: キャリブレーションデータを用いて最適な量子化を行う手法
  • 精度影響: 適切な手法を使えば、INT4でもperplexityの増加は数%程度

量子化は、LLMを実用的にデプロイするための必須技術です。メモリ制約のある環境でも高品質なLLMを動かすために、これらの技術を活用しましょう。

次のステップとして、以下の記事も参考にしてください。