In [7]:
%matplotlib inline
from __future__ import division

import collections

import matplotlib as mpl
import matplotlib.pyplot as plt

import numpy as np

In [3]:
x1 = np.linspace(0, 10, 101)
x2 = np.linspace(0, 10, 11)
y1 = np.sin(x1)
y2 = np.sin(x2)

In [6]:
plt.plot(x1, y1, x2, y2)


Out[6]:
[<matplotlib.lines.Line2D at 0x10c967410>,
 <matplotlib.lines.Line2D at 0x10c967690>]

In [8]:
isinstance(x1, collections.Iterable)


Out[8]:
True

In [19]:
x = [x1, x2]
y = [y1, y2]


Out[19]:
[<matplotlib.lines.Line2D at 0x10cccb1d0>]

In [20]:
x


Out[20]:
[array([  0. ,   0.1,   0.2,   0.3,   0.4,   0.5,   0.6,   0.7,   0.8,
          0.9,   1. ,   1.1,   1.2,   1.3,   1.4,   1.5,   1.6,   1.7,
          1.8,   1.9,   2. ,   2.1,   2.2,   2.3,   2.4,   2.5,   2.6,
          2.7,   2.8,   2.9,   3. ,   3.1,   3.2,   3.3,   3.4,   3.5,
          3.6,   3.7,   3.8,   3.9,   4. ,   4.1,   4.2,   4.3,   4.4,
          4.5,   4.6,   4.7,   4.8,   4.9,   5. ,   5.1,   5.2,   5.3,
          5.4,   5.5,   5.6,   5.7,   5.8,   5.9,   6. ,   6.1,   6.2,
          6.3,   6.4,   6.5,   6.6,   6.7,   6.8,   6.9,   7. ,   7.1,
          7.2,   7.3,   7.4,   7.5,   7.6,   7.7,   7.8,   7.9,   8. ,
          8.1,   8.2,   8.3,   8.4,   8.5,   8.6,   8.7,   8.8,   8.9,
          9. ,   9.1,   9.2,   9.3,   9.4,   9.5,   9.6,   9.7,   9.8,
          9.9,  10. ]),
 array([  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.])]

In [24]:
all(isinstance(i, collections.Iterable) for i in x)


Out[24]:
True

In [27]:
fig, ax = plt.subplots()
ax.plot(x1, y1)


Out[27]:
[<matplotlib.lines.Line2D at 0x10e019ad0>]

In [29]:
def plot(x, y, scale='linear', xlabel=None, ylabel=None, shape=None, xlim=None, ylim=None, rcParams={}):
    with mpl.rc_context(rc=rcParams):
        fig, ax = plt.subplots()

        plotting_functions = {
            'linear': ax.plot,
            'semilogy': ax.semilogy,
            'semilogx': ax.semilogx,
            'loglog': ax.loglog}

        plot = plotting_functions[scale]

        nested_iterable = all(isinstance(i, collections.Iterable) for i in x)
        
        if nested_iterable:
            for xi, yi in zip(x, y):
                plot(xi, yi)
        else:
            plot(x, y)

        if xlabel is not None:
            ax.set_xlabel(xlabel)
        if ylabel is not None:
            ax.set_ylabel(ylabel)

        if xlim is not None:
            ax.set_xlim(xlim)
        if ylim is not None:
            ax.set_ylim(ylim)

        fig.tight_layout()
        return fig, ax

In [31]:
fig, ax = plot(x, y)



In [32]:
plot(x=[[1, 2, 3], [1, 1.5, 2, 2.5, 3]], y=[[1, 4, 9], [1, 2, 4, 6, 9]])


Out[32]:
(<matplotlib.figure.Figure at 0x10e0b9910>,
 <matplotlib.axes.AxesSubplot at 0x10e016450>)

In [38]:
def nested_iterable(x):
    """Return true if x is (at least) list of lists, or a 2D numpy array, or list of 1D
    numpy arrays.

    Raises a TypeError if passed a non-iterable."""
    return all(isinstance(i, collections.Iterable) for i in x)


def plot(x, y, scale='linear', xlabel=None, ylabel=None, shape=None,
         xlim=None, ylim=None, rcParams={}):
    with mpl.rc_context(rc=rcParams):
        fig, ax = plt.subplots()

        plotting_functions = {
            'linear': ax.plot,
            'semilogy': ax.semilogy,
            'semilogx': ax.semilogx,
            'loglog': ax.loglog}

        plot = plotting_functions[scale]

        nested_x = nested_iterable(x)
        nested_y = nested_iterable(y)
        
        if nested_x and nested_x:
            for xi, yi in zip(x, y):
                plot(xi, yi)
        elif nested_y:
            for yi in y:
                plot(x, yi)
        elif nested_x:
            raise ValueError("Give multiple y arrays for multiple x arrays")
        else:
            plot(x, y)

        if xlabel is not None:
            ax.set_xlabel(xlabel)
        if ylabel is not None:
            ax.set_ylabel(ylabel)

        if xlim is not None:
            ax.set_xlim(xlim)
        if ylim is not None:
            ax.set_ylim(ylim)

        fig.tight_layout()
        return fig, ax

In [43]:
plot(x=[1, 2, 3], y=[[1, 2, 3], [1, 4, 9]], xlabel='x [nm]', shape=)


Out[43]:
(<matplotlib.figure.Figure at 0x10e2bc3d0>,
 <matplotlib.axes.AxesSubplot at 0x10e88a210>)

In [48]:
x = 'abc'
isinstance(x, str)


Out[48]:
True

In [49]:
def iterable(x):
    """True if x is an iterable other than a string: some sort of list-like
    container"""
    if isinstance(x, str):
        return False
    else:
        return isinstance(x, collections.Iterable)

In [51]:
iterable(['ac', 'bd'])


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-51-25708609003e> in <module>()
----> 1 iterable(['ac', 'bd'])

<ipython-input-49-a1d0a0c6471f> in iterable(x)
      5         return False
      6     else:
----> 7         return isinstance(i, collections.Iterable)

NameError: global name 'i' is not defined

In [ ]: