In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fingerprint import GraphFingerprint
from wb import WeightsAndBiases
from itertools import combinations
from random import choice, sample
from numpy.random import permutation
from sklearn.ensemble import RandomForestRegressor
from sklearn.cross_validation import train_test_split, ShuffleSplit, cross_val_score
from sklearn.preprocessing import LabelBinarizer
from autograd import grad
from time import time

import autograd.numpy as np
import networkx as nx
import math
import matplotlib.pyplot as plt
from numba import jit

In [3]:
shapes = dict()
shapes[0] = 10
shapes[1] = 10
shapes[2] = 10
wb = WeightsAndBiases(2, shapes)
# wb[0]

In [9]:
def make_random_graph(nodes, n_edges, features_dict):
    """
    Makes a randomly connected graph. 
    """
    
    G = nx.Graph()
    for n in nodes:
        G.add_node(n, features=features_dict[n])
    
    for i in range(n_edges):
        u, v = sample(G.nodes(), 2)
        G.add_edge(u, v)
        
    return G

# features_dict will look like this:
# {0: array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
#  1: array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0]),
#  2: array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0]),
#  3: array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0]),
#  4: array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0]),
#  5: array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0]),
#  6: array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0]),
#  7: array([0, 0, 0, 0, 0, 0, 0, 1, 0, 0]),
#  8: array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]),
#  9: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1])}

all_nodes = [i for i in range(10)]    
lb = LabelBinarizer()
features_dict = {i:lb.fit_transform(all_nodes)[i] for i in all_nodes}

G = make_random_graph(sample(all_nodes, 6), 5, features_dict)
G.edges(data=True)
# G.nodes(data=True)


Out[9]:
[(0, 5, {}), (4, 8, {}), (4, 5, {}), (5, 8, {})]

In [10]:
def score(G):
    """
    The regressable score for each graph will be the sum of the 
    (square root of each node + the sum of its neighbors.)
    """
    sum_score = 0
    for n, d in G.nodes(data=True):
        sum_score += math.sqrt(n)
        
        for nbr in G.neighbors(n):
            sum_score += nbr
    return sum_score

score(G)


Out[10]:
49.21075947218795

In [11]:
syngraphs = [make_random_graph(sample(all_nodes, 6), 5, features_dict) for i in range(1000)]

In [12]:
len(syngraphs)


Out[12]:
1000

Set up a learning scenario where...


In [13]:
fingerprints = np.zeros((len(syngraphs), 10))

for i, g in enumerate(syngraphs):
    gfp = GraphFingerprint(g, 2, shapes)
    fp = gfp.compute_fingerprint(wb.vect, wb.unflattener)
    fingerprints[i] = fp

In [15]:
import pandas as pd
X = pd.DataFrame(np.array(fingerprints))
Y = [score(g) for g in syngraphs]
Y


Out[15]:
[48.928198407402256,
 74.20622924337965,
 54.33788174096706,
 47.15973615609376,
 57.74552259372066,
 48.099771282656064,
 56.47870866461908,
 38.24150542378974,
 56.29603285093747,
 50.00996767509825,
 52.88839199818388,
 54.21075947218795,
 57.51398484502916,
 37.62044280575275,
 44.29603285093748,
 22.382332347441764,
 49.57394971846685,
 34.74552259372066,
 71.47870866461908,
 37.44229722087945,
 35.83182209022495,
 43.21075947218795,
 48.79201568100656,
 43.06335983891644,
 46.685557720282965,
 70.15973615609374,
 50.337881740967056,
 57.88839199818388,
 54.974691494688166,
 60.331309031347566,
 33.028083658506354,
 48.33130903134756,
 49.028083658506354,
 56.24603565259803,
 50.33130903134756,
 67.33788174096706,
 75.57394971846685,
 73.62044280575276,
 48.69213042990247,
 42.57394971846685,
 61.928198407402256,
 51.61387009613326,
 52.21075947218795,
 70.06993254853595,
 55.12445997568367,
 47.24150542378975,
 68.92819840740225,
 59.18154055035206,
 46.47870866461908,
 46.21075947218796,
 43.59575411272515,
 58.66024921497113,
 46.337881740967056,
 22.595754112725153,
 71.51398484502916,
 70.51398484502917,
 41.44229722087945,
 42.47870866461908,
 62.47757340128953,
 59.62044280575276,
 64.44229722087945,
 38.47757340128953,
 42.337881740967056,
 69.69213042990248,
 55.09524105384777,
 37.099771282656064,
 45.028083658506354,
 51.424181237471345,
 35.928198407402256,
 31.099771282656064,
 70.92366817859397,
 46.24603565259804,
 40.85651078325254,
 59.47757340128954,
 58.337881740967056,
 41.928198407402256,
 46.424181237471345,
 65.06993254853595,
 41.20622924337965,
 45.59575411272515,
 42.15973615609376,
 39.62044280575275,
 48.24150542378975,
 75.15973615609374,
 41.15973615609375,
 37.79201568100657,
 36.009967675098245,
 30.37780211863347,
 61.57394971846684,
 60.69213042990246,
 47.181540550352054,
 51.424181237471345,
 37.50945461622087,
 76.89178696366264,
 45.79654590981486,
 62.928198407402256,
 57.82729186141665,
 30.613870096133258,
 77.33788174096705,
 55.06993254853594,
 63.337881740967056,
 73.44229722087945,
 52.923668178593964,
 54.79654590981486,
 46.028083658506354,
 44.79654590981486,
 46.82729186141665,
 67.89178696366264,
 55.06993254853594,
 71.21075947218796,
 51.424181237471345,
 49.82729186141665,
 48.59575411272515,
 61.028083658506354,
 57.513984845029164,
 41.099771282656064,
 39.47417843581078,
 67.65571898616284,
 60.88839199818388,
 45.47870866461908,
 38.099771282656064,
 47.827291861416654,
 38.337881740967056,
 55.20622924337966,
 59.83182209022494,
 72.33788174096706,
 60.560477932315074,
 75.65571898616284,
 51.29603285093747,
 37.24150542378975,
 41.20622924337966,
 36.88181928856438,
 52.74552259372065,
 39.74552259372066,
 32.59575411272515,
 57.61387009613326,
 42.62044280575275,
 60.06335983891643,
 41.50945461622087,
 52.92819840740225,
 62.65571898616284,
 57.337881740967056,
 67.66024921497113,
 49.24603565259804,
 58.47757340128953,
 53.83182209022495,
 70.00996767509824,
 48.424181237471345,
 53.24603565259804,
 29.595754112725153,
 40.685557720282965,
 43.38233234744176,
 44.69213042990246,
 61.44229722087945,
 46.12445997568366,
 55.66024921497114,
 55.424181237471345,
 33.83182209022495,
 51.61387009613326,
 46.37780211863347,
 47.29603285093747,
 64.71024641331057,
 45.24603565259803,
 57.33788174096706,
 41.59575411272515,
 62.21075947218795,
 86.57394971846685,
 58.20622924337966,
 49.12445997568367,
 56.62044280575275,
 56.337881740967056,
 64.42418123747134,
 52.424181237471345,
 39.928198407402256,
 61.51398484502916,
 58.06993254853594,
 55.62044280575276,
 64.97469149468816,
 43.47757340128954,
 49.62044280575276,
 55.47757340128953,
 53.24150542378975,
 44.74552259372066,
 41.21075947218795,
 62.009967675098245,
 38.337881740967056,
 53.62044280575275,
 44.71024641331057,
 37.83182209022494,
 32.424181237471345,
 46.24603565259804,
 46.65028153987288,
 48.69213042990246,
 68.89178696366264,
 43.009967675098245,
 37.650281539872886,
 54.29603285093748,
 66.33788174096705,
 64.68555772028296,
 49.424181237471345,
 38.83182209022494,
 48.856510783252546,
 59.51398484502916,
 40.028083658506354,
 55.856510783252546,
 44.88839199818388,
 45.86370330515628,
 48.83182209022494,
 59.41760852785184,
 57.62044280575276,
 52.62044280575276,
 56.974691494688166,
 56.62044280575276,
 35.65571898616284,
 40.47757340128953,
 64.24150542378975,
 42.028083658506354,
 40.69213042990246,
 37.71024641331057,
 48.06335983891644,
 53.06449510224598,
 58.62044280575276,
 69.09977128265606,
 42.61387009613326,
 65.74552259372066,
 51.59575411272515,
 40.47870866461908,
 54.923668178593964,
 57.923668178593964,
 45.12445997568367,
 68.92366817859397,
 49.513984845029164,
 44.69213042990246,
 55.74552259372065,
 45.66024921497114,
 45.06335983891643,
 31.382332347441764,
 64.15973615609374,
 43.424181237471345,
 53.099771282656064,
 42.337881740967056,
 31.595754112725153,
 43.028083658506354,
 72.30600052603573,
 60.56047793231507,
 35.96811878506867,
 75.92366817859397,
 49.69213042990246,
 39.62044280575276,
 58.424181237471345,
 48.66024921497113,
 68.89178696366264,
 45.61387009613326,
 53.37780211863347,
 44.06335983891644,
 56.509454616220864,
 55.24603565259803,
 63.337881740967056,
 50.61387009613326,
 38.974691494688166,
 41.099771282656064,
 66.47757340128953,
 61.88839199818388,
 44.41760852785185,
 50.50945461622087,
 52.74552259372066,
 40.27791686752937,
 63.20622924337967,
 50.88839199818388,
 62.12445997568367,
 45.009967675098245,
 65.65571898616284,
 33.24150542378975,
 62.21075947218795,
 57.57394971846685,
 40.69213042990246,
 69.74552259372066,
 55.509454616220864,
 60.337881740967056,
 50.79654590981486,
 53.09524105384777,
 54.424181237471345,
 54.509454616220864,
 49.06335983891643,
 38.099771282656064,
 55.891786963662625,
 42.65028153987288,
 42.66024921497114,
 46.83182209022494,
 54.24603565259804,
 60.028083658506354,
 56.69213042990247,
 58.47417843581078,
 50.856510783252546,
 36.24603565259804,
 35.74552259372066,
 64.57394971846685,
 47.424181237471345,
 65.33788174096705,
 45.20622924337967,
 51.24603565259804,
 44.69213042990247,
 34.12445997568366,
 22.246035652598035,
 81.44229722087945,
 52.099771282656064,
 59.337881740967056,
 42.66024921497113,
 68.12445997568366,
 37.47870866461908,
 32.29603285093748,
 34.38233234744176,
 53.442297220879446,
 37.928198407402256,
 66.15973615609374,
 55.65571898616283,
 42.15973615609376,
 52.15973615609376,
 52.20622924337966,
 46.974691494688166,
 36.29603285093748,
 54.424181237471345,
 45.24150542378975,
 38.69213042990246,
 47.24150542378975,
 48.377802118633475,
 46.65571898616283,
 66.92819840740225,
 50.88181928856438,
 56.79654590981485,
 49.337881740967056,
 26.124459975683667,
 41.62044280575276,
 67.27791686752937,
 37.24150542378975,
 54.24603565259804,
 46.424181237471345,
 28.595754112725153,
 58.79201568100656,
 46.88181928856438,
 77.88839199818388,
 45.65571898616284,
 74.00996767509824,
 48.47870866461908,
 57.88839199818388,
 53.974691494688166,
 30.92366817859396,
 42.79654590981486,
 64.66024921497113,
 34.83182209022495,
 27.863703305156275,
 63.009967675098245,
 38.71024641331057,
 58.33130903134756,
 43.37780211863347,
 63.47870866461907,
 55.69213042990247,
 66.33788174096706,
 54.83182209022494,
 25.509454616220864,
 72.57394971846685,
 66.44229722087945,
 48.62044280575276,
 48.86370330515628,
 28.595754112725153,
 47.06449510224598,
 41.12445997568366,
 57.442297220879446,
 44.68555772028297,
 58.059964873437686,
 49.12445997568366,
 35.181540550352054,
 42.009967675098245,
 57.009967675098245,
 55.44229722087945,
 70.51398484502916,
 35.41760852785184,
 68.57394971846685,
 63.88839199818388,
 63.12445997568367,
 62.424181237471345,
 77.92819840740225,
 52.856510783252546,
 32.79201568100656,
 53.62044280575275,
 59.38233234744176,
 37.41760852785185,
 39.337881740967056,
 36.33130903134756,
 71.65571898616284,
 49.62044280575275,
 30.745522593720654,
 59.009967675098245,
 57.29603285093748,
 39.928198407402256,
 50.24264068711929,
 73.21075947218796,
 56.65571898616283,
 54.513984845029164,
 42.337881740967056,
 57.59575411272515,
 53.09524105384777,
 60.424181237471345,
 50.65571898616284,
 33.96811878506867,
 49.242640687119284,
 34.337881740967056,
 58.79654590981486,
 60.20622924337966,
 41.88181928856438,
 65.09977128265606,
 56.59575411272515,
 50.028083658506354,
 71.92366817859397,
 54.44229722087945,
 59.928198407402256,
 45.59575411272515,
 69.51398484502917,
 67.15973615609374,
 47.50945461622087,
 48.85651078325254,
 44.15973615609375,
 53.509454616220864,
 33.24150542378975,
 46.47757340128953,
 48.337881740967056,
 47.891786963662625,
 66.21075947218796,
 47.856510783252546,
 53.41760852785185,
 63.74552259372066,
 53.21075947218795,
 53.33130903134756,
 56.06993254853594,
 49.57394971846685,
 39.29603285093748,
 59.82729186141665,
 56.923668178593964,
 51.923668178593964,
 46.424181237471345,
 69.50945461622086,
 65.41760852785185,
 54.62044280575276,
 44.12445997568367,
 60.33788174096706,
 61.891786963662625,
 34.06335983891643,
 67.30600052603573,
 43.099771282656064,
 43.65571898616284,
 67.85651078325255,
 57.06335983891644,
 47.424181237471345,
 73.06993254853595,
 52.33788174096706,
 59.62044280575276,
 62.856510783252546,
 50.47757340128953,
 43.79654590981485,
 62.51398484502916,
 50.29603285093748,
 48.47757340128954,
 54.66024921497113,
 54.028083658506354,
 50.88181928856438,
 69.51398484502917,
 54.337881740967056,
 34.69213042990246,
 73.21075947218796,
 61.12445997568366,
 59.74552259372066,
 32.41760852785184,
 61.47757340128953,
 47.928198407402256,
 42.923668178593964,
 45.242640687119284,
 52.83182209022494,
 55.424181237471345,
 62.29603285093748,
 48.974691494688166,
 40.891786963662625,
 57.44229722087945,
 49.59575411272515,
 50.09977128265606,
 57.06993254853594,
 55.50945461622087,
 43.61387009613326,
 53.62044280575275,
 53.89178696366263,
 49.928198407402256,
 50.12445997568367,
 44.83182209022494,
 54.96811878506867,
 43.424181237471345,
 31.974691494688166,
 59.21075947218795,
 53.86370330515628,
 44.65571898616284,
 51.92819840740225,
 56.88839199818388,
 29.417608527851844,
 48.891786963662625,
 47.928198407402256,
 63.509454616220864,
 75.57394971846685,
 42.06335983891643,
 74.44229722087945,
 38.06335983891643,
 67.83182209022493,
 40.41760852785185,
 33.83182209022495,
 36.974691494688166,
 72.65571898616284,
 34.62044280575275,
 74.79654590981485,
 51.21075947218795,
 48.79654590981486,
 40.56047793231507,
 39.62044280575276,
 47.337881740967056,
 46.33130903134756,
 45.21075947218795,
 54.57394971846684,
 71.33788174096706,
 27.974691494688166,
 33.38233234744176,
 41.509454616220864,
 69.71024641331057,
 62.181540550352054,
 53.06449510224598,
 26.02808365850635,
 47.00996767509825,
 68.51398484502916,
 68.51398484502917,
 52.82729186141665,
 67.33788174096706,
 58.06335983891643,
 58.009967675098245,
 46.009967675098245,
 32.37780211863347,
 33.38233234744176,
 52.09524105384777,
 35.65571898616283,
 52.856510783252546,
 23.595754112725153,
 34.028083658506354,
 76.33130903134756,
 45.38233234744176,
 68.51398484502917,
 53.74552259372066,
 69.88839199818388,
 55.424181237471345,
 51.65571898616283,
 58.12445997568367,
 53.66024921497113,
 46.337881740967056,
 77.12445997568366,
 66.12445997568366,
 52.61387009613326,
 39.099771282656064,
 57.06335983891643,
 44.24603565259803,
 41.79201568100656,
 64.62044280575276,
 72.21075947218796,
 37.306000526035724,
 28.974691494688166,
 79.24603565259804,
 65.12445997568366,
 77.71024641331057,
 46.337881740967056,
 61.44229722087945,
 44.06335983891643,
 36.71024641331057,
 27.888391998183877,
 34.974691494688166,
 46.06335983891643,
 29.210759472187956,
 54.099771282656064,
 53.24603565259804,
 62.24150542378975,
 62.028083658506354,
 56.74552259372066,
 47.74552259372066,
 47.74552259372066,
 49.028083658506354,
 58.009967675098245,
 35.59575411272515,
 45.424181237471345,
 63.928198407402256,
 49.12445997568367,
 39.12445997568367,
 41.66024921497114,
 50.20622924337967,
 60.61387009613326,
 39.79654590981485,
 61.06335983891644,
 32.06335983891643,
 27.613870096133258,
 55.509454616220864,
 44.028083658506354,
 60.424181237471345,
 65.89178696366264,
 46.38233234744176,
 72.51398484502917,
 60.06335983891644,
 41.82729186141665,
 48.69213042990246,
 24.888391998183877,
 46.099771282656064,
 54.62044280575276,
 60.099771282656064,
 49.74552259372066,
 41.47870866461908,
 53.88181928856438,
 39.15973615609375,
 75.74552259372066,
 57.21075947218795,
 40.83182209022495,
 50.47757340128954,
 60.79201568100657,
 31.796545909814856,
 28.831822090224943,
 48.24603565259804,
 37.47870866461908,
 41.65571898616283,
 50.974691494688166,
 72.42418123747134,
 42.33130903134756,
 34.099771282656064,
 53.44229722087945,
 46.92819840740225,
 51.51398484502916,
 43.29603285093748,
 33.79201568100656,
 38.62044280575275,
 58.337881740967056,
 64.82729186141665,
 45.47870866461908,
 32.74552259372066,
 30.83182209022494,
 45.337881740967056,
 53.82729186141665,
 48.66024921497113,
 40.424181237471345,
 54.79654590981486,
 60.928198407402256,
 58.44229722087945,
 42.509454616220864,
 38.974691494688166,
 34.974691494688166,
 69.15973615609374,
 28.069932548535935,
 65.12445997568366,
 39.24150542378974,
 36.33130903134756,
 43.62044280575276,
 47.82729186141665,
 59.15973615609376,
 37.88839199818388,
 38.47870866461908,
 58.509454616220864,
 52.71024641331057,
 55.20622924337966,
 45.181540550352054,
 53.65571898616283,
 43.33130903134756,
 39.88839199818388,
 60.89178696366263,
 74.15973615609374,
 29.595754112725153,
 67.42418123747134,
 50.474178435810785,
 44.513984845029164,
 63.028083658506354,
 31.33130903134756,
 64.97469149468816,
 56.928198407402256,
 49.41760852785185,
 66.57394971846685,
 50.59575411272515,
 25.099771282656064,
 40.79201568100656,
 27.595754112725153,
 30.146264369941974,
 39.74552259372066,
 73.12445997568366,
 33.66024921497113,
 42.29603285093748,
 70.65571898616284,
 72.06335983891644,
 50.06335983891644,
 38.059964873437686,
 50.82729186141665,
 69.15973615609376,
 49.79654590981486,
 76.44229722087945,
 52.44229722087945,
 55.62044280575276,
 55.12445997568367,
 42.66024921497113,
 50.47870866461908,
 50.50945461622087,
 42.337881740967056,
 58.79654590981486,
 49.20622924337966,
 61.89178696366263,
 60.62044280575276,
 40.099771282656064,
 51.47757340128953,
 52.028083658506354,
 53.88839199818388,
 37.88839199818388,
 62.337881740967056,
 54.064495102245985,
 59.27791686752937,
 43.337881740967056,
 41.47757340128954,
 60.337881740967056,
 53.24603565259804,
 43.15973615609375,
 60.099771282656064,
 65.09977128265606,
 63.38233234744176,
 44.12445997568367,
 43.83182209022494,
 73.85651078325255,
 69.33788174096705,
 54.856510783252546,
 35.099771282656064,
 69.89178696366264,
 51.29603285093748,
 42.83182209022494,
 55.09524105384777,
 61.61387009613326,
 40.974691494688166,
 63.29603285093748,
 53.06335983891644,
 43.24150542378975,
 60.306000526035724,
 47.66024921497113,
 37.74552259372066,
 49.028083658506354,
 59.69213042990246,
 38.337881740967056,
 30.745522593720654,
 62.74552259372066,
 44.62044280575276,
 71.15973615609374,
 55.66024921497113,
 38.009967675098245,
 61.827291861416654,
 48.21075947218795,
 46.099771282656064,
 60.74552259372066,
 52.57394971846685,
 58.028083658506354,
 56.424181237471345,
 39.24150542378975,
 52.424181237471345,
 49.028083658506354,
 53.009967675098245,
 63.33788174096706,
 72.92366817859397,
 44.59575411272515,
 72.89178696366264,
 41.86370330515628,
 46.24150542378975,
 40.56047793231507,
 56.06993254853594,
 65.33130903134756,
 32.41760852785184,
 77.06335983891644,
 50.028083658506354,
 53.09524105384777,
 46.79654590981485,
 55.064495102245985,
 41.06335983891643,
 38.009967675098245,
 63.69213042990247,
 47.331309031347566,
 43.57394971846685,
 57.92366817859396,
 44.15973615609376,
 76.24603565259804,
 63.65571898616284,
 69.41760852785185,
 76.71024641331057,
 52.059964873437686,
 49.83182209022494,
 71.12445997568366,
 64.20622924337965,
 62.89178696366263,
 54.509454616220864,
 43.66024921497114,
 83.92366817859397,
 59.69213042990247,
 54.06993254853594,
 42.509454616220864,
 53.47757340128953,
 70.47757340128953,
 49.62044280575275,
 41.856510783252546,
 39.424181237471345,
 59.24150542378975,
 46.69213042990246,
 55.424181237471345,
 47.06449510224598,
 58.83182209022494,
 52.57394971846685,
 66.50945461622086,
 69.24603565259804,
 67.44229722087945,
 52.79201568100656,
 68.06335983891644,
 43.61387009613326,
 30.831822090224943,
 63.89178696366263,
 45.974691494688166,
 73.15973615609376,
 44.009967675098245,
 55.79654590981486,
 45.59575411272515,
 57.009967675098245,
 45.79654590981486,
 52.59575411272515,
 55.82729186141665,
 50.424181237471345,
 42.59575411272515,
 64.00996767509824,
 68.92819840740225,
 44.38233234744176,
 22.831822090224943,
 53.47870866461908,
 44.47757340128953,
 56.424181237471345,
 56.33130903134756,
 68.57394971846685,
 77.24603565259804,
 73.44229722087945,
 34.974691494688166,
 54.57394971846685,
 57.61387009613326,
 57.71024641331057,
 40.82729186141665,
 46.47757340128953,
 37.337881740967056,
 67.92366817859397,
 46.69213042990247,
 47.15973615609375,
 62.424181237471345,
 43.974691494688166,
 51.69213042990246,
 66.71024641331057,
 35.56047793231507,
 74.89178696366264,
 73.69213042990248,
 73.62044280575276,
 47.424181237471345,
 73.62044280575276,
 58.88839199818388,
 41.74552259372066,
 49.856510783252546,
 62.928198407402256,
 59.89178696366263,
 46.099771282656064,
 77.62044280575276,
 47.06335983891643,
 18.246035652598035,
 38.028083658506354,
 48.61387009613326,
 37.62044280575276,
 51.82729186141664,
 64.24150542378975,
 33.24150542378975,
 44.21075947218795,
 48.74552259372066,
 41.21075947218795,
 42.442297220879446,
 62.331309031347566,
 51.41760852785185,
 31.863703305156275,
 48.21075947218795,
 79.12445997568366,
 42.424181237471345,
 39.38233234744176,
 47.331309031347566,
 55.59575411272515,
 45.86370330515627,
 61.15973615609376,
 52.509454616220864,
 34.86370330515628,
 70.15973615609374,
 75.71024641331057,
 48.51398484502916,
 57.009967675098245,
 45.06993254853594,
 54.79654590981486,
 61.65571898616284,
 64.24150542378975,
 60.509454616220864,
 51.009967675098245,
 42.59575411272515,
 47.71024641331057,
 52.47417843581078,
 51.974691494688166,
 50.059964873437686,
 56.06335983891644,
 27.478708664619077,
 54.65571898616284,
 39.560477932315074,
 49.923668178593964,
 43.66024921497113,
 63.57394971846685,
 73.06335983891644,
 41.928198407402256,
 42.83182209022494,
 64.65571898616284,
 64.88839199818388,
 62.028083658506354,
 16.146264369941974,
 47.50945461622087,
 52.21075947218795,
 37.009967675098245,
 33.62044280575276,
 61.24603565259804,
 60.424181237471345,
 67.71024641331057,
 68.21075947218796,
 41.331309031347566,
 82.85651078325255,
 38.424181237471345,
 46.83182209022494,
 59.337881740967056,
 44.24150542378975,
 72.89178696366264,
 45.71024641331057,
 57.65571898616284,
 45.12445997568367,
 63.57394971846685,
 35.38233234744176,
 51.83182209022494,
 58.33788174096706,
 57.82729186141665,
 53.928198407402256,
 38.650281539872886,
 50.028083658506354,
 52.74552259372066,
 44.20622924337967,
 61.61387009613326,
 76.62044280575276,
 36.47870866461908,
 63.24603565259804,
 54.974691494688166,
 71.00996767509825,
 51.83182209022494,
 63.61387009613326,
 52.57394971846685,
 46.65028153987288,
 50.79654590981486,
 63.62044280575275,
 54.29603285093748,
 62.15973615609375,
 81.51398484502916,
 65.06993254853595,
 55.57394971846685,
 57.69213042990246,
 42.028083658506354,
 45.50945461622087,
 49.856510783252546,
 67.15973615609374,
 51.242640687119284,
 67.51398484502917,
 46.66024921497114,
 60.33788174096706,
 66.42418123747134,
 63.424181237471345,
 43.88839199818388,
 63.06335983891643,
 54.15973615609376,
 54.06449510224598,
 63.65571898616284,
 44.47870866461908,
 40.59575411272515,
 62.47870866461908,
 43.928198407402256,
 59.891786963662625,
 65.61387009613325,
 85.65571898616284,
 40.66024921497113,
 48.88181928856438,
 47.74552259372066,
 36.79201568100657,
 69.00996767509824,
 52.24150542378975,
 37.44229722087945,
 47.44229722087945,
 39.33130903134756,
 32.62044280575276,
 45.65571898616284]

In [16]:
# A simple test - the weights are random, so given the random weights, what is the prediction accuracy using
# random forest?

cv = ShuffleSplit(n=len(X), n_iter=10)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y)

In [17]:
rfr = RandomForestRegressor()
rfr.fit(X_train, Y_train)
# preds = np.rint(rfr.predict(X_test))
preds = rfr.predict(X_test)

from sklearn.metrics import mean_squared_error as mse

print(preds)
mse(preds, Y_test)


[ 43.67695261  47.3385884   57.87213266  58.46154494  49.30283591
  45.1910689   49.62011753  68.51271188  55.97949475  56.87368079
  51.20740201  52.82036204  39.96014671  45.98629516  38.3348929
  47.1899939   51.29653208  55.64410408  68.73036695  39.97548537
  52.03298246  41.89859511  54.11502223  66.73106219  42.47465259
  56.41343114  50.86075826  53.47209566  48.60856569  36.44817545
  55.41267125  41.57927364  58.88869173  62.32523781  63.10703161
  75.18985164  52.90939361  55.56826118  62.29001719  55.95944492
  56.5585834   42.65502314  62.62820979  40.00982918  37.80423368
  57.50874823  46.29515216  55.92392279  54.15469834  39.64106818
  53.82318679  46.13854437  52.59131282  60.15372546  48.14311435
  56.10719572  49.29198149  55.15306164  66.20213352  57.56156667
  66.99253059  44.18506805  56.10373662  49.92821683  54.3671003
  53.42356048  53.99011181  56.83291065  46.76118799  36.61827233
  42.63446111  66.27867992  39.85271934  60.13191935  38.0692856
  54.57874676  47.37484701  43.84539631  56.70546889  51.35510077
  49.57920632  37.05520848  51.8964331   58.02737701  46.63600319
  49.86150577  38.63074864  46.66771136  53.12403096  50.07452438
  38.10190102  46.21459807  50.76870765  54.44087193  58.48301302
  53.3638422   38.14920485  47.68789218  39.78682811  41.51614392
  44.02377177  57.16225328  52.17764252  62.08802305  51.23182208
  42.98168611  55.73226976  43.59157251  54.45910591  48.62390652
  58.6966194   46.68249657  54.91175987  54.58914378  47.0263595
  53.14986866  59.60631246  45.95284222  57.10140832  54.551652
  61.84456414  46.52084284  42.99929642  55.24944984  54.93744547
  51.331387    56.61458012  35.97433638  35.32652155  58.3978437
  63.96394481  60.35873052  55.18002067  52.95500519  41.71859388
  52.46177248  57.37440916  52.22502167  61.93022564  51.12551533
  55.75407636  64.89811274  46.73129142  50.91310802  45.79080922
  54.98016342  51.5205999   56.08514083  50.99883344  41.07666844
  52.46254705  41.02494078  50.41339754  53.77407135  53.14281294
  40.35771906  55.05472757  55.20970484  55.56705361  38.24967739
  42.84886535  32.81772047  52.86897388  57.6306999   55.27712146
  58.39259208  50.24231092  39.61694688  60.09866789  52.72871291
  42.46395135  52.63980131  57.06162536  49.56257591  49.98415497
  54.14455225  45.5337472   45.95636378  38.29409184  53.64034941
  37.75264331  38.87625617  44.53377605  43.35382542  43.42452462
  54.42663606  55.31602288  44.74331812  54.19877001  42.43889526
  63.71139152  44.97707791  61.93909247  50.11318092  51.07912591
  59.03352341  48.90482897  42.19757506  44.76781252  41.40338621
  52.58164255  55.68421745  50.62157447  50.84418827  36.9660053
  50.72017466  55.84985883  44.77503235  66.66118848  42.59267843
  63.12411564  51.0458201   52.94559305  49.58698129  52.93726948
  63.62874748  54.9673481   61.02959068  46.04476839  61.27303384
  57.35417679  63.70272942  57.86209403  60.95570867  36.66955182
  50.52022014  60.3720891   51.62001492  62.7139744   63.50866296
  56.81569042  54.68122702  55.46190725  62.77748982  48.89404186
  44.97291703  56.68162412  49.62861127  51.34122282  48.1313756
  58.02575184  57.63612049  50.57284753  52.44441094  51.30654984
  58.97982442  44.49676606  57.85544027  65.20860158  63.27982442]
Out[17]:
97.212967823908713

In [18]:
# How does this compare with randomly shuffled data?
mse(permutation(Y_test), Y_test)


Out[18]:
255.20357633857947

In [19]:
[i for i in zip(Y_test, preds)]


Out[19]:
[(44.15973615609376, 43.676952614980749),
 (51.92819840740225, 47.338588399393792),
 (69.15973615609374, 57.872132655583229),
 (61.51398484502916, 58.461544944567756),
 (49.242640687119284, 49.302835909230751),
 (53.24603565259804, 45.191068904758133),
 (22.831822090224943, 49.620117530665411),
 (72.92366817859397, 68.512711883478048),
 (72.89178696366264, 55.979494751521848),
 (45.12445997568367, 56.873680790125398),
 (54.29603285093748, 51.20740201163337),
 (36.29603285093748, 52.82036203784115),
 (37.099771282656064, 39.960146714030557),
 (35.41760852785184, 45.986295155139111),
 (26.124459975683667, 38.334892896194745),
 (38.009967675098245, 47.189993902282637),
 (62.928198407402256, 51.296532083596801),
 (53.06335983891644, 55.644104084057894),
 (46.028083658506354, 68.730366945905601),
 (52.33788174096706, 39.97548537008025),
 (60.331309031347566, 52.032982460519904),
 (39.33130903134756, 41.898595111413329),
 (43.337881740967056, 54.11502223219086),
 (68.06335983891644, 66.731062192738605),
 (41.47870866461908, 42.474652591556932),
 (59.62044280575276, 56.413431135668418),
 (57.71024641331057, 50.860758257501338),
 (60.424181237471345, 53.472095655182009),
 (47.891786963662625, 48.608565688831391),
 (33.62044280575276, 36.448175446727895),
 (48.856510783252546, 55.412671252930991),
 (51.69213042990246, 41.579273638199375),
 (62.009967675098245, 58.888691732673372),
 (54.06449510224598, 62.325237809967177),
 (60.509454616220864, 63.107031605810029),
 (69.41760852785185, 75.18985163542996),
 (35.56047793231507, 52.909393610462267),
 (53.41760852785185, 55.568261177284398),
 (71.15973615609374, 62.290017189154739),
 (69.24603565259804, 55.959444916350016),
 (27.595754112725153, 56.55858339730694),
 (32.62044280575276, 42.655023142574656),
 (52.059964873437686, 62.628209793351878),
 (45.24150542378975, 40.009829178833272),
 (42.009967675098245, 37.804233676083292),
 (46.974691494688166, 57.508748229527633),
 (54.33788174096706, 46.295152157455455),
 (64.24150542378975, 55.923922785755963),
 (48.79654590981486, 54.154698335025728),
 (37.928198407402256, 39.641068183254021),
 (44.88839199818388, 53.823186786645827),
 (57.51398484502916, 46.138544365857705),
 (63.337881740967056, 52.591312822665394),
 (86.57394971846685, 60.153725462038381),
 (49.12445997568367, 48.143114349682925),
 (52.424181237471345, 56.107195716526611),
 (41.44229722087945, 49.29198149047216),
 (58.83182209022494, 55.153061640228884),
 (71.00996767509825, 66.202133522281741),
 (67.30600052603573, 57.561566666701559),
 (79.12445997568366, 66.992530585176098),
 (41.65571898616283, 44.185068053704434),
 (56.928198407402256, 56.103736619666378),
 (52.59575411272515, 49.928216826486747),
 (67.42418123747134, 54.367100299058009),
 (55.06993254853594, 53.42356048462328),
 (47.424181237471345, 53.990111807074456),
 (51.974691494688166, 56.832910647184995),
 (55.856510783252546, 46.761187993510049),
 (40.83182209022495, 36.618272334912419),
 (47.15973615609376, 42.634461113160384),
 (60.06335983891643, 66.278679920690848),
 (45.47870866461908, 39.852719344709378),
 (63.57394971846685, 60.131919348071719),
 (39.24150542378974, 38.069285595376279),
 (51.83182209022494, 54.578746758811725),
 (38.650281539872886, 47.374847007531564),
 (52.099771282656064, 43.845396314579503),
 (44.83182209022494, 56.705468893103827),
 (71.21075947218796, 51.355100773522928),
 (49.62044280575275, 49.579206321644911),
 (44.15973615609375, 37.05520848224252),
 (43.33130903134756, 51.896433101967162),
 (49.028083658506354, 58.027377014828559),
 (62.15973615609375, 46.636003190037528),
 (37.47870866461908, 49.861505767924108),
 (41.82729186141665, 38.630748637738833),
 (57.82729186141665, 46.66771136409519),
 (45.59575411272515, 53.124030962013954),
 (41.15973615609375, 50.074524376225042),
 (34.38233234744176, 38.101901015968203),
 (43.424181237471345, 46.214598071559848),
 (64.88839199818388, 50.768707649317697),
 (56.59575411272515, 54.440871929028674),
 (66.47757340128953, 58.48301301974368),
 (43.61387009613326, 53.363842202480669),
 (33.66024921497113, 38.149204850928996),
 (57.74552259372066, 47.687892180109813),
 (31.099771282656064, 39.786828106940717),
 (36.24603565259804, 41.516143922878442),
 (48.66024921497113, 44.023771765839264),
 (63.65571898616284, 57.162253282200652),
 (67.66024921497113, 52.177642522633981),
 (48.79201568100656, 62.088023050117705),
 (53.928198407402256, 51.231822075476018),
 (33.24150542378975, 42.981686105505148),
 (48.66024921497113, 55.732269759177086),
 (68.92819840740225, 43.591572512022587),
 (37.24150542378975, 54.459105912984647),
 (45.06335983891643, 48.62390652112353),
 (50.00996767509825, 58.69661940235595),
 (33.24150542378975, 46.682496572626192),
 (55.09524105384777, 54.911759867989566),
 (63.337881740967056, 54.58914378008437),
 (45.59575411272515, 47.026359500654749),
 (37.41760852785185, 53.1498686618433),
 (69.50945461622086, 59.606312464454561),
 (47.928198407402256, 45.952842220861513),
 (47.06449510224598, 57.101408323261239),
 (37.44229722087945, 54.551651998097967),
 (64.24150542378975, 61.844564140073587),
 (46.33130903134756, 46.520842840520821),
 (47.424181237471345, 42.99929641591136),
 (53.37780211863347, 55.249449840392074),
 (50.88839199818388, 54.937445468859721),
 (59.83182209022494, 51.331387000287535),
 (69.33788174096705, 56.614580123867505),
 (38.337881740967056, 35.974336381881173),
 (42.337881740967056, 35.326521553450505),
 (50.82729186141665, 58.397843704819003),
 (71.92366817859397, 63.963944814256152),
 (70.51398484502916, 60.358730521052074),
 (52.028083658506354, 55.180020670188568),
 (56.923668178593964, 52.955005194882311),
 (40.099771282656064, 41.718593883767944),
 (51.009967675098245, 52.461772479520981),
 (58.79201568100656, 57.374409155937009),
 (64.57394971846685, 52.225021671858748),
 (63.74552259372066, 61.930225639773809),
 (50.856510783252546, 51.125515325286628),
 (71.12445997568366, 55.754076355431039),
 (54.856510783252546, 64.898112743995355),
 (48.69213042990246, 46.73129142164693),
 (43.928198407402256, 50.913108022074574),
 (46.24603565259804, 45.790809215400955),
 (50.47757340128953, 54.980163419328598),
 (65.89178696366264, 51.52059989808339),
 (33.38233234744176, 56.085140827760242),
 (48.47870866461908, 50.998833439193959),
 (46.47757340128953, 41.07666843797886),
 (50.61387009613326, 52.462547051257353),
 (37.44229722087945, 41.024940782217065),
 (42.33130903134756, 50.413397538256234),
 (62.47757340128953, 53.774071353344212),
 (45.24603565259803, 53.142812943473963),
 (41.099771282656064, 40.357719064543332),
 (63.88839199818388, 55.054727568727969),
 (58.06993254853594, 55.209704844923998),
 (41.21075947218795, 55.567053610845925),
 (43.83182209022494, 38.249677386240485),
 (32.37780211863347, 42.848865354381125),
 (18.246035652598035, 32.817720474697566),
 (58.059964873437686, 52.86897387862723),
 (36.974691494688166, 57.630699902501114),
 (65.09977128265606, 55.277121460859703),
 (42.509454616220864, 58.392592084117197),
 (39.74552259372066, 50.242310915696734),
 (42.83182209022494, 39.616946878431023),
 (65.41760852785185, 60.098667885469162),
 (45.86370330515628, 52.728712907433817),
 (41.74552259372066, 42.463951354208504),
 (46.099771282656064, 52.639801305012043),
 (53.47757340128953, 57.061625355775831),
 (50.47870866461908, 49.562575913507168),
 (32.59575411272515, 49.984154974807915),
 (55.424181237471345, 54.144552250046274),
 (46.38233234744176, 45.53374720014979),
 (28.831822090224943, 45.956363781237478),
 (43.57394971846685, 38.294091843978897),
 (47.06449510224598, 53.640349413350975),
 (32.424181237471345, 37.752643311960156),
 (36.33130903134756, 38.876256167375161),
 (50.29603285093748, 44.533776051504425),
 (44.12445997568367, 43.353825419137841),
 (35.099771282656064, 43.424524615441825),
 (73.69213042990248, 54.42663605525329),
 (63.33788174096706, 55.316022881123295),
 (34.74552259372066, 44.743318115476733),
 (45.424181237471345, 54.198770014957304),
 (51.424181237471345, 42.43889525925038),
 (63.38233234744176, 63.711391520087048),
 (41.856510783252546, 44.977077909610934),
 (45.61387009613326, 61.939092474196642),
 (54.65571898616284, 50.113180917860475),
 (43.37780211863347, 51.079125910436822),
 (50.09977128265606, 59.033523406177281),
 (70.47757340128953, 48.904828965816833),
 (40.424181237471345, 42.197575057031386),
 (63.57394971846685, 44.767812518113502),
 (37.74552259372066, 41.403386208561031),
 (47.424181237471345, 52.581642554256391),
 (67.15973615609374, 55.684217449218842),
 (64.24150542378975, 50.62157446502583),
 (50.028083658506354, 50.844188265853703),
 (28.069932548535935, 36.966005298231558),
 (62.24150542378975, 50.720174655001387),
 (70.06993254853595, 55.84985882973681),
 (65.06993254853595, 44.775032347199428),
 (73.62044280575276, 66.661188475797388),
 (35.181540550352054, 42.592678431260246),
 (67.27791686752937, 63.124115636992272),
 (42.337881740967056, 51.045820103758089),
 (42.61387009613326, 52.945593051092182),
 (45.50945461622087, 49.586981288350877),
 (57.65571898616284, 52.937269479011022),
 (45.20622924337967, 63.628747480315852),
 (50.028083658506354, 54.96734809838788),
 (76.33130903134756, 61.029590680214142),
 (48.21075947218795, 46.044768388154708),
 (73.06993254853595, 61.273033844452449),
 (42.29603285093748, 57.354176790964118),
 (72.89178696366264, 63.70272941603266),
 (56.62044280575276, 57.862094028376518),
 (42.65028153987288, 60.955708671769024),
 (40.56047793231507, 36.669551824685826),
 (58.12445997568367, 50.520220142440891),
 (48.974691494688166, 60.372089104334478),
 (33.028083658506354, 51.620014918889936),
 (33.38233234744176, 62.713974398010102),
 (49.82729186141665, 63.508662961427106),
 (51.29603285093748, 56.815690417710286),
 (66.44229722087945, 54.681227024959298),
 (47.928198407402256, 55.461907245700424),
 (59.891786963662625, 62.77748981965739),
 (41.47757340128954, 48.894041863612671),
 (39.424181237471345, 44.972917025965351),
 (57.442297220879446, 56.681624124276723),
 (40.82729186141665, 49.628611271128179),
 (55.12445997568367, 51.341222822421344),
 (50.24264068711929, 48.131375603442734),
 (45.79654590981486, 58.025751838606595),
 (47.50945461622087, 57.636120490811926),
 (55.57394971846685, 50.572847533168478),
 (68.51398484502917, 52.444410943914491),
 (48.61387009613326, 51.30654984240946),
 (55.12445997568367, 58.979824415963229),
 (31.974691494688166, 44.496766058562677),
 (60.560477932315074, 57.855440270070801),
 (58.47757340128953, 65.208601584216083),
 (59.27791686752937, 63.279824415963219)]

Optimization with Autograd

Here, I try using autograd to do the optimizations required.


In [20]:
def predict(wb_vect, wb_unflattener, graph_fp):#, linweights):
    """
    Given the weights and biases for each layer, make a prediction for the graph.
    """
    fp = graph_fp.compute_fingerprint(wb_vect, wb_unflattener)
    wb = wb_unflattener(wb_vect)
    top_layer = max(wb.keys())
    linweights = wb[top_layer]['linweights']
    return np.dot(fp, linweights)

predict(wb.vect, wb.unflattener, gfp)


Out[20]:
array([[ 0.47141931]])

In [21]:
@jit
def train_loss(wb_vect, wb_unflattener):
    """
    Training loss function - should take in a vector.
    """
    sum_loss = 0
    for i, g in enumerate(syngraphs):
        gfp = GraphFingerprint(g, 2, shapes)
        pred = predict(wb_vect, wb_unflattener, gfp)
        loss = len(g.nodes()) - predict(wb_vect, wb_unflattener, gfp)
        sum_loss = sum_loss + loss ** 2
    
    return sum_loss / len(syngraphs)

train_loss(wb.vect, wb.unflattener)


Out[21]:
array([[ 30.52793686]])

In [22]:
def sgd(grad, wb_vect, wb_unflattener, callback=None, num_iters=200, step_size=0.1, mass=0.9):
    """
    Stochastic gradient descent with momentum.
    """
    velocity = np.zeros(len(wb_vect))
    for i in range(num_iters):
        print(i)
        g = grad(wb_vect, wb_unflattener)
        # if callback: callback(x, i, g)
        velocity = mass * velocity - (1.0 - mass) * g
        wb_vect += step_size * velocity
        print(train_loss(wb_vect, wb_unflattener))
    return wb_vect

In [23]:
train_loss(wb.vect, wb.unflattener)


Out[23]:
array([[ 30.52793686]])

In [24]:
grad_func = grad(train_loss)

In [36]:
sgd(grad_func, wb.vect, wb.unflattener, num_iters=200)


0
[[ 1.17223147]]
1
[[ 1.08935819]]
2
[[ 0.9782466]]
3
[[ 0.84884053]]
4
[[ 0.71082609]]
5
[[ 0.57295617]]
6
[[ 0.44257404]]
7
[[ 0.32533199]]
8
[[ 0.22508747]]
9
[[ 0.1439499]]
10
[[ 0.08244543]]
11
[[ 0.03976491]]
12
[[ 0.01406122]]
13
[[ 0.00276477]]
14
[[ 0.00289126]]
15
[[ 0.01132097]]
16
[[ 0.02503518]]
17
[[ 0.0413011]]
18
[[ 0.05780206]]
19
[[ 0.0727142]]
20
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-36-c2e645b70d89> in <module>()
----> 1 sgd(grad_func, wb.vect, wb.unflattener, num_iters=200)

<ipython-input-22-171b15cc5a04> in sgd(grad, wb_vect, wb_unflattener, callback, num_iters, step_size, mass)
      6     for i in range(num_iters):
      7         print(i)
----> 8         g = grad(wb_vect, wb_unflattener)
      9         # if callback: callback(x, i, g)
     10         velocity = mass * velocity - (1.0 - mass) * g

/Users/ericmjl/Documents/github/autograd/autograd/core.py in gradfun(*args, **kwargs)
     20     @attach_name_and_doc(fun, argnum, 'Gradient')
     21     def gradfun(*args,**kwargs):
---> 22         return backward_pass(*forward_pass(fun,args,kwargs,argnum))
     23     return gradfun
     24 

/Users/ericmjl/Documents/github/autograd/autograd/core.py in backward_pass(start_node, end_node, tape)
     59                 "Types are {0} and {1}".format(type(new_node(getval(cur_outgrad))), node.node_type)
     60             for gradfun, parent in node.parent_grad_ops:
---> 61                 og = cast_to_node_type(gradfun(cur_outgrad), parent.node_type, parent.node_value)
     62                 parent.outgrads.append(og)
     63     return cur_outgrad

/Users/ericmjl/Documents/github/autograd/autograd/numpy/numpy_grads.py in new_fun(g)
    456             for axis, size in enumerate(shape):
    457                 if size == 1:
--> 458                     result = anp.sum(result, axis=axis, keepdims=True)
    459             assert anp.shape(result) == shape
    460             return result

/Users/ericmjl/Documents/github/autograd/autograd/core.py in __call__(self, *args, **kwargs)
    131                         tapes.add(tape)
    132 
--> 133         result = self.fun(*argvals, **kwargs)
    134         if result is NotImplemented: return result
    135         if ops:

/Users/ericmjl/anaconda/lib/python3.5/site-packages/numpy/core/fromnumeric.py in sum(a, axis, dtype, out, keepdims)
   1833     else:
   1834         return _methods._sum(a, axis=axis, dtype=dtype,
-> 1835                              out=out, keepdims=keepdims)
   1836 
   1837 

/Users/ericmjl/anaconda/lib/python3.5/site-packages/numpy/core/_methods.py in _sum(a, axis, dtype, out, keepdims)
     30 
     31 def _sum(a, axis=None, dtype=None, out=None, keepdims=False):
---> 32     return umr_sum(a, axis, dtype, out, keepdims)
     33 
     34 def _prod(a, axis=None, dtype=None, out=None, keepdims=False):

KeyboardInterrupt: 

In [37]:
trained_weights = wb.unflattener(wb.vect)[2]['linweights']
trained_weights


Out[37]:
array([[ 1.08165481],
       [ 1.14253854],
       [ 5.90582944],
       [ 1.41045033],
       [ 1.4167689 ],
       [ 1.04994567],
       [ 0.96421758],
       [ 1.90682434],
       [ 1.67905653],
       [ 0.95807283]])

In [38]:
test_graphs = [make_random_graph(sample(all_nodes, 6), 5, features_dict) for i in range(100)]

test_fingerprints = np.zeros((len(test_graphs), 10))
# test_fingerprints
for i, g in enumerate(test_graphs):
    gfp = GraphFingerprint(g, 2, shapes)
    fp = gfp.compute_fingerprint(wb.vect, wb.unflattener)
    test_fingerprints[i] = fp

# test_fingerprints


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-38-83410cc0c3f8> in <module>()
      4 # test_fingerprints
      5 for i, g in enumerate(test_graphs):
----> 6     gfp = GraphFingerprint(g, 2, shapes)
      7     fp = gfp.compute_fingerprint(wb.vect, wb.unflattener)
      8     test_fingerprints[i] = fp

/Users/ericmjl/Documents/github/graph-fingerprint/graphfingerprint/fingerprint.py in __init__(self, graph, n_layers, shapes)
     41     """
     42     def __init__(self, graph, n_layers, shapes):
---> 43         super(GraphFingerprint, self).__init__()
     44         self.layers = self.set_layers(n_layers, graph, shapes)
     45 

TypeError: super(type, obj): obj must be an instance or subtype of type

In [35]:
preds = []
for i, g in enumerate(test_graphs):
    gfp = GraphFingerprint(g, 2, shapes)
#     fp = gfp.compute_fingerprint(wb.vect, wb.unflattener)
    preds.append(predict(wb.vect, wb.unflattener, gfp)[0])
# preds[0]

In [33]:
Y_test = [score(g) for g in syngraphs]

[i for i in zip(Y_test, preds)]


Out[33]:
[(48.928198407402256, array([ 7.14867291])),
 (74.20622924337965, array([ 7.09265279])),
 (54.33788174096706, array([ 7.13102324])),
 (47.15973615609376, array([ 7.14004252])),
 (57.74552259372066, array([ 7.15536322])),
 (48.099771282656064, array([ 7.08721323])),
 (56.47870866461908, array([ 7.09514357])),
 (38.24150542378974, array([ 7.14967854])),
 (56.29603285093747, array([ 7.14834682])),
 (50.00996767509825, array([ 7.03731036])),
 (52.88839199818388, array([ 7.11407215])),
 (54.21075947218795, array([ 7.00740655])),
 (57.51398484502916, array([ 7.1114426])),
 (37.62044280575275, array([ 7.15784038])),
 (44.29603285093748, array([ 7.10304284])),
 (22.382332347441764, array([ 7.158675])),
 (49.57394971846685, array([ 7.14029349])),
 (34.74552259372066, array([ 7.13034657])),
 (71.47870866461908, array([ 7.03841688])),
 (37.44229722087945, array([ 7.10595887])),
 (35.83182209022495, array([ 7.09429304])),
 (43.21075947218795, array([ 7.12786603])),
 (48.79201568100656, array([ 7.09037602])),
 (43.06335983891644, array([ 7.11814221])),
 (46.685557720282965, array([ 7.09232816])),
 (70.15973615609374, array([ 7.01373972])),
 (50.337881740967056, array([ 7.01228641])),
 (57.88839199818388, array([ 7.09837914])),
 (54.974691494688166, array([ 7.09483916])),
 (60.331309031347566, array([ 7.03965592])),
 (33.028083658506354, array([ 7.0522535])),
 (48.33130903134756, array([ 7.13030435])),
 (49.028083658506354, array([ 7.11119339])),
 (56.24603565259803, array([ 7.15329592])),
 (50.33130903134756, array([ 7.15014104])),
 (67.33788174096706, array([ 7.10864614])),
 (75.57394971846685, array([ 7.16037839])),
 (73.62044280575276, array([ 7.00790325])),
 (48.69213042990247, array([ 7.01155799])),
 (42.57394971846685, array([ 7.09961517])),
 (61.928198407402256, array([ 7.15232727])),
 (51.61387009613326, array([ 7.09673716])),
 (52.21075947218795, array([ 7.10867008])),
 (70.06993254853595, array([ 7.03883113])),
 (55.12445997568367, array([ 7.1187624])),
 (47.24150542378975, array([ 7.11834077])),
 (68.92819840740225, array([ 7.10733665])),
 (59.18154055035206, array([ 7.11127933])),
 (46.47870866461908, array([ 7.10969478])),
 (46.21075947218796, array([ 7.10597842])),
 (43.59575411272515, array([ 7.02725554])),
 (58.66024921497113, array([ 7.13002199])),
 (46.337881740967056, array([ 7.11056595])),
 (22.595754112725153, array([ 7.08825151])),
 (71.51398484502916, array([ 7.15577325])),
 (70.51398484502917, array([ 7.09186454])),
 (41.44229722087945, array([ 7.01992822])),
 (42.47870866461908, array([ 7.15656033])),
 (62.47757340128953, array([ 7.15739865])),
 (59.62044280575276, array([ 7.05363065])),
 (64.44229722087945, array([ 7.11592472])),
 (38.47757340128953, array([ 7.15449761])),
 (42.337881740967056, array([ 7.15235532])),
 (69.69213042990248, array([ 7.08769716])),
 (55.09524105384777, array([ 7.13597431])),
 (37.099771282656064, array([ 7.08859125])),
 (45.028083658506354, array([ 7.09591182])),
 (51.424181237471345, array([ 7.12318052])),
 (35.928198407402256, array([ 7.10176003])),
 (31.099771282656064, array([ 7.11395373])),
 (70.92366817859397, array([ 7.11501726])),
 (46.24603565259804, array([ 7.15882416])),
 (40.85651078325254, array([ 7.10851732])),
 (59.47757340128954, array([ 7.10298631])),
 (58.337881740967056, array([ 7.1539892])),
 (41.928198407402256, array([ 7.14602392])),
 (46.424181237471345, array([ 7.04180724])),
 (65.06993254853595, array([ 7.086408])),
 (41.20622924337965, array([ 7.03486792])),
 (45.59575411272515, array([ 7.06994849])),
 (42.15973615609376, array([ 7.12774099])),
 (39.62044280575275, array([ 7.11474207])),
 (48.24150542378975, array([ 7.01964172])),
 (75.15973615609374, array([ 7.05322737])),
 (41.15973615609375, array([ 7.08902032])),
 (37.79201568100657, array([ 7.07929729])),
 (36.009967675098245, array([ 7.15868789])),
 (30.37780211863347, array([ 7.15469859])),
 (61.57394971846684, array([ 7.10604694])),
 (60.69213042990246, array([ 7.14316626])),
 (47.181540550352054, array([ 7.05936376])),
 (51.424181237471345, array([ 7.09633475])),
 (37.50945461622087, array([ 7.11066585])),
 (76.89178696366264, array([ 7.07958998])),
 (45.79654590981486, array([ 7.10627])),
 (62.928198407402256, array([ 7.10502272])),
 (57.82729186141665, array([ 7.05426132])),
 (30.613870096133258, array([ 7.10443681])),
 (77.33788174096705, array([ 7.10593066])),
 (55.06993254853594, array([ 7.01209186]))]

In [ ]:


In [ ]:
plt.scatter(preds, n_nodes, alpha=0.3)
plt.xlabel('predictions')
plt.ylabel('actual')
plt.title('number of nodes')

In [ ]:
class Class(object):
    """docstring for ClassName"""
    def __init__(self, arg):
        super(Class, self).__init__()
        self.arg = arg
        
    def __iter__():
        pass
        
    def function(self, value, other_thing):
        return value['k']['v']['x'] ** 2 + value['y'] ** 3
    
    def function2(self, value):
        return np.sum(np.dot(value['arr1'], value['arr2'])) + 1
        
        
# def function(value):
#     return value ** 2

In [ ]:
c = Class(np.random.random((10,10)))

from collections import OrderedDict
value = dict({'k':{'v':{'x':3.0}}, 'y':2.0})
gradfunc = grad(c.function)
gradfunc(value, 'string')

In [ ]:
def fun2(value):
    return np.sum(np.dot(value['arr1'], value['arr2']))

value = {'arr1':np.random.random((10,10)), 'arr2':np.random.random((10,10))}
gradfunc = grad(fun2)(value)
gradfunc

In [ ]:
value = {'arr1':np.random.random((10,10)), 'arr2':np.random.random((10,10))}
# value
gradfunc = grad(c.function2)
gradfunc(value)
# np.dot(c.arg, value['arr1'])# , c.arg)
# c.function2(value)

In [ ]:
np.dot(value['arr1'], value['arr2'])

In [ ]:


In [ ]: