In [1]:
# ipyplotly
from plotly.graph_objs import FigureWidget
from plotly.callbacks import Points, InputDeviceState

# pandas
import pandas as pd

# numpy
import numpy as np

# scikit learn
from sklearn import datasets

# ipywidgets
from ipywidgets import HBox, VBox, Button

# functools
from functools import partial


/usr/local/lib/python3.7/site-packages/sklearn/utils/__init__.py:4: DeprecationWarning:

Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working


In [2]:
# Load iris dataset
iris_data = datasets.load_iris()
feature_names = [name.replace(' (cm)', '').replace(' ', '_') for name in iris_data.feature_names]
iris_df = pd.DataFrame(iris_data.data, columns=feature_names)
iris_class = iris_data.target + 1
N = len(iris_df)
iris_df.head()


Out[2]:
sepal_length sepal_width petal_length petal_width
0 5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2

In [3]:
f1 = FigureWidget(**{
    'data': [{'marker': {'cmax': 1.5,
                         'cmin': -0.5,
                         'color': np.zeros(N),
                         'colorbar': {'ticks': 'outside', 'ticktext': ['unselected', 'selected'], 'tickvals': [0, 1]},
                         'colorscale': [[0, 'lightgray'], [0.5, 'lightgray'],
                                        [0.5, 'red'], [1, 'red']],
                         'showscale': False,
                         'size': 8},
              'mode': 'markers',
              'type': 'scatter',
              'uid': '9fb32f14-6f15-11e8-973c-645aede86e5b',
              'x': iris_df.sepal_length,
              'y': iris_df.petal_width}],
    'layout': {'dragmode': 'lasso', 'width': 500}
})
scatt1 = f1.data[0]
f1



In [4]:
f2 = FigureWidget(**{
    'data': [{'marker': {'cmax': 1.5,
                         'cmin': -0.5,
                         'color': np.zeros(N),
                         'colorbar': {'ticks': 'outside', 'ticktext': ['unselected', 'selected'], 'tickvals': [0, 1]},
                         'colorscale': [[0, 'lightgray'], [0.5, 'lightgray'],
                                        [0.5, 'red'], [1, 'red']],
                         'size': 8},
              'mode': 'markers',
              'type': 'scatter',
              'uid': 'e3f13218-6f15-11e8-b2a9-645aede86e5b',
              'x': iris_df.petal_length,
              'y': iris_df.sepal_width}],
    'layout': {'dragmode': 'lasso', 'width': 500, 'xaxis': {'title': 'petal_length'}, 'yaxis': {'title': 'sepal_width'}}
})
scatt2 = f2.data[0]
f2



In [5]:
# Configure brush on both plots to update both plots
def brush(trace, points, state):
    inds = np.array(points.point_inds)
    if inds.size:
        selected = scatt1.marker.color.copy()
        selected[inds] = 1
        scatt1.marker.color = selected
        scatt2.marker.color = selected    
    
scatt1.on_selection(brush)
scatt2.on_selection(brush)

In [8]:
# Reset brush
def reset_brush(btn):
    selected = np.zeros(iris_class.size)
    scatt1.marker.color = selected
    scatt2.marker.color = selected
    
# Create reset button
button = Button(description="clear")
button.on_click(reset_brush)
button



In [9]:
dashboard = VBox([HBox([f1, f2]), button])
dashboard



In [ ]: