機械学習や深層学習で分類問題を扱う際には、one-hot encodingなどの表現をすることが多いですが、その際にone-hotベクトルの最大値のインデックスを取得するためにNumPyのargmaxを利用することがあります。
しかし、頻繁に利用するargmaxやargminですが、意外とaxisの指定など微妙に理解しきれておらず、どのように指定すればいいんだっけ?と悩んでしまう人も多いのではないでしょうか。
今回は、NumPyにおいてargminやargmaxを使いこなす方法について解説します。
本記事の内容
- argmin/argmaxの基本的な使い方
- 多次元配列でのaxis指定の挙動
- 実践的な使用例
argmax/argminの基本
np.argmax は配列の最大値のインデックスを、np.argmin は最小値のインデックスを返す関数です。
import numpy as np
# 1次元配列
a = np.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3])
print(f"配列: {a}")
print(f"argmax: {np.argmax(a)} (値: {a[np.argmax(a)]})")
print(f"argmin: {np.argmin(a)} (値: {a[np.argmin(a)]})")
出力は以下のようになります。
配列: [3 1 4 1 5 9 2 6 5 3]
argmax: 5 (値: 9)
argmin: 1 (値: 1)
同じ最大値(最小値)が複数存在する場合は、最初に出現するインデックスが返されます。
import numpy as np
a = np.array([1, 5, 3, 5, 2])
print(f"argmax: {np.argmax(a)}") # 1(最初の5のインデックス)
2次元配列でのaxis指定
多次元配列の場合、axis パラメータで探索する軸を指定できます。
import numpy as np
arr = np.array([
[1, 5, 3],
[4, 2, 6],
[7, 0, 8]
])
print("配列:")
print(arr)
# axis指定なし: 全要素を平坦化して探索
print(f"\nargmax(): {np.argmax(arr)}") # 8(平坦化: index=8)
print(f"argmin(): {np.argmin(arr)}") # 7(平坦化: index=7)
# axis=0: 列方向(行を跨いで)探索
print(f"\nargmax(axis=0): {np.argmax(arr, axis=0)}") # [2, 0, 2]
print(f"argmin(axis=0): {np.argmin(arr, axis=0)}") # [0, 2, 0]
# axis=1: 行方向(列を跨いで)探索
print(f"\nargmax(axis=1): {np.argmax(arr, axis=1)}") # [1, 2, 2]
print(f"argmin(axis=1): {np.argmin(arr, axis=1)}") # [0, 1, 1]
axisの挙動を図で理解する
axisの指定は直感的でない場合があるので、図で確認しましょう。
import numpy as np
import matplotlib.pyplot as plt
arr = np.array([
[1, 5, 3],
[4, 2, 6],
[7, 0, 8]
])
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
# axis なし(平坦化)
flat = arr.flatten()
max_idx = np.argmax(flat)
axes[0].imshow([[1 if i == max_idx else 0 for i in range(9)]], cmap='Reds', aspect='auto')
for i in range(9):
axes[0].text(i, 0, f'{flat[i]}', ha='center', va='center', fontsize=14)
axes[0].set_title(f"No axis: argmax={max_idx} (value={flat[max_idx]})")
axes[0].set_xticks(range(9))
axes[0].set_yticks([])
# axis=0(列方向に探索)
result0 = np.argmax(arr, axis=0)
highlight0 = np.zeros_like(arr, dtype=float)
for j in range(3):
highlight0[result0[j], j] = 1
axes[1].imshow(highlight0, cmap='Reds', alpha=0.5, aspect='auto')
for i in range(3):
for j in range(3):
axes[1].text(j, i, f'{arr[i,j]}', ha='center', va='center', fontsize=14)
axes[1].set_title(f"axis=0: argmax={result0.tolist()}")
axes[1].set_xticks(range(3))
axes[1].set_xticklabels(['col0', 'col1', 'col2'])
axes[1].set_yticks(range(3))
axes[1].set_yticklabels(['row0', 'row1', 'row2'])
# axis=1(行方向に探索)
result1 = np.argmax(arr, axis=1)
highlight1 = np.zeros_like(arr, dtype=float)
for i in range(3):
highlight1[i, result1[i]] = 1
axes[2].imshow(highlight1, cmap='Reds', alpha=0.5, aspect='auto')
for i in range(3):
for j in range(3):
axes[2].text(j, i, f'{arr[i,j]}', ha='center', va='center', fontsize=14)
axes[2].set_title(f"axis=1: argmax={result1.tolist()}")
axes[2].set_xticks(range(3))
axes[2].set_xticklabels(['col0', 'col1', 'col2'])
axes[2].set_yticks(range(3))
axes[2].set_yticklabels(['row0', 'row1', 'row2'])
plt.tight_layout()
plt.show()
axis=0: 各列について、行方向に最大値を探す。結果の形状は列数と同じaxis=1: 各行について、列方向に最大値を探す。結果の形状は行数と同じ
3次元配列での使い方
3次元以上の配列でも同様にaxis指定が使えます。
import numpy as np
arr3d = np.random.randint(0, 100, size=(2, 3, 4))
print(f"配列の形状: {arr3d.shape}")
print(arr3d)
# 各axisでのargmax
print(f"\nargmax(axis=0).shape: {np.argmax(arr3d, axis=0).shape}") # (3, 4)
print(f"argmax(axis=1).shape: {np.argmax(arr3d, axis=1).shape}") # (2, 4)
print(f"argmax(axis=2).shape: {np.argmax(arr3d, axis=2).shape}") # (2, 3)
結果の形状は、指定したaxisが消えた形状になります。
実践例: one-hot encodingのデコード
分類問題での典型的な使い方です。
import numpy as np
# ニューラルネットワークの出力(ソフトマックス後)を想定
logits = np.array([
[0.1, 0.7, 0.2], # クラス1
[0.8, 0.1, 0.1], # クラス0
[0.2, 0.3, 0.5], # クラス2
[0.05, 0.9, 0.05], # クラス1
])
# argmaxで予測クラスを取得
predictions = np.argmax(logits, axis=1)
class_names = ['Cat', 'Dog', 'Bird']
print("=== Classification Results ===")
for i, (pred, probs) in enumerate(zip(predictions, logits)):
print(f" Sample {i}: {class_names[pred]} (confidence: {probs[pred]:.2f})")
実践例: 最近傍探索
距離行列から最も近い点を見つける例です。
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
# 2次元の点群
points = np.random.randn(20, 2)
query = np.array([1.0, 1.0])
# 距離の計算
distances = np.sqrt(np.sum((points - query) ** 2, axis=1))
# 最近傍のインデックス
nearest_idx = np.argmin(distances)
print(f"最近傍点: index={nearest_idx}, 距離={distances[nearest_idx]:.3f}")
# 上位3つの近傍
top3_idx = np.argsort(distances)[:3]
print(f"Top 3 近傍: {top3_idx.tolist()}")
plt.figure(figsize=(8, 6))
plt.scatter(points[:, 0], points[:, 1], c='steelblue', s=50, label='Points')
plt.scatter(query[0], query[1], c='red', s=100, marker='*', label='Query')
plt.scatter(points[nearest_idx, 0], points[nearest_idx, 1],
c='none', edgecolors='red', s=200, linewidths=2, label='Nearest')
for i in top3_idx:
plt.plot([query[0], points[i, 0]], [query[1], points[i, 1]], 'r--', alpha=0.5)
plt.legend()
plt.title("Nearest Neighbor Search using argmin")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()
まとめ
本記事では、NumPyのargminとargmaxの使い方を解説しました。
argmax/argminは配列の最大値/最小値のインデックスを返す- axis指定なし: 配列を平坦化して探索
- axis=0: 列ごとに行方向に探索
- axis=1: 行ごとに列方向に探索
- 結果の形状は、指定したaxisが消えた形状になる
- one-hotデコードや最近傍探索など、機械学習の実践で頻繁に利用される