LLaVAアーキテクチャの理論と実装

LLaVA(Large Language and Vision Assistant)は、2023年にLiu et al.によって発表されたオープンソースのマルチモーダルLLMです。論文 “Visual Instruction Tuning” で提案され、シンプルなアーキテクチャながら高い性能を達成し、マルチモーダルLLMの研究を大きく加速させました。

LLaVAの成功の鍵は、シンプルな設計効率的な学習戦略にあります。CLIP画像エンコーダとLLaMAを線形プロジェクションで接続するだけという単純な構造でありながら、GPT-4Vに迫る性能を実現しています。

本記事の内容

  • LLaVAの設計思想
  • アーキテクチャの詳細
  • ビジュアルインストラクションチューニング
  • 学習データの生成方法
  • PyTorchによる実装

前提知識

この記事を読む前に、以下の記事を読んでおくと理解が深まります。

LLaVAの設計思想

シンプルさの追求

LLaVAの設計哲学は「シンプルに、しかし効果的に」です。

複雑な融合モジュール(Q-FormerやPerceiver Resamplerなど)を使わず、単純な線形プロジェクションで画像特徴をLLMの埋め込み空間に写像します。

この設計の利点: 1. 実装が容易: 既存のLLMに簡単に追加可能 2. 学習が安定: 学習パラメータが少ない 3. デバッグが容易: 各コンポーネントの役割が明確 4. 拡張が容易: 新しいLLMや画像エンコーダに置き換え可能

ビジュアルインストラクションチューニング

LLaVAのもう一つの重要な貢献は、GPT-4を使って学習データを生成するというアプローチです。

画像とそのキャプションをGPT-4に入力し、画像についての対話データを生成させます。これにより、高品質な指示追従データを効率的に収集できます。

アーキテクチャ

全体構造

[入力画像 (224×224)]
    ↓
[CLIP ViT-L/14] (凍結)
    ↓
[画像特徴 (256 tokens × 1024)]
    ↓
[線形プロジェクション] (学習)
    ↓
[画像トークン (256 tokens × 4096)]
    ↓
[画像トークン] + [テキストトークン]
    ↓
[LLaMA / Vicuna] (LoRAまたはフル学習)
    ↓
[出力テキスト]

1. 画像エンコーダ

CLIP ViT-L/14を使用します。

  • 入力: 224×224のRGB画像
  • 出力: 256個のパッチトークン(最終層のCLSトークンを除く)
  • 次元: 各トークンは1024次元
# CLIP ViT-L/14の出力
# image: (B, 3, 224, 224)
# features: (B, 257, 1024)  # 256パッチ + 1 CLSトークン
# LLaVAでは256パッチトークンのみ使用: (B, 256, 1024)

画像エンコーダは事前学習済みのCLIPを使用し、学習中は凍結します。

2. プロジェクション層

シンプルな線形プロジェクション(または2層MLP)で、画像特徴をLLMの埋め込み空間に写像します。

$$ \bm{H}_v = \bm{W} \cdot \bm{Z}_v $$

ここで: – $\bm{Z}_v \in \mathbb{R}^{N \times d_v}$: 画像特徴($N=256$, $d_v=1024$) – $\bm{W} \in \mathbb{R}^{d_v \times d_l}$: 線形変換($d_l=4096$ for LLaMA-7B) – $\bm{H}_v \in \mathbb{R}^{N \times d_l}$: LLM空間の画像トークン

LLaVA-1.5では、2層MLPを使用:

$$ \bm{H}_v = \bm{W}_2 \cdot \text{GELU}(\bm{W}_1 \cdot \bm{Z}_v) $$

3. LLMバックボーン

LLaMA(または指示チューニング済みのVicuna)を使用します。

  • LLaVA: Vicuna-7B/13B
  • LLaVA-1.5: Vicuna-7B/13B(より強力な設定)

画像トークンは、特殊トークン <image> の位置に挿入されます。

<s> USER: <image>
この画像について説明してください。
ASSISTANT: この画像には...
</s>

<image> は256個の画像トークンに置き換えられます。

入力フォーマット

プロンプトテンプレート

LLaVAは以下のような会話形式のプロンプトを使用します:

A chat between a curious user and an artificial intelligence assistant.
The assistant gives helpful, detailed, and polite answers to the user's questions.

USER: <image>
{ユーザーの質問}
ASSISTANT: {モデルの応答}

マルチターン対話

複数回の質問応答も可能です:

USER: <image>
この画像には何が写っていますか?
ASSISTANT: この画像には赤いリンゴがテーブルの上に置かれています。

USER: そのリンゴはどのような状態ですか?
ASSISTANT: リンゴは新鮮で、艶があり、傷のない状態です。

学習戦略

Stage 1: 事前学習(Feature Alignment)

目的: 画像特徴とテキスト埋め込み空間の整合性を取る

設定: – 画像エンコーダ: 凍結 – プロジェクション: 学習 – LLM: 凍結

データ: CC3M(Conceptual Captions 3M)の595Kサブセット

タスク: 画像キャプション生成

入力: <image> この画像を簡潔に説明してください。
出力: {キャプション}

この段階では、プロジェクション層のみを学習し、画像とテキストの対応を学習します。

Stage 2: 指示チューニング(Instruction Tuning)

目的: 対話的なタスクに適応

設定: – 画像エンコーダ: 凍結 – プロジェクション: 学習 – LLM: 学習(フルまたはLoRA)

データ: LLaVA-Instruct-150K(GPT-4で生成)

タスク: – 詳細な画像説明 – 複雑な推論 – 対話

この段階で、モデルは指示に従って適切な応答を生成する能力を獲得します。

学習データの生成

GPT-4を使ったデータ生成

LLaVAの学習データ(LLaVA-Instruct)は、GPT-4を使って生成されました。

入力: 画像のキャプションとバウンディングボックス

出力: 3種類の会話データ

  1. 会話(Conversation): 画像についての自然な対話
  2. 詳細説明(Detailed Description): 画像の詳細な説明
  3. 複雑な推論(Complex Reasoning): 画像についての推論を要する質問

データ生成のプロンプト例

You are an AI assistant helping to create training data for a visual AI model.

Given the following image description:
Caption: "A person riding a bicycle on a sunny day"
Objects: [bicycle (person, wheels, handlebar), sun, trees]

Generate a conversation between a user and an assistant about this image.
The conversation should be natural and informative.

データセットの構成

  • LLaVA-Instruct-150K: 150K件の画像-対話ペア
  • LLaVA-1.5: 665K件(より多くの学術データセットを含む)

LLaVAのバリエーション

LLaVA-1.5

オリジナルLLaVAの改良版:

  1. 高解像度: 336×336入力(576トークン)
  2. MLPプロジェクション: 2層MLP
  3. より多くの学習データ: 665K件
  4. 追加の学術データ: VQA, GQA, OCRなど

LLaVA-NeXT(LLaVA-1.6)

さらなる改良:

  1. 動的高解像度: 様々な解像度・アスペクト比に対応
  2. より大きなLLM: Mistral-7B, Yi-34Bなど
  3. 動画理解: 動画入力のサポート

PyTorchによる実装

LLaVAモデル

import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer


class LLaVA(nn.Module):
    """LLaVA: Large Language and Vision Assistant"""
    def __init__(
        self,
        vision_model_name="openai/clip-vit-large-patch14",
        llm_model_name="meta-llama/Llama-2-7b-hf",
        freeze_vision=True,
        freeze_llm=False,
    ):
        super().__init__()

        # 画像エンコーダ(CLIP ViT)
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
        self.vision_hidden_size = self.vision_encoder.config.hidden_size  # 1024

        # 凍結設定
        if freeze_vision:
            for param in self.vision_encoder.parameters():
                param.requires_grad = False

        # LLM
        self.llm = LlamaForCausalLM.from_pretrained(llm_model_name)
        self.llm_hidden_size = self.llm.config.hidden_size  # 4096

        if freeze_llm:
            for param in self.llm.parameters():
                param.requires_grad = False

        # プロジェクション(2層MLP)
        self.vision_projection = nn.Sequential(
            nn.Linear(self.vision_hidden_size, self.llm_hidden_size),
            nn.GELU(),
            nn.Linear(self.llm_hidden_size, self.llm_hidden_size),
        )

        # 特殊トークンのIDを記録
        self.image_token_id = None  # <image>トークン

    def encode_image(self, images):
        """
        画像をLLMトークン空間にエンコード

        Args:
            images: (B, 3, 224, 224)
        Returns:
            (B, num_patches, llm_hidden_size)
        """
        # CLIP ViTで特徴抽出
        vision_outputs = self.vision_encoder(images)
        # last_hidden_state: (B, num_patches+1, hidden_size)
        # CLSトークンを除く
        image_features = vision_outputs.last_hidden_state[:, 1:, :]

        # プロジェクション
        image_tokens = self.vision_projection(image_features)

        return image_tokens

    def prepare_inputs(self, images, input_ids, attention_mask, image_positions):
        """
        画像トークンをテキストに挿入

        Args:
            images: (B, 3, H, W)
            input_ids: (B, seq_len)
            attention_mask: (B, seq_len)
            image_positions: (B,) 各サンプルでの<image>トークンの位置
        """
        batch_size = images.shape[0]

        # 画像エンコード
        image_tokens = self.encode_image(images)  # (B, num_img_tokens, hidden)
        num_img_tokens = image_tokens.shape[1]

        # テキスト埋め込み
        text_embeds = self.llm.get_input_embeddings()(input_ids)

        # 画像トークンを<image>位置に挿入
        new_embeds_list = []
        new_attention_mask_list = []

        for b in range(batch_size):
            pos = image_positions[b]

            # <image>トークンの前後で分割
            before = text_embeds[b, :pos]
            after = text_embeds[b, pos+1:]  # <image>トークンをスキップ

            # 結合: [before] + [image_tokens] + [after]
            new_embed = torch.cat([before, image_tokens[b], after], dim=0)
            new_embeds_list.append(new_embed)

            # アテンションマスクも同様に処理
            before_mask = attention_mask[b, :pos]
            after_mask = attention_mask[b, pos+1:]
            img_mask = torch.ones(num_img_tokens, device=attention_mask.device)
            new_mask = torch.cat([before_mask, img_mask, after_mask], dim=0)
            new_attention_mask_list.append(new_mask)

        # パディングして揃える
        max_len = max(e.shape[0] for e in new_embeds_list)
        new_embeds = torch.zeros(batch_size, max_len, self.llm_hidden_size, device=images.device)
        new_attention_mask = torch.zeros(batch_size, max_len, device=images.device)

        for b in range(batch_size):
            length = new_embeds_list[b].shape[0]
            new_embeds[b, :length] = new_embeds_list[b]
            new_attention_mask[b, :length] = new_attention_mask_list[b]

        return new_embeds, new_attention_mask

    def forward(
        self,
        images,
        input_ids,
        attention_mask,
        image_positions,
        labels=None,
    ):
        """
        順伝播

        Args:
            images: (B, 3, H, W)
            input_ids: (B, seq_len) <image>トークンを含む
            attention_mask: (B, seq_len)
            image_positions: (B,) <image>トークンの位置
            labels: (B, seq_len) 学習ターゲット
        """
        # 入力を準備
        inputs_embeds, new_attention_mask = self.prepare_inputs(
            images, input_ids, attention_mask, image_positions
        )

        # ラベルも調整(必要に応じて)
        if labels is not None:
            num_img_tokens = (inputs_embeds.shape[1] - input_ids.shape[1] + 1)
            # 画像トークン部分は損失計算から除外
            new_labels = self._adjust_labels(labels, image_positions, num_img_tokens)
        else:
            new_labels = None

        # LLMに入力
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=new_attention_mask,
            labels=new_labels,
        )

        return outputs

    def _adjust_labels(self, labels, image_positions, num_img_tokens):
        """ラベルを画像トークンに合わせて調整"""
        batch_size = labels.shape[0]
        new_labels_list = []

        for b in range(batch_size):
            pos = image_positions[b]
            before = labels[b, :pos]
            after = labels[b, pos+1:]
            # 画像トークン部分は-100(無視)
            img_labels = torch.full((num_img_tokens,), -100, device=labels.device)
            new_label = torch.cat([before, img_labels, after], dim=0)
            new_labels_list.append(new_label)

        # パディング
        max_len = max(l.shape[0] for l in new_labels_list)
        new_labels = torch.full((batch_size, max_len), -100, device=labels.device)
        for b in range(batch_size):
            length = new_labels_list[b].shape[0]
            new_labels[b, :length] = new_labels_list[b]

        return new_labels

    @torch.no_grad()
    def generate(
        self,
        images,
        input_ids,
        attention_mask,
        image_positions,
        max_new_tokens=256,
        **generate_kwargs,
    ):
        """画像とプロンプトからテキストを生成"""
        inputs_embeds, new_attention_mask = self.prepare_inputs(
            images, input_ids, attention_mask, image_positions
        )

        outputs = self.llm.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=new_attention_mask,
            max_new_tokens=max_new_tokens,
            **generate_kwargs,
        )

        return outputs

学習ループ

def train_llava(model, train_dataloader, optimizer, scheduler, num_epochs):
    """LLaVAの学習"""
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0

        for batch in train_dataloader:
            images = batch['images'].cuda()
            input_ids = batch['input_ids'].cuda()
            attention_mask = batch['attention_mask'].cuda()
            image_positions = batch['image_positions'].cuda()
            labels = batch['labels'].cuda()

            optimizer.zero_grad()

            outputs = model(
                images=images,
                input_ids=input_ids,
                attention_mask=attention_mask,
                image_positions=image_positions,
                labels=labels,
            )

            loss = outputs.loss
            loss.backward()

            # 勾配クリッピング
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

推論例

def inference_example(model, tokenizer, image_processor, image_path, prompt):
    """推論の例"""
    from PIL import Image

    # 画像の読み込みと前処理
    image = Image.open(image_path).convert('RGB')
    pixel_values = image_processor(images=image, return_tensors='pt').pixel_values

    # プロンプトの準備
    full_prompt = f"USER: <image>\n{prompt}\nASSISTANT:"
    inputs = tokenizer(full_prompt, return_tensors='pt')

    # <image>トークンの位置を特定
    image_token_id = tokenizer.convert_tokens_to_ids('<image>')
    image_positions = (inputs.input_ids == image_token_id).nonzero(as_tuple=True)[1]

    # 生成
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            images=pixel_values.cuda(),
            input_ids=inputs.input_ids.cuda(),
            attention_mask=inputs.attention_mask.cuda(),
            image_positions=image_positions.cuda(),
            max_new_tokens=256,
            temperature=0.2,
            do_sample=True,
        )

    # デコード
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return response

LLaVAの性能

ベンチマーク結果

モデル VQAv2 GQA TextVQA MM-Bench
LLaVA-7B 78.5 62.0 58.2 64.3
LLaVA-13B 80.0 63.3 61.3 67.7
LLaVA-1.5-7B 78.5 62.0 58.2 64.3
LLaVA-1.5-13B 80.0 63.3 61.3 67.7

定性的な強み

  • 詳細な画像説明
  • 複雑な推論
  • 対話的なやり取り
  • OCR(画像内テキストの読解)

限界

  • 幻覚(存在しないものを説明)
  • 細かいカウント
  • 空間関係の正確な把握

まとめ

本記事では、LLaVAアーキテクチャの仕組みを解説しました。

  • シンプルな設計: CLIP ViT + 線形プロジェクション + LLM
  • 2段階学習: 事前学習(アライメント)→ 指示チューニング
  • GPT-4によるデータ生成: 高品質な学習データを効率的に収集
  • 入力形式: <image>トークンを画像トークン列に置き換え
  • 効率的な学習: プロジェクション層のみの学習から始める

LLaVAは、シンプルながら効果的なマルチモーダルLLMの設計を示し、その後の多くの研究に影響を与えました。オープンソースで公開されていることも、研究コミュニティへの大きな貢献です。

次のステップとして、以下の記事も参考にしてください。