Figure

In [1]:
from IPython.core.display import display, HTML
toc = !nbtoc `pwd` #/NOTEBOOK_NAME.ipynb
display(HTML(toc[0]))
In [1]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
In [2]:
matplotlib.interactive(False)
In [3]:
x = np.linspace(1,10,50)
y1 = x**2
y2 = 1/x
y3 = np.cos(x)
y4 = x
plt.plot(x, y1, x, y2, x, y3, x, y4)
plt.ylim(-1, 5)
plt.show()

Figure and axes

A figure is a canvas where you can put several axes, thas is, graphs.

There is only one figure at a time.

Get access to the figure (only one figure in the plt module):

In [17]:
fig = plt.figure()

Now you add axes to the figure. You need to specify the location and size: [left, bottom, width, height] (it's a list!). The bottom left is at (0,0). Each value is a fraction (percentage) of the figure. It can be greater than 1 (for example: 1.5 -> 150%).

Note that you need some room between the axes (0.1). It doesn't matter if the heights + the offset add up to more than 1:

In [18]:
ax1 = fig.add_axes([0, 0.6, 1, 0.5]) # note that we start at 0.6 here
ax2 = fig.add_axes([0,   0, 1, 0.5])
plt.show()
In [38]:
fig = plt.figure()
ax1 = fig.add_axes([0, 0.6, 1.2, 0.5]) # note that we start at 0.6 here
ax2 = fig.add_axes([0,   0,   1, 0.5])
plt.show()

Then you set ax.plot() as you use plt.plot() in the introduction.

Sometimes, some function are different, for example: ax.set_xlabel() whereas plt.xlabel().

In [46]:
fig = plt.figure()

ax1 = fig.add_axes([0, 0.6, 1, 0.5])
ax2 = fig.add_axes([0,   0, 1, 0.5])

ax1.plot(x, y1)
ax1.plot(x, y2)
ax1.legend(labels=['a function', 'another function'])
ax1.set_xlabel('x values')
ax1.set_ylabel('y values')

ax2.plot(x, y3)
ax2.set_xlabel('x values')
ax2.set_ylabel('y values')

plt.show()

Because you can put the axes where you want on the figure, you can put axes inside other axes:

In [54]:
fig = plt.figure()

ax1 = fig.add_axes([0, 0, 1, 1])
ax2 = fig.add_axes([0.15, 0.55, 0.5, 0.4])

ax1.plot(x, y1)
ax1.plot(x, y2)
ax1.legend(labels=['a function', 'another function'], loc='center right')
ax1.set_xlabel('x values')
ax1.set_ylabel('y values')

ax2.plot(x, y3)
ax2.set_xlabel('x values')
ax2.set_ylabel('y values')

plt.show()

Figure size

You can add a figure size, in inches. This allows you to change the default ratio.

In [62]:
fig = plt.figure(figsize=(6,3))

ax1 = fig.add_axes([0, 0, 1, 1])
ax2 = fig.add_axes([0.15, 0.55, 0.5, 0.4])

ax1.plot(x, y1)
ax1.plot(x, y2)
ax1.legend(labels=['a function', 'another function'], loc='center right')
ax1.set_xlabel('x values')
ax1.set_ylabel('y values')

ax2.plot(x, y3)
ax2.set_xlabel('x values')
ax2.set_ylabel('y values')

plt.show()
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>

You can set the figure afterwards:

fig.set_figwidth(6)
fig.set_figheight(6)

Figure and subplots

See the automatic method below.

The grid by hand

Axes can be freely arranged on the figure, but if you want a grid-like organization, you should use subplots:

fig.add_subplot(121)

where 121 is:

  • 1: the number of rows
  • 2: the number of columns
  • 1: the position of the current subplot

The location of each subplot follows the pattern:

1 2
3 4

1 2 3
4 5 6

etc.

See below the plt.subplots() function that returns a grid at once.

In [69]:
fig = plt.figure()

ax1 = fig.add_subplot(121)
ax1.plot(x, y1)

ax2 = fig.add_subplot(122)
ax2.plot(x, y2)

plt.show()
In [70]:
fig = plt.figure()

ax1 = fig.add_subplot(221)
ax1.plot(x, y1)

ax2 = fig.add_subplot(222)
ax2.plot(x, y2)

ax3 = fig.add_subplot(223)
ax3.plot(x, y3)

ax4 = fig.add_subplot(224)
ax4.plot(x, y4)

plt.show()

Get everything at once

The plt.subplots(nrows, ncols) return a tuple (fig, axes), where axes is a numpy array.

The actual ouput depends on the sqeeze argument: if True, extra dimensions are remove (no array at all if (1,1), 1-d array if on dimension is 1). If False, always return a 2-d array.

To access:

In [96]:
fig, ax = plt.subplots(2, 2)
print(ax[0,1])
ax
AxesSubplot(0.547727,0.53;0.352273x0.35)
Out[96]:
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7f30e6842390>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x7f30e6456860>],
       [<matplotlib.axes._subplots.AxesSubplot object at 0x7f30e6921cc0>,
        <matplotlib.axes._subplots.AxesSubplot object at 0x7f30e6ba3278>]],
      dtype=object)

or:

In [97]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
In [102]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.plot(x, y4)

plt.show()

You can share axes. When sharing x, the x-axis appear only on the bottom graph. When you share y, the y-axis appears only at the left.

In [4]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex=True)

ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.plot(x, y4)

plt.show()
In [5]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharey=True)

ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.plot(x, y4)

plt.show()

If your axes are too tight, you can also use tight_layout():

In [4]:
fig, axes = plt.subplots(2,3)
plt.show()
In [5]:
fig, axes = plt.subplots(2,3)
plt.tight_layout()
plt.show()

What is you want only 3 subplots? Use the Axes.remove() method:

In [8]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

ax1.plot(x, y1)
ax2.plot(x, y2)
ax3.plot(x, y3)
ax4.remove()

plt.show()

Row and col spans

Use subplot2grid() if you want row and col spans:

subplot2grid( (num_rows, num_cols), (row, col), [rowspan=2], [colspan=2] )

There are some differences from add_subplot():

  • it's not a function of Figure, but of the module plt,
  • indices don't start at at, but at 0.
In [76]:
ax1 = plt.subplot2grid((2,2), (0,0))
ax1.plot(x, y1)

ax2 = plt.subplot2grid((2,2), (0,1))
ax2.plot(x, y2)

ax3 = plt.subplot2grid((2,2), (1,0), colspan=2)
ax3.plot(x, y3)

plt.show()
In [86]:
ax1 = plt.subplot2grid((2,2), (0,0))
ax1.plot(x, y1)

ax2 = plt.subplot2grid((2,2), (1,0))
ax2.plot(x, y2)

ax3 = plt.subplot2grid((2,2), (0,1), rowspan=2)
ax3.plot(x, y3)
#ax3.tick_params(left=False, right=True, labelleft=False, labelright=True)
ax3.yaxis.tick_right()

plt.show()

More complicated arrangements

See plt.GridSpec, example here

Difference between ax and plt

Usually, you add set_ before the method:

ax.set_xlabel("foo") # instead of plt.xlabel

Sometimes, it's a little different:

ax.yaxis.tick_right() # but you can use ax.tick_params(...)