This post is written for my Machine Learning for Physics and Astronomy Students, Spring 2022.

It’s a shame that no one ever explains how to use matplotlib to students because it’s “just a tool” and therefore less important than the content of classes. But I find that students usually grasp math, physics, statistics concepts really quickly yet struggle at implementation (and visualization) because the tools were not explained properly to them.

We’re going to do a lot of plotting; physicists always do a lot of plotting. So here I’m offering a small explanation of how “making a figure” with matplotlib works. Once you understand how a tool works, you’ll be able to debug when things go wrong, leading to a much more pleasant user experience.

Matplotlib’s Object-Oriented User-Side Approach

Matplotlib creates your figure by dealing with three items in the back:

  1. matplotlib.backend_bases.FigureCanvas is the area onto which a figure is drawn
  2. matplotlib.backend_bases.Renderer is the object that knows how to actually draws onto the FigureCanvas, for a variety of different backends
  3. matplotlib.artist.Artist is the object that actually knows how to use the renderer to paint onto the canvas.

For us, the most important part to understand is the further structure of the Artist objects. The typical user will spend 95% of their time working only with with Artist objects. There are two broad types of Artist objects: primitives and containers. Primitives are the actual graphics that get drawn, like lines (Line2D), bars in bar graphs (Rectangle), text (Text), etc. The containers are the things you place them into (Axes, Figure).

Now, this is where it gets confusing. Figure and Axes, for example, come directly from their own modules, i.e. matplotlib.figure.Figure and matplotlib.axes.Axes. Although you always instantiate a Figure object and one or more Axes object, you never actually call directly from these modules. Instead, the official documentation always directs you to call from matplotlib.pyplot, even when they describe “pyplot is a collection of command style functions.” That’s only part of what it does. The other thing it does is provide easy-to-read wrappers for interacting directly with other matplotlib modules ala MATLAB. I recommend getting used to NOT using the MATLAB-style plots for real plotting (this is when you just call plt.plot(x,y), for example) because more difficult plots are only possible with the object-oriented approach.

So now let’s introduce a basic example of matplotlib usage via the object-oriented approach (though this might look fairly different from what you usually see as basic usage).

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0,1,15)
y = np.linspace(0,1,15)

fig = plt.figure()
ax = fig.add_axes([0.2, 0.4, 0.5, 0.3])

ax.plot(x,y)
ax.set_title('A Good Title')
ax.set_xlabel('Property 1')
ax.set_ylabel('Property 2')
plt.show()

This produces the following image as output:

Let’s break this down line by line.

The first two lines import the aforementioned wrappers and numpy. The next two lines create something for us to plot.

plt.figure() instantiates a Figure object, which is the outermost container for a matplotlib graphic. It can contain one or more Axes objects, which is what the average human would refer to as “the plot itself,” a set of axes that are tied together. It does NOT refer to the x-axis and y-axis that you think it would. To see the difference, consider this example:

Every figure needs at least one Axes object to actually have something to plot onto. We need to add these Axes objects ourselves. That’s what’s done in the next line: we call Figure.add_axes() which takes in as input the dimensions [xmin, ymin, dx, dy]. These numbers are fractions of Figure width and height, that specify the following:

  • xmin: Horizontal coordinate of the lower left corner.
  • ymin: Vertical coordinate of the lower left corner.
  • dx: Width of the subplot.
  • dy: Height of the subplot.

To see this in action, consider when I change these attributes to the values to [0.2, 0.4, 0.5, 0.3]:

Note that all the white space around the Figure is the Figure and the Axes object is now smaller because I’ve set dx and dy to be smaller. To bring it back to the earlier discussion, so far we’ve created Figure and Axes objects, which are container Artists.

In the last lines is where we add items that are based on primitive Artists. ax.plot() plots our data as what is likely a matplotlib.lines.Line2D, and the titles and axis labels are matplotlib.text.Text objects. Note that at no point did we interact with these primitives but they were generated by methods that comes from matplotlib.axes.Axes, which was generated by the pyplot wrapper that generated our Axes object. Indeed, the methods of matplotlib.axes is where most of the plotting takes place and hence the documentation for that submodule should be the first search area for when you want to plot something new to you, like a pie plot (Axes.pie) or a contour plot (Axes.contour).

Sometimes you have to interact with the matplotlib primitives directly e.g. when you want to add a box around something (then you need Rectangle), but 90% of the time, you don’t, and you should really only be looking at the methods of matplotlib.axes.

Lastly: plt.show() finally ties it back together with the renderer and actually displays all open Figure objects. You don’t need to call this item in jupyter notebooks because the command %matplotlib inline calls this line at the end of every cell by default (for better or for worse).

Actual Plotting Examples

Now that we know how things work in the back, let’s take a look at what your plotting scripts will probably actually look like.

First, we uaully combine the first two lines of creating a figure and axes into one line with plt.subplots, which also lets us create… subplots. So instead of

fig = plt.figure()
ax = fig.add_axes([0.2, 0.4, 0.5, 0.3])

we usually see

fig, ax = plt.subplots(1,1)
# or simply
fig, ax = plt.subplots()
# or more fully
fig, ax = plt.subplots(nrows=2,ncols=2)

This is the usage I recommend for every day plotting since you’re likely to be making subplots. All plt.subplots() does is create the Figure object and axes objects (“subplots”) simultaneously. If you call it empty as in the second example, you just get one plot. If you call it as in the last option, you get a 2x2 grid of subplots, which looks like this:

Then, in that case, what is ax? A quick investigation then:

>>> type(ax)<class 'numpy.ndarray'>
>>> ax.shape
(2,2)

It seems that ax is nothing more than a 2x2 numpy array storing each of the four Axes objects. If we want to plot something in each of them, we need to just index them as usual. Let’s plot a red line in the upper left, a pink line in the upper right, a blue line in the lower left and a green line in the lower right:

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0,1,15)
y = np.linspace(0,1,15)

fig, ax = plt.subplots(2,2)
ax[0][0].plot(x,y, color='red')
ax[0][1].plot(x,y, color='pink')
ax[1][0].plot(x,y, color='blue')
ax[1][1].plot(x,y, color='green')
plt.show()

You could, of course, do this in a loop much nicer, perhaps by using np.ravel() but my point stands.

Now this should all be enough to create whatever fancy plots you want. So let me leave you with a final snippet and output to get you started on whatever plotting you may need to do in the future.

import matplotlib.pyplot as plt
import numpy as np

scalars = [1,0.5, 0.25]
models = [s*np.sin(np.linspace(0, 2*np.pi,100)) for s in scalars ]
data = [model + 0.1*np.random.randn(100) for model in models]

fig, axes = plt.subplots(2,3, figsize=(8,4), dpi=200, sharex='all', sharey='row')
for i, ax in enumerate(axes.T): # get columns instead of rows
	# in top plot, make data and model
	ax[0].plot(data[i], 'o', color='C'+str(i), alpha=0.5) 
	ax[0].plot(models[i], color='C'+str(i))
	# in bottom plot, make residuals
	ax[1].plot(data[i] - models[i], '.', color='C'+str(i),)
	
	if i == 0:
		ax[0].set_ylabel('Amplitude')
		ax[1].set_ylabel('Residuals')

	if i ==1:
		ax[1].set_xlabel('Time (s)')

plt.tight_layout() # magic line that makes labels not overlap

plt.show()