【NumPy】argminとargmaxの使い方を完全解説(axisの指定も)

機械学習や深層学習で分類問題を扱う際には、one-hot encodingなどの表現をすることが多いですが、その際にone-hotベクトルの最大値のインデックスを取得するためにNumPyのargmaxを利用することがあります。

しかし、頻繁に利用するargmaxやargminですが、意外とaxisの指定など微妙に理解しきれておらず、どのように指定すればいいんだっけ?と悩んでしまう人も多いのではないでしょうか。

今回は、NumPyにおいてargminやargmaxを使いこなす方法について解説します。

本記事の内容

  • argmin/argmaxの基本的な使い方
  • 多次元配列でのaxis指定の挙動
  • 実践的な使用例

argmax/argminの基本

np.argmax は配列の最大値のインデックスを、np.argmin は最小値のインデックスを返す関数です。

import numpy as np

# 1次元配列
a = np.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3])

print(f"配列: {a}")
print(f"argmax: {np.argmax(a)}  (値: {a[np.argmax(a)]})")
print(f"argmin: {np.argmin(a)}  (値: {a[np.argmin(a)]})")

出力は以下のようになります。

配列: [3 1 4 1 5 9 2 6 5 3]
argmax: 5  (値: 9)
argmin: 1  (値: 1)

同じ最大値(最小値)が複数存在する場合は、最初に出現するインデックスが返されます。

import numpy as np

a = np.array([1, 5, 3, 5, 2])
print(f"argmax: {np.argmax(a)}")  # 1(最初の5のインデックス)

2次元配列でのaxis指定

多次元配列の場合、axis パラメータで探索する軸を指定できます。

import numpy as np

arr = np.array([
    [1, 5, 3],
    [4, 2, 6],
    [7, 0, 8]
])
print("配列:")
print(arr)

# axis指定なし: 全要素を平坦化して探索
print(f"\nargmax(): {np.argmax(arr)}")  # 8(平坦化: index=8)
print(f"argmin(): {np.argmin(arr)}")  # 7(平坦化: index=7)

# axis=0: 列方向(行を跨いで)探索
print(f"\nargmax(axis=0): {np.argmax(arr, axis=0)}")  # [2, 0, 2]
print(f"argmin(axis=0): {np.argmin(arr, axis=0)}")  # [0, 2, 0]

# axis=1: 行方向(列を跨いで)探索
print(f"\nargmax(axis=1): {np.argmax(arr, axis=1)}")  # [1, 2, 2]
print(f"argmin(axis=1): {np.argmin(arr, axis=1)}")  # [0, 1, 1]

axisの挙動を図で理解する

axisの指定は直感的でない場合があるので、図で確認しましょう。

import numpy as np
import matplotlib.pyplot as plt

arr = np.array([
    [1, 5, 3],
    [4, 2, 6],
    [7, 0, 8]
])

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# axis なし(平坦化)
flat = arr.flatten()
max_idx = np.argmax(flat)
axes[0].imshow([[1 if i == max_idx else 0 for i in range(9)]], cmap='Reds', aspect='auto')
for i in range(9):
    axes[0].text(i, 0, f'{flat[i]}', ha='center', va='center', fontsize=14)
axes[0].set_title(f"No axis: argmax={max_idx} (value={flat[max_idx]})")
axes[0].set_xticks(range(9))
axes[0].set_yticks([])

# axis=0(列方向に探索)
result0 = np.argmax(arr, axis=0)
highlight0 = np.zeros_like(arr, dtype=float)
for j in range(3):
    highlight0[result0[j], j] = 1
axes[1].imshow(highlight0, cmap='Reds', alpha=0.5, aspect='auto')
for i in range(3):
    for j in range(3):
        axes[1].text(j, i, f'{arr[i,j]}', ha='center', va='center', fontsize=14)
axes[1].set_title(f"axis=0: argmax={result0.tolist()}")
axes[1].set_xticks(range(3))
axes[1].set_xticklabels(['col0', 'col1', 'col2'])
axes[1].set_yticks(range(3))
axes[1].set_yticklabels(['row0', 'row1', 'row2'])

# axis=1(行方向に探索)
result1 = np.argmax(arr, axis=1)
highlight1 = np.zeros_like(arr, dtype=float)
for i in range(3):
    highlight1[i, result1[i]] = 1
axes[2].imshow(highlight1, cmap='Reds', alpha=0.5, aspect='auto')
for i in range(3):
    for j in range(3):
        axes[2].text(j, i, f'{arr[i,j]}', ha='center', va='center', fontsize=14)
axes[2].set_title(f"axis=1: argmax={result1.tolist()}")
axes[2].set_xticks(range(3))
axes[2].set_xticklabels(['col0', 'col1', 'col2'])
axes[2].set_yticks(range(3))
axes[2].set_yticklabels(['row0', 'row1', 'row2'])

plt.tight_layout()
plt.show()
  • axis=0: 各列について、行方向に最大値を探す。結果の形状は列数と同じ
  • axis=1: 各行について、列方向に最大値を探す。結果の形状は行数と同じ

3次元配列での使い方

3次元以上の配列でも同様にaxis指定が使えます。

import numpy as np

arr3d = np.random.randint(0, 100, size=(2, 3, 4))
print(f"配列の形状: {arr3d.shape}")
print(arr3d)

# 各axisでのargmax
print(f"\nargmax(axis=0).shape: {np.argmax(arr3d, axis=0).shape}")  # (3, 4)
print(f"argmax(axis=1).shape: {np.argmax(arr3d, axis=1).shape}")  # (2, 4)
print(f"argmax(axis=2).shape: {np.argmax(arr3d, axis=2).shape}")  # (2, 3)

結果の形状は、指定したaxisが消えた形状になります。

実践例: one-hot encodingのデコード

分類問題での典型的な使い方です。

import numpy as np

# ニューラルネットワークの出力(ソフトマックス後)を想定
logits = np.array([
    [0.1, 0.7, 0.2],  # クラス1
    [0.8, 0.1, 0.1],  # クラス0
    [0.2, 0.3, 0.5],  # クラス2
    [0.05, 0.9, 0.05], # クラス1
])

# argmaxで予測クラスを取得
predictions = np.argmax(logits, axis=1)
class_names = ['Cat', 'Dog', 'Bird']

print("=== Classification Results ===")
for i, (pred, probs) in enumerate(zip(predictions, logits)):
    print(f"  Sample {i}: {class_names[pred]} (confidence: {probs[pred]:.2f})")

実践例: 最近傍探索

距離行列から最も近い点を見つける例です。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# 2次元の点群
points = np.random.randn(20, 2)
query = np.array([1.0, 1.0])

# 距離の計算
distances = np.sqrt(np.sum((points - query) ** 2, axis=1))

# 最近傍のインデックス
nearest_idx = np.argmin(distances)
print(f"最近傍点: index={nearest_idx}, 距離={distances[nearest_idx]:.3f}")

# 上位3つの近傍
top3_idx = np.argsort(distances)[:3]
print(f"Top 3 近傍: {top3_idx.tolist()}")

plt.figure(figsize=(8, 6))
plt.scatter(points[:, 0], points[:, 1], c='steelblue', s=50, label='Points')
plt.scatter(query[0], query[1], c='red', s=100, marker='*', label='Query')
plt.scatter(points[nearest_idx, 0], points[nearest_idx, 1],
            c='none', edgecolors='red', s=200, linewidths=2, label='Nearest')
for i in top3_idx:
    plt.plot([query[0], points[i, 0]], [query[1], points[i, 1]], 'r--', alpha=0.5)
plt.legend()
plt.title("Nearest Neighbor Search using argmin")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

まとめ

本記事では、NumPyのargminとargmaxの使い方を解説しました。

  • argmax / argmin は配列の最大値/最小値のインデックスを返す
  • axis指定なし: 配列を平坦化して探索
  • axis=0: 列ごとに行方向に探索
  • axis=1: 行ごとに列方向に探索
  • 結果の形状は、指定したaxisが消えた形状になる
  • one-hotデコードや最近傍探索など、機械学習の実践で頻繁に利用される
関連タグ: argmax argmin 配列操作