In [1]:
##### HW 3, problem 3 #####
# Sean Lubner
%matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
plt.close('all') # get rid of other plots
# Load in data as record
dt = [('slen',np.float64), ('swid',np.float64), ('plen',np.float64),
('pwid',np.float64), ('species','S10') ]
try:
flowers = np.loadtxt('hw_3_data/flowers.csv', dt, delimiter=',', skiprows=1).view(np.recarray)
except ValueError:
print "Data is not the right type, please reformat."
raise
except:
print "The data is not correctly formatted."
# Set up figure
ctable = {'setosa':'r','versicolor':'g','virginica':'b'} # color dictionary for plotting flowers
fields = ['slen','swid','plen','pwid'] # iterate over fields
fig, faxs = plt.subplots(4, 4, figsize=(15,10))
fig.suptitle('Brushable Flower Data \nclick & drag to select a region of data in a subplot', fontsize=20)
# initial plotting
data_list=[] # empty bucket to hold initial plotting lines.
for i,ii in enumerate(fields):
for j,jj in enumerate(fields):
temp = faxs[j,i].scatter(flowers[ii][:], flowers[jj][:], c=[ctable[x] for x in flowers.species[:]],
edgecolor='none', alpha=0.4, s=20)
data_list.append(temp)
if j < 3:
faxs[j,i].set_xticklabels([])
if i > 0:
faxs[j,i].set_yticklabels([])
# Label plots
faxs[0,0].annotate("Sepal Length", xy=(0.07, .83), xycoords='axes fraction', size=16)
faxs[1,1].annotate("Sepal Width", xy=(0.07, .83), xycoords='axes fraction', size=16)
faxs[2,2].annotate("Petal Length", xy=(0.07, .83), xycoords='axes fraction', size=16)
faxs[3,3].annotate("Petal Width", xy=(0.07, .83), xycoords='axes fraction', size=16)
# Setup brusher
class brusher(object):
""" brusher object lets a user easily click-and-drag to select a rectangular region of a plot.
The selected data is highlighted in all subplots, while other data is dimmed. """
def __init__(self, plot_figure, plot_axes, plot_data):
""" Initialize some variables, and connect event handlers """
self.data = plot_data
self.fig = plot_figure
self.faxs = plot_axes
self.x0 = None
self.y0 = None
self.x1 = None
self.y1 = None
self.rect_drawing = None
self.brushed_data = [] # initialize
self.fig.canvas.mpl_connect('button_press_event', self.on_press)
self.fig.canvas.mpl_connect('button_release_event', self.on_release)
def on_press(self, event):
""" grab the axes, coordinates, and lay a marker to guide the eye """
self.ax = event.inaxes
self.x0 = event.xdata
self.y0 = event.ydata
self.marker, = self.ax.plot(self.x0, self.y0, 'ro') # guide for the eye during selection
self.ax.figure.canvas.draw()
def on_release(self, event):
""" Draw the rectangle and call replot_selection() to isolate the brushed data """
if event.inaxes != self.ax:
return
self.marker.set_visible(False) # no longer need guide
if self.rect_drawing is not None:
self.rect_drawing.remove()
self.x1 = event.xdata
self.y1 = event.ydata
rec_width=(self.x1 - self.x0)
rec_height=(self.y1 - self.y0)
self.rect = Rectangle((self.x0,self.y0), rec_width, rec_height, alpha=0.3, fc='gray') # create rectangle
self.rect_drawing = self.ax.add_patch(self.rect)
self.ax.figure.canvas.draw()
self.replot_selection()
def replot_selection(self):
""" pulls out the data within the selection, and replots it """
fields = ['slen','swid','plen','pwid']
loc = np.where(faxs == self.ax)
y_series, x_series = fields[loc[0][0]], fields[loc[1][0]]
x_lims = [self.x0, self.x1]
y_lims = [self.y0, self.y1]
x_lims.sort()
y_lims.sort() # in case user selected region backwards
selection_indicesx = [i for i, x in enumerate(self.data[x_series][:]) if (x >= x_lims[0]) and (x <= x_lims[1])]
selection_indicesy = [i for i, x in enumerate(self.data[y_series][:]) if (x >= y_lims[0]) and (x <= y_lims[1])]
selection_indices = list(set(selection_indicesx).intersection(set(selection_indicesy)))
self.plot_points(selection_indices)
self.fig.canvas.draw()
def plot_points(self, points, dull_out=True):
""" Plots 'points' in all subplots, and dims previous points elsewhere """
[a.remove() for a in self.brushed_data] # wipe out previous brushing if it exists
self.brushed_data=[] # wipe the slate clean
if dull_out == True:
[a.set_color('gray') for a in data_list] # gray out everything
[a.set_alpha(0.1) for a in data_list] # Fade more
ctable = {'setosa':'r','versicolor':'g','virginica':'b'} # color dictionary for plotting flowers
fields = ['slen','swid','plen','pwid'] # iterate over fields
for i,ii in enumerate(fields):
for j,jj in enumerate(fields):
#self.faxs[j,i].scatter(self.data[ii][points], self.data[jj][points], c='k', s=20)
temp = self.faxs[j,i].scatter(self.data[ii][points], self.data[jj][points],
c=[ctable[x] for x in flowers.species[points]],
edgecolor='none', alpha=0.4, s=20)
self.brushed_data.append(temp)
def get_brushed_data(self):
return self.brushed_data
if __name__ == "__main__":
a = brusher(fig, faxs, flowers); # on we go!
In [ ]: