データを扱う人や論文を書くような研究者であれば、頻繁にmatplotlibを使っていると思います。
実はmatplotlibは簡単に使えるものの、実は図形がオブジェクトとして定義されおり、図形に対して様々な操作をAPIを介して行うことで、ユーザーが自由自在にグラフや図形を編集できるようになっており、使い方によっては、かなり柔軟に、そして綺麗な絵や図を書くことができます。
しかし、自由度が高く設計されている分、matplotlibの機能はかなり豊富であり、全部を理解するとかなり大変だと思います。
今回は、matplotlibを自由自在に使いこなせるような、最低限必要な内部構造についての説明をします。
また理解した内部構造を元に、matplotlibで綺麗なグラフを作るためにサンプルをいくつか紹介します。
ちなみに、matplotlibのグラフの使い方として、グローバルオブジェクトとしてpltを操作する場合と、オブジェクトを作る方法があります。
この記事では、後者のオブジェクトを作る方法についての解説記事になるのでご注意ください。
- matplotlibにおける内部構造を理解
- matplotlibでいい感じのグラフを作る方法
matplotlibの内部構造
matplotlibの図形は、実はオブジェクト指向のような内部構造を有しており、内部構造に対してはAPIが用意されており、APIを扱うことでグラフを操作するような構造になっています。
この内部構造は一度理解しておくと、かなり自由にグラフを触ることができ、今後自分が作りたいようにグラフを扱うことができるので、大変便利です。
matplotlibの内部構造の全体像 FigureとAxes
matplotlibは大きく、Figureオブジェクトを作り、その中にAxesと呼ばれるグラフの描写エリアを作るのが基本です。Figureはmatplotlibで操作できるような図形領域で、ウィンドウに相当する概念です。
一方、AxesはFigureに内包されており、実際のグラフ1つ1つに相当します。
ざっくり、このようになっています。
上の説明図では、Figureに図形領域であるAxesが2つ描写されていますが、1つしかグラフを作らない場合でもこのような構造になっています。
下の図は、matplotlibの公式サイトにある、内部構造を説明する図です。
基本的にmatplotlibの図形の構成はこのようになっているので、自分がどこを操作しているか意識してみましょう。さらに詳しい内部構造はこのようになっています。
matplotlibで簡単にグラフを作ってみる
まずはライブラリをインポートします。numpyはグラフに表示するデータを生成するのに色々便利なので、入れています。
import numpy as np
import matplotlib.pyplot as plt
続いて、FigureとAxesを作ります。plt.subplots( ) を使うことで、FigureオブジェクトとAxesオブジェクトを作ることができます。
実際に図形を描写するときは、Axesオブジェクトのplot()メソッドを使うことで、描写することができます。
fig, ax = plt.subplots()
x = np.linspace(1, 100, 10)
y = 2 * x
ax.plot(x,y)
複数のグラフが縦横に並ぶグラフを作る
AxesオブジェクトをFigureオブジェクトに複数配置することで、複数のグラフが横に並んだ図形も作ることができます。このようなグラフを作ります。
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
x = np.linspace(-10, 10, 1000)
fig = plt.figure(figsize=(20, 6), dpi=200)
ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
ax4 = fig.add_subplot(2, 2, 4)
y1 = stats.norm.pdf(x, 4, 1) # normal distribution
y2 = stats.beta.pdf(x, 2, 3) # beta distribution
y3 = stats.t.pdf(x, 5) # t distribution
y4 = stats.lognorm.pdf(x, 5, 0.1) # lognorm distribution
ax1.plot(x, y1, color="blue", label="normal distribution")
ax2.plot(x, y2, color="red", label="normal distribution")
ax3.plot(x, y3, color="green", label="t disribution")
ax4.plot(x, y4, color="brown", label="lognorm disribution")
ax1.legend()
ax2.legend()
ax3.legend()
ax4.legend()
matplotlibでいい感じのグラフを作る方法
matplotlibでいい感じのグラフを作る方法ですが、好みの問題もあると思うので、ここではいくつかのサンプルを掲載することにします。
Matpliotlibの公式サイトにあるUser Galleryには、いい感じのグラフとコードのセットがあるので、自分の好みのグラフを見つけて、自分好みにカスタマイズしてみると良いと思います。
カラフルな多変量時系列の表示
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cbook import get_sample_data
fname = get_sample_data('percent_bachelors_degrees_women_usa.csv',
asfileobj=False)
gender_degree_data = np.genfromtxt(fname, delimiter=',', names=True)
fig, ax = plt.subplots(1, 1, figsize=(12, 14))
ax.set_prop_cycle(color=[
'#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a',
'#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94',
'#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d',
'#17becf', '#9edae5'])
ax.spines[:].set_visible(False)
ax.xaxis.tick_bottom()
ax.yaxis.tick_left()
fig.subplots_adjust(left=.06, right=.75, bottom=.02, top=.94)
ax.set_xlim(1969.5, 2011.1)
ax.set_ylim(-0.25, 90)
ax.set_xticks(range(1970, 2011, 10))
ax.set_yticks(range(0, 91, 10))
ax.xaxis.set_major_formatter('{x:.0f}')
ax.yaxis.set_major_formatter('{x:.0f}%')
ax.grid(True, 'major', 'y', ls='--', lw=.5, c='k', alpha=.3)
ax.tick_params(axis='both', which='both', labelsize=14,
bottom=False, top=False, labelbottom=True,
left=False, right=False, labelleft=True)
majors = ['Health Professions', 'Public Administration', 'Education',
'Psychology', 'Foreign Languages', 'English',
'Communications\nand Journalism', 'Art and Performance', 'Biology',
'Agriculture', 'Social Sciences and History', 'Business',
'Math and Statistics', 'Architecture', 'Physical Sciences',
'Computer Science', 'Engineering']
y_offsets = {'Foreign Languages': 0.5, 'English': -0.5,
'Communications\nand Journalism': 0.75,
'Art and Performance': -0.25, 'Agriculture': 1.25,
'Social Sciences and History': 0.25, 'Business': -0.75,
'Math and Statistics': 0.75, 'Architecture': -0.75,
'Computer Science': 0.75, 'Engineering': -0.25}
for column in majors:
column_rec_name = column.replace('\n', '_').replace(' ', '_')
line, = ax.plot('Year', column_rec_name, data=gender_degree_data,
lw=2.5)
y_pos = gender_degree_data[column_rec_name][-1] - 0.5
if column in y_offsets:
y_pos += y_offsets[column]
ax.text(2011.5, y_pos, column, fontsize=14, color=line.get_color())
fig.suptitle("Percentage of Bachelor's degrees conferred to women in "
"the U.S.A. by major (1970-2011)", fontsize=18, ha="center")
plt.show()