In [2]:
import numpy as np
import itertools
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

%matplotlib inline

In [5]:
plt.xkcd()
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')

label_size = 15
filename = '3d_plot.pdf'

points = [0,1]
z_points = [0,1,2]
x_cap = ['B', 'N']
y_cap = ['A', 'P']
z_cap = ['M', 'T', 'C']

for no, (x, y, z) in enumerate(itertools.product(points, points, z_points)):
    color = 'k'
    if no in [2, 5]:
        color = 'b'
#    if no in [0, 2, 4, 6, 7]:
#    if no in [0]:
#        color = 'r'
        
    ax.scatter(x, y, z, s=40, c=color)
    ax.text(x-0.1, y, z+0.1,  '%d(%s,%s,%s)' % (no, x_cap[x], y_cap[y], z_cap[z]), size=label_size, color=color)

x_labels = ['Bayesian', 'Non-Bayesian']
y_labels = ['Active', 'Passive']
z_labels = ['Matrix', 'Tensor', 'Comp.']
ax.set_xticks(points)
ax.set_yticks(points)
ax.set_zticks(z_points)
ax.set_xticklabels(x_labels, size=label_size)
ax.set_yticklabels(y_labels, size=label_size)
ax.set_zticklabels(z_labels, size=label_size)

plt.savefig(filename, format='PDF', bbox_inches='tight', pad_inches=0.1)



In [ ]: