from IPython.core.display import display, HTML
toc = !nbtoc `pwd` #/NOTEBOOK_NAME.ipynb
display(HTML(toc[0]))
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
matplotlib.interactive(False)
x = np.linspace(1,10,50)
y1 = x**2
y2 = 1/x
y3 = np.cos(x)
y4 = x
plt.plot(x, y1, x, y2, x, y3, x, y4)
plt.ylim(-1, 5)
plt.show()
A figure is a canvas where you can put several axes, thas is, graphs.
There is only one figure at a time.
Get access to the figure (only one figure in the plt
module):
fig = plt.figure()
Now you add axes to the figure. You need to specify the location and size: [left, bottom, width, height]
(it's a list!). The bottom left is at (0,0). Each value is a fraction (percentage) of the figure. It can be greater than 1 (for example: 1.5 -> 150%).
Note that you need some room between the axes (0.1). It doesn't matter if the heights + the offset add up to more than 1:
ax1 = fig.add_axes([0, 0.6, 1, 0.5]) # note that we start at 0.6 here
ax2 = fig.add_axes([0, 0, 1, 0.5])
plt.show()
fig = plt.figure()
ax1 = fig.add_axes([0, 0.6, 1.2, 0.5]) # note that we start at 0.6 here
ax2 = fig.add_axes([0, 0, 1, 0.5])
plt.show()
Then you set ax.plot()
as you use plt.plot()
in the introduction.
Sometimes, some function are different, for example: ax.set_xlabel()
whereas plt.xlabel()
.
fig = plt.figure()
ax1 = fig.add_axes([0, 0.6, 1, 0.5])
ax2 = fig.add_axes([0, 0, 1, 0.5])
ax1.plot(x, y1)
ax1.plot(x, y2)
ax1.legend(labels=['a function', 'another function'])
ax1.set_xlabel('x values')
ax1.set_ylabel('y values')
ax2.plot(x, y3)
ax2.set_xlabel('x values')
ax2.set_ylabel('y values')
plt.show()
Because you can put the axes where you want on the figure, you can put axes inside other axes:
fig = plt.figure()
ax1 = fig.add_axes([0, 0, 1, 1])
ax2 = fig.add_axes([0.15, 0.55, 0.5, 0.4])
ax1.plot(x, y1)
ax1.plot(x, y2)
ax1.legend(labels=['a function', 'another function'], loc='center right')
ax1.set_xlabel('x values')
ax1.set_ylabel('y values')
ax2.plot(x, y3)
ax2.set_xlabel('x values')
ax2.set_ylabel('y values')
plt.show()
You can add a figure size, in inches. This allows you to change the default ratio.
fig = plt.figure(figsize=(6,3))
ax1 = fig.add_axes([0, 0, 1, 1])
ax2 = fig.add_axes([0.15, 0.55, 0.5, 0.4])
ax1.plot(x, y1)
ax1.plot(x, y2)
ax1.legend(labels=['a function', 'another function'], loc='center right')
ax1.set_xlabel('x values')
ax1.set_ylabel('y values')
ax2.plot(x, y3)
ax2.set_xlabel('x values')
ax2.set_ylabel('y values')
plt.show()
You can set the figure afterwards:
fig.set_figwidth(6)
fig.set_figheight(6)
See the automatic method below.
Axes can be freely arranged on the figure, but if you want a grid-like organization, you should use subplots:
fig.add_subplot(121)
where 121
is:
1
: the number of rows2
: the number of columns1
: the position of the current subplotThe location of each subplot follows the pattern:
1 2
3 4
1 2 3
4 5 6
etc.
See below the plt.subplots()
function that returns a grid at once.
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax1.plot(x, y1)
ax2 = fig.add_subplot(122)
ax2.plot(x, y2)
plt.show()
fig = plt.figure()
ax1 = fig.add_subplot(221)
ax1.plot(x, y1)
ax2 = fig.add_subplot(222)
ax2.plot(x, y2)
ax3 = fig.add_subplot(223)
ax3.plot(x, y3)
ax4 = fig.add_subplot(224)
ax4.plot(x, y4)
plt.show()
The plt.subplots(nrows, ncols)
return a tuple (fig, axes)
, where axes is a numpy array.
The actual ouput depends on the sqeeze
argument: if True
, extra dimensions are remove (no array at all if (1,1), 1-d array if on dimension is 1). If False
, always return a 2-d array.
To access:
fig, ax = plt.subplots(2, 2)
print(ax[0,1])
ax
or:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.plot(x, y4)
plt.show()
You can share axes. When sharing x, the x-axis appear only on the bottom graph. When you share y, the y-axis appears only at the left.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex=True)
ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.plot(x, y4)
plt.show()
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharey=True)
ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.plot(x, y4)
plt.show()
If your axes are too tight, you can also use tight_layout()
:
fig, axes = plt.subplots(2,3)
plt.show()
fig, axes = plt.subplots(2,3)
plt.tight_layout()
plt.show()
What is you want only 3 subplots? Use the Axes.remove()
method:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.remove()
plt.show()
Use subplot2grid()
if you want row and col spans:
subplot2grid( (num_rows, num_cols), (row, col), [rowspan=2], [colspan=2] )
There are some differences from add_subplot()
:
Figure
, but of the module plt
,ax1 = plt.subplot2grid((2,2), (0,0))
ax1.plot(x, y1)
ax2 = plt.subplot2grid((2,2), (0,1))
ax2.plot(x, y2)
ax3 = plt.subplot2grid((2,2), (1,0), colspan=2)
ax3.plot(x, y3)
plt.show()
ax1 = plt.subplot2grid((2,2), (0,0))
ax1.plot(x, y1)
ax2 = plt.subplot2grid((2,2), (1,0))
ax2.plot(x, y2)
ax3 = plt.subplot2grid((2,2), (0,1), rowspan=2)
ax3.plot(x, y3)
#ax3.tick_params(left=False, right=True, labelleft=False, labelright=True)
ax3.yaxis.tick_right()
plt.show()
ax
and plt
¶Usually, you add set_
before the method:
ax.set_xlabel("foo") # instead of plt.xlabel
Sometimes, it's a little different:
ax.yaxis.tick_right() # but you can use ax.tick_params(...)