【PyTorch】自作Datasetの作り方を完全解説

PyTorchで自作のDatasetを作る方法について解説していきます。PyTorchのデータパイプラインは DatasetDataLoader の2つのクラスが中心的な役割を果たしており、これらを理解することは深層学習の実装において非常に重要です。

本記事の内容

  • torch.utils.data.Dataset クラスの仕組み
  • 自作Datasetの実装方法
  • DataLoaderとの連携
  • 画像データ・時系列データへの応用例

Datasetクラスの基本

PyTorchのDatasetクラスを継承して自作のDatasetを作る場合、以下の2つのメソッドを実装する必要があります。

import torch
from torch.utils.data import Dataset
メソッド 説明
__getitem__(self, index) インデックスを指定してデータを返す
__len__(self) データセットの総数を返す

__getitem__ は、インデックスを指定したときにデータセットからサンプルを返す関数です。__len__ は、データセットの総数を返す関数です。これらを実装することで、自作のDatasetクラスを実装することができます。

最もシンプルなDataset

まずは最も基本的な例から始めましょう。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class SimpleDataset(Dataset):
    """シンプルな自作Dataset"""
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)

# データの準備
np.random.seed(42)
X = np.random.randn(100, 5)  # 100サンプル, 5特徴量
y = np.random.randn(100, 1)  # 100サンプル, 1出力

# Datasetの作成
dataset = SimpleDataset(X, y)

# 動作確認
print(f"データセットのサイズ: {len(dataset)}")
print(f"1つ目のサンプル:")
sample_x, sample_y = dataset[0]
print(f"  X: {sample_x}")
print(f"  y: {sample_y}")

DataLoaderとの連携

Datasetを作成したら、DataLoaderを使ってバッチ処理やシャッフルを行います。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)

np.random.seed(42)
X = np.random.randn(100, 5)
y = np.random.randn(100, 1)
dataset = SimpleDataset(X, y)

# DataLoaderの作成
dataloader = DataLoader(
    dataset,
    batch_size=16,     # バッチサイズ
    shuffle=True,      # データをシャッフルする
    num_workers=0,     # データ読み込みのワーカー数
    drop_last=False,   # 最後の不完全なバッチを捨てるか
)

# DataLoaderの動作確認
for batch_idx, (batch_x, batch_y) in enumerate(dataloader):
    print(f"Batch {batch_idx}: X.shape={batch_x.shape}, y.shape={batch_y.shape}")
    if batch_idx >= 2:
        break

DataLoaderの主要なパラメータは以下のとおりです。

パラメータ 説明
batch_size バッチサイズ
shuffle エポックごとにデータをシャッフルするか
num_workers 並列データ読み込みのプロセス数
drop_last 最後のバッチがbatch_size未満の場合に捨てるか
pin_memory GPU転送を高速化(CUDA使用時)

分類問題のDataset

分類問題では、ラベルを整数として扱います。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ClassificationDataset(Dataset):
    """分類問題用のDataset"""
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)  # 分類ラベルはLongTensor

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)

# 3クラス分類のダミーデータ
np.random.seed(42)
n_samples = 300
X = np.random.randn(n_samples, 10)
y = np.random.randint(0, 3, n_samples)

dataset = ClassificationDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 1バッチ取り出し
batch_x, batch_y = next(iter(dataloader))
print(f"X shape: {batch_x.shape}")
print(f"y shape: {batch_y.shape}")
print(f"y dtype: {batch_y.dtype}")  # torch.int64
print(f"Unique labels: {torch.unique(batch_y).tolist()}")

前処理を組み込んだDataset

データの正規化や前処理をDatasetクラスに組み込むことができます。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class PreprocessedDataset(Dataset):
    """前処理付きのDataset"""
    def __init__(self, X, y, normalize=True):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.float32)

        if normalize:
            # 訓練データの統計量で正規化
            self.mean = self.X.mean(axis=0)
            self.std = self.X.std(axis=0) + 1e-8
            self.X = (self.X - self.mean) / self.std

    def __getitem__(self, index):
        x = torch.FloatTensor(self.X[index])
        y = torch.FloatTensor([self.y[index]])

        # ノイズ付加(データ拡張)
        if self.training:
            x = x + torch.randn_like(x) * 0.01
        return x, y

    def __len__(self):
        return len(self.X)

    @property
    def training(self):
        return getattr(self, '_training', False)

    def train(self):
        self._training = True

    def eval(self):
        self._training = False

# 使用例
np.random.seed(42)
X = np.random.randn(200, 8) * 10 + 5
y = np.random.randn(200)

dataset = PreprocessedDataset(X, y, normalize=True)
print(f"正規化後の平均: {dataset.X.mean(axis=0).round(4)}")
print(f"正規化後の標準偏差: {dataset.X.std(axis=0).round(4)}")

時系列データのDataset

時系列データでは、スライディングウィンドウでサンプルを作成するのが一般的です。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

class TimeSeriesDataset(Dataset):
    """時系列データ用のDataset(スライディングウィンドウ)"""
    def __init__(self, data, window_size, prediction_horizon=1):
        self.data = torch.FloatTensor(data)
        self.window_size = window_size
        self.prediction_horizon = prediction_horizon

    def __getitem__(self, index):
        # 入力: window_size分の過去データ
        x = self.data[index:index + self.window_size]
        # ターゲット: prediction_horizon先の値
        y = self.data[index + self.window_size:
                      index + self.window_size + self.prediction_horizon]
        return x, y

    def __len__(self):
        return len(self.data) - self.window_size - self.prediction_horizon + 1

# サイン波の時系列データ
np.random.seed(42)
t = np.linspace(0, 10 * np.pi, 500)
data = np.sin(t) + np.random.normal(0, 0.1, len(t))

# Dataset作成
window_size = 50
dataset = TimeSeriesDataset(data, window_size=window_size, prediction_horizon=1)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

print(f"データ点数: {len(data)}")
print(f"サンプル数: {len(dataset)}")

# 1サンプルの可視化
x, y = dataset[0]
plt.figure(figsize=(10, 4))
plt.plot(range(window_size), x.numpy(), 'b-', linewidth=2, label='Input window')
plt.plot(window_size, y.numpy(), 'ro', markersize=10, label='Target')
plt.axvline(x=window_size - 0.5, color='gray', linestyle='--', alpha=0.5)
plt.title(f"Time Series Dataset: window_size={window_size}")
plt.xlabel("Time step")
plt.ylabel("Value")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

訓練ループとの統合

最後に、自作Datasetを訓練ループに組み込む例を示します。

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np

class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)

# データ生成
np.random.seed(42)
X = np.random.randn(500, 10)
w_true = np.random.randn(10, 1)
y = X @ w_true + np.random.randn(500, 1) * 0.1

# Dataset作成と分割
dataset = SimpleDataset(X, y)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# シンプルなモデル
model = nn.Linear(10, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 訓練ループ
n_epochs = 20
for epoch in range(n_epochs):
    # 訓練
    model.train()
    train_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        pred = model(batch_x)
        loss = criterion(pred, batch_y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # 検証
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            pred = model(batch_x)
            loss = criterion(pred, batch_y)
            val_loss += loss.item()

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d}: "
              f"Train Loss={train_loss/len(train_loader):.6f}, "
              f"Val Loss={val_loss/len(val_loader):.6f}")

まとめ

本記事では、PyTorchで自作Datasetを作成する方法を解説しました。

  • torch.utils.data.Dataset を継承し、__getitem____len__ を実装する
  • DataLoader と組み合わせてバッチ処理・シャッフル・並列読み込みを行う
  • 前処理(正規化、データ拡張)はDatasetクラス内に組み込める
  • 時系列データはスライディングウィンドウ方式でサンプルを作成する
  • random_split で訓練・検証データの分割が可能