In [7]:
%matplotlib inline
from __future__ import division
import collections
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
In [3]:
x1 = np.linspace(0, 10, 101)
x2 = np.linspace(0, 10, 11)
y1 = np.sin(x1)
y2 = np.sin(x2)
In [6]:
plt.plot(x1, y1, x2, y2)
Out[6]:
In [8]:
isinstance(x1, collections.Iterable)
Out[8]:
In [19]:
x = [x1, x2]
y = [y1, y2]
Out[19]:
In [20]:
x
Out[20]:
In [24]:
all(isinstance(i, collections.Iterable) for i in x)
Out[24]:
In [27]:
fig, ax = plt.subplots()
ax.plot(x1, y1)
Out[27]:
In [29]:
def plot(x, y, scale='linear', xlabel=None, ylabel=None, shape=None, xlim=None, ylim=None, rcParams={}):
with mpl.rc_context(rc=rcParams):
fig, ax = plt.subplots()
plotting_functions = {
'linear': ax.plot,
'semilogy': ax.semilogy,
'semilogx': ax.semilogx,
'loglog': ax.loglog}
plot = plotting_functions[scale]
nested_iterable = all(isinstance(i, collections.Iterable) for i in x)
if nested_iterable:
for xi, yi in zip(x, y):
plot(xi, yi)
else:
plot(x, y)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
fig.tight_layout()
return fig, ax
In [31]:
fig, ax = plot(x, y)
In [32]:
plot(x=[[1, 2, 3], [1, 1.5, 2, 2.5, 3]], y=[[1, 4, 9], [1, 2, 4, 6, 9]])
Out[32]:
In [38]:
def nested_iterable(x):
"""Return true if x is (at least) list of lists, or a 2D numpy array, or list of 1D
numpy arrays.
Raises a TypeError if passed a non-iterable."""
return all(isinstance(i, collections.Iterable) for i in x)
def plot(x, y, scale='linear', xlabel=None, ylabel=None, shape=None,
xlim=None, ylim=None, rcParams={}):
with mpl.rc_context(rc=rcParams):
fig, ax = plt.subplots()
plotting_functions = {
'linear': ax.plot,
'semilogy': ax.semilogy,
'semilogx': ax.semilogx,
'loglog': ax.loglog}
plot = plotting_functions[scale]
nested_x = nested_iterable(x)
nested_y = nested_iterable(y)
if nested_x and nested_x:
for xi, yi in zip(x, y):
plot(xi, yi)
elif nested_y:
for yi in y:
plot(x, yi)
elif nested_x:
raise ValueError("Give multiple y arrays for multiple x arrays")
else:
plot(x, y)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
fig.tight_layout()
return fig, ax
In [43]:
plot(x=[1, 2, 3], y=[[1, 2, 3], [1, 4, 9]], xlabel='x [nm]', shape=)
Out[43]:
In [48]:
x = 'abc'
isinstance(x, str)
Out[48]:
In [49]:
def iterable(x):
"""True if x is an iterable other than a string: some sort of list-like
container"""
if isinstance(x, str):
return False
else:
return isinstance(x, collections.Iterable)
In [51]:
iterable(['ac', 'bd'])
In [ ]: