In [ ]:
import json

import plotly.offline as py
py.init_notebook_mode()
import plotly.graph_objs as go

In [ ]:
def printPlot(title, data, intuitive=True):    
    traces = []
    for label, resultsByUnionSize in data:
        x = []
        y = []
        for unionSize, results in sorted(resultsByUnionSize,
                                         key=lambda x: x[0]):
            s = str(unionSize)
            x += [s]*len(results)
            
            if intuitive:
                y += list(unionSize + extra for extra in results)
            else:
                y += list(results)

        traces.append(go.Box(
            x=x,
            y=y,
            name=label,
            boxpoints = False
        ))

    py.iplot({
        "data": traces,
        "layout": {
            'title': title,
            'xaxis': {
                'title': '# SDRs in',
                'zeroline': False
            },
            'yaxis': {
                'title': ('# SDRs out' if intuitive
                          else '# additional SDRs out'),
                'zeroline': False
            },
            'boxmode': 'group'
        }
    })
    

def fetchData(folder, n, w, threshold, cellsPerColumn):
    filename = "{}/n_{}_w_{}_threshold_{}_cellsPerColumn_{}.json".format(
      folder, n, w, threshold, cellsPerColumn)        

    with open(filename, "r") as fileIn:
        return json.load(fileIn)

In [ ]:
printPlot(
    "Operating on unions: Dense minicolumn SDRs (n=15, w=10, threshold=8)",
    [["{} cells per column".format(cellsPerColumn),
      fetchData("data/default", 15, 10, 8, cellsPerColumn)]
     for cellsPerColumn in [7, 10, 13, 16]])

In [ ]:
printPlot(
    "Operating on unions: Varying density of minicolumn SDRs (n=15, cellsPerColumn=10)",
    [["w={}, threshold={}".format(w, threshold),
      fetchData("data/default", 15, w, threshold, 10)]
     for w, threshold in [(3, 2), (4, 3), (5, 4), (10, 8)]])

In [ ]:
printPlot(
    "Operating on unions: TM cell SDRs (n=15, w=10, threshold=8, 10 cells per column)",
    [["Regular TM SDRs", fetchData("data/default", 15, 10, 8, 10)],
     ["Improved TM SDRs", fetchData("data/improved-tm-sdrs", 15, 10, 8, 10)]])

In [ ]:
printPlot(
    "Operating on unions: With improved TM SDRs (n=15, w=10, threshold=8)",
    [["{} cells per column".format(cellsPerColumn),
      fetchData("data/improved-tm-sdrs", 15, 10, 8, cellsPerColumn)]
     for cellsPerColumn in [7, 10, 13, 16]])