In [2]:
%matplotlib inline
from __future__ import division
import collections
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
In [15]:
x1 = np.linspace(0, 10, 101)
x2 = np.linspace(0, 10, 11)
y1 = np.sin(x1)
y2 = np.sin(x2)
In [4]:
plt.plot(x1, y1, x2, y2)
Out[4]:
In [5]:
isinstance(x1, collections.Iterable)
Out[5]:
In [6]:
x = [x1, x2]
y = [y1, y2]
In [7]:
x
Out[7]:
In [8]:
all(isinstance(i, collections.Iterable) for i in x)
Out[8]:
In [9]:
fig, ax = plt.subplots()
ax.plot(x1, y1)
Out[9]:
In [10]:
def plot(x, y, scale='linear', xlabel=None, ylabel=None, shape=None, xlim=None, ylim=None, rcParams={}):
with mpl.rc_context(rc=rcParams):
fig, ax = plt.subplots()
plotting_functions = {
'linear': ax.plot,
'semilogy': ax.semilogy,
'semilogx': ax.semilogx,
'loglog': ax.loglog}
plot = plotting_functions[scale]
nested_iterable = all(isinstance(i, collections.Iterable) for i in x)
if nested_iterable:
for xi, yi in zip(x, y):
plot(xi, yi)
else:
plot(x, y)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
fig.tight_layout()
return fig, ax
In [11]:
fig, ax = plot(x, y)
In [12]:
plot(x=[[1, 2, 3], [1, 1.5, 2, 2.5, 3]], y=[[1, 4, 9], [1, 2, 4, 6, 9]])
Out[12]:
In [13]:
def nested_iterable(x):
"""Return true if x is (at least) list of lists, or a 2D numpy array, or list of 1D
numpy arrays.
Raises a TypeError if passed a non-iterable."""
return all(isinstance(i, collections.Iterable) for i in x)
def plot(x, y, scale='linear', xlabel=None, ylabel=None, shape=None,
xlim=None, ylim=None, rcParams={}):
with mpl.rc_context(rc=rcParams):
fig, ax = plt.subplots()
plotting_functions = {
'linear': ax.plot,
'semilogy': ax.semilogy,
'semilogx': ax.semilogx,
'loglog': ax.loglog}
plot = plotting_functions[scale]
nested_x = nested_iterable(x)
nested_y = nested_iterable(y)
if nested_x and nested_x:
for xi, yi in zip(x, y):
plot(xi, yi)
elif nested_y:
for yi in y:
plot(x, yi)
elif nested_x:
raise ValueError("Give multiple y arrays for multiple x arrays")
else:
plot(x, y)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
fig.tight_layout()
return fig, ax
In [14]:
plot(x=[1, 2, 3], y=[[1, 2, 3], [1, 4, 9]], xlabel='x [nm]')
In [ ]:
def iterable(x):
"""True if x is an iterable other than a string: some sort of list-like
container"""
if isinstance(x, str):
return False
else:
return isinstance(x, collections.Iterable)
In [ ]:
iterable(['ac', 'bd'])
In [35]:
plt.plot(x1, y1, )
In [ ]: