In [3]:
%matplotlib inline
In [9]:
import matplotlib.pyplot as plt
import numpy as np
def f(x):
return x**3 + 3*x**2 - x + 1
fig, ax = plt.subplots()
x = np.linspace(-3, 3.2, 100)
ax.plot(x, f(x))
plt.show()
I will try to collect here some methods that I have found useful, almost all of these are taken from SO posts and other sites and where I remember I will be sure to give credit where it is due. I am simply collecting the information here to avoid repeated Googling!
The most distinct difference that we notice is the placement of the axes: on the bournding box rather than at $(0, 0)$. While this makes practical sense when we are plotting data, for a function alone it is out-of-style. There are many ways to approach this, here are those that I know:
This is the most direct method in which we set the spin by hand:
In [25]:
def CenterAxes(ax):
ax.spines['left'].set_position('zero')
ax.spines['right'].set_color('none')
ax.spines['bottom'].set_position('zero')
ax.spines['top'].set_color('none')
ax.spines['left'].set_smart_bounds(True)
ax.spines['bottom'].set_smart_bounds(True)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.axhline(linewidth=1.5, color='k')
ax.axvline(linewidth=1.5, color='k')
return ax
fig, ax = plt.subplots()
x = np.linspace(-3, 3.2, 100)
ax.plot(x, f(x))
ax = CenterAxes(ax)
plt.show()
Here is a second, somewhat simpler method which also allows you to easily add arrows. The downside is that we are using the SubplotZero axes class which is not as well built as the matplotlib axes class. For example, I cannot for the life of me figure out how to rotate the yticks!
In [48]:
from mpl_toolkits.axes_grid.axislines import SubplotZero
fig = plt.figure()
ax = SubplotZero(fig, 111)
fig.add_subplot(ax)
ax.plot(x, f(x))
# Remove the old axes
for direction in ["left", "right", "bottom", "top"]:
ax.axis[direction].set_visible(False)
# Add the new axes with arrows
for direction in ["xzero", "yzero"]:
ax.axis[direction].set_axisline_style("-|>")
ax.axis[direction].set_visible(True)
plt.show()
Often we don't want to plot numerical labels, but instead some algebraic quantity such as $\pi$. There are two methods to do this that I am aware of.
This is the first, and most obvious way to get the desired effect. Unfortunately it is often discouraged by matplotlib developers and SO posts. Personally I see no harm, and often it is significantly easier to remember!
In [63]:
x = np.linspace(-np.pi, np.pi, 100)
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))
# Manual setup of the ticks and labels
xticks = np.arange(-np.pi, np.pi + np.pi/2, np.pi/2)
xticklabels = ["$-\pi$", "$-\pi/2$", "$0$", "$\pi/2$", "$\pi$"]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, size=16)
plt.show()
If your not using the object-oriented ax methods then you can shorten this by
simply passing both the xticks and xticklabels to plt.xticks() as follows:
In [62]:
fig = plt.figure()
plt.plot(x, np.sin(x))
# Manual setup of the ticks and labels
xticks = np.arange(-np.pi, np.pi + np.pi/2, np.pi/2)
xticklabels = ["$-\pi$", "$-\pi/2$", "$0$", "$\pi/2$", "$\pi$"]
plt.xticks(xticks, xticklabels)
plt.show()
In [89]:
def is_multiple(x, y, tol=0.1):
modx =np.mod(x, y)
if modx < tol or 1 - modx < tol:
return int(float(x) / y)
else:
return None
def format_val(val, tol=0.1):
if abs(abs(val) - 1) < tol:
return ""
else:
return "{:1.0f}".format(val)
def GetPiLabel(val):
coef = val / np.pi
if abs(coef) < 0.1:
return "$0$"
coef_int = is_multiple(coef, 1.0)
while coef_int:
return "${}\pi$".format(format_val(coef))
coef_half = is_multiple(coef, 0.5)
while coef_half:
return "${}\pi/{{2}}$".format(format_val(coef/0.5))
x = np.linspace(0, 2*np.pi, 100)
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))
# Manual setup of the ticks and labels
xticks = np.arange(0, 2*np.pi + np.pi/2, np.pi/2)
xticklabels = [GetPiLabel(xtick) for xtick in xticks]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, size=14)
plt.show()
Of course a more powerful function would be better than created generic fractions.
The advised way to set the labels is through the matplotlib.ticker method. This method is lifted pretty much directly from this example in the docs. We can also use the function defined earlier on, but we need to modify it slightly to accept a position which is subsequently ignored.
In [98]:
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, FuncFormatter
def GetPiLabel(val, position):
coef = val / np.pi
if abs(coef) < 0.1:
return "$0$"
coef_int = is_multiple(coef, 1.0)
while coef_int:
return "${}\pi$".format(format_val(coef))
coef_half = is_multiple(coef, 0.5)
while coef_half:
return "${}\pi/{{2}}$".format(format_val(coef/0.5))
x = np.linspace(0, 2*np.pi, 100)
fig, ax = plt.subplots()
ax.plot(x, np.sin(x))
majorLocator = MultipleLocator(np.pi/2)
ax.xaxis.set_major_locator(majorLocator)
majorFormatter = FuncFormatter(GetPiLabel)
ax.xaxis.set_major_formatter(majorFormatter)
plt.show()