In [260]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
from matrix_factorization import matrix_factorization
from graph_init import *
from similarity import *
from create_R import *
from ALS import *
from hard_hfs import *
import copy
import matplotlib.pyplot as plt


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

In [261]:
def RMSE(ground, predict):
    
    error = 0
    n = 0
    
    for i in range(len(ground)):
        for j in range(len(ground[0])):
            if ground[i,j] != 0:
                error += (ground[i,j] - predict[i,j])**2
                n += 1
                
    return np.sqrt(error/n)

def RMSEvec(ground, predict):
    
    error = 0
    n = 0
    
    for i in range(len(ground)):
        if ground[i] != 0:
            error += (ground[i] - predict[i])**2
            n += 1
                
    return np.sqrt(error/n)

def meanError(ground_truth,new_res):
    return np.mean(abs((new_res - ground_truth)[ground_truth!=0]))

def dictfromR(R):

    R_dict = {"Users": np.empty([0]), "Movies": np.empty([0]), "Ratings": np.empty([0])}

    for i in range(len(R)):
        for j in range(len(R[0])):
            if R[i,j] != 0:
                R_dict["Users"] = np.append(R_dict["Users"],i)
                R_dict["Movies"] = np.append(R_dict["Movies"],j)
                R_dict["Ratings"] = np.append(R_dict["Ratings"],R[i,j])

    return R_dict

In [262]:
R,R_dict = create_R()

print(R_dict)


/home/marc/Documents/MVA/ProjetGraphes/src/create_R.py:21: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  R[ratingsnp[i,0]-1, ratingsnp[i,-1]] = ratingsnp[i,2]
{'Ratings': array([ 2.5,  3. ,  3. , ...,  4. ,  2.5,  3.5]), 'Users': array([   0.,    0.,    0., ...,  670.,  670.,  670.]), 'Movies': array([   30.,   833.,   859., ...,  4597.,  4610.,  4696.])}

In [263]:
P_dict = copy.deepcopy(R_dict)
P_dict["Ratings"] = np.ones([len(R_dict["Ratings"])])
P = R > 0
print(P)


[[False False False ..., False False False]
 [False False False ..., False False False]
 [False False False ..., False False False]
 ..., 
 [False False False ..., False False False]
 [ True False False ..., False False False]
 [ True False False ..., False False False]]

In [264]:
to_keep = 99.85/100

iss, js = np.where(R > 0.1)
n_ratings = len(js)
ground_truth = copy.deepcopy(R)
shuf = np.array(range(n_ratings))
iss = iss[shuf]
js = js[shuf]

deleted_i = []
deleted_j = []
deleted = 0
i = 0
while deleted < len(iss)*(1-to_keep):
    if np.sum(R[iss[i]] > 0.1) > 1 and np.sum(R[:,js[i]] > 0.1) > 1:
        R[iss[i],js[i]] = 0
        deleted += 1
        deleted_i.append(iss[i])
        deleted_j.append(js[i])
    i += 1
        
    
# deleted_i = iss[:len(iss)*(1-to_keep)]
# deleted_j = js[:len(js)*(1-to_keep)]
# R[deleted_i, deleted_j] = 0

In [265]:
np.mean(R[R != 0])


Out[265]:
3.5437993850960914

In [266]:
R_dictCopy = copy.deepcopy(R_dict)
R_dict = dictfromR(R)

In [267]:
print(RMSEvec(ground_truth[deleted_i,deleted_j],3.55*np.ones(len(deleted_i))))


0.929207809006

In [112]:
N = len(R)
M = len(R[0])
K = 4

als = ALS(K,N,M,"Users","Movies","Ratings",lbda = 0.1,lbda2 = 0.1)
print("Als created")
ans = als.fit(R_dict)


Als created

In [113]:
R_rec = np.dot(als.U,np.transpose(als.V))

In [114]:
print(RMSEvec(ground_truth[deleted_i,deleted_j],R_rec[deleted_i,deleted_j]))


4.11149823383

In [ ]:


In [115]:
R_rec[deleted_i,deleted_j]


Out[115]:
array([  2.29700832e+00,   2.76217137e+00,   2.60425165e+00,
         2.47738546e+00,   2.94912485e+00,   2.91459307e+00,
         2.88271487e+00,   3.01562477e+00,   2.43390917e+00,
         2.71487981e+00,   2.21470134e+00,   2.33572217e+00,
         2.89259165e+00,   2.47322566e+00,   2.56385035e+00,
         2.44990704e+00,   2.46034045e+00,   2.56659376e+00,
         2.74186580e+00,   2.55110723e+00,   3.48679975e+00,
         3.12108784e+00,   2.98175593e+00,   3.49128958e+00,
         3.04289626e+00,   2.94288669e+00,   3.10820129e+00,
         2.98123225e+00,   3.23647658e+00,   2.16114459e+00,
         2.88611052e+00,   2.37032158e+00,   2.40893741e+00,
         2.29817837e+00,   2.03613858e+00,   2.05343990e+00,
         3.08862444e+00,   3.13599319e+00,   2.34395586e+00,
         3.05137409e+00,   2.25054251e+00,   2.45278708e+00,
         3.31962215e+00,   3.21187618e+00,   2.80992392e+00,
         3.38475186e+00,   2.48840992e+00,   2.38070646e+00,
         3.35124536e+00,   3.30810419e+00,   2.86457624e+00,
         2.27570317e+00,   3.03714837e+00,   3.10674403e+00,
         2.92108954e+00,   2.61713781e+00,   3.17233112e+00,
         3.18260186e+00,   3.26870059e+00,   2.40334941e+00,
         2.36720831e+00,   2.76855993e+00,   2.36681474e+00,
         2.81003697e+00,   2.17854899e+00,   1.69061078e+00,
         2.62695945e+00,   2.59273558e+00,   3.21565025e+00,
         2.75253218e+00,   3.12720703e+00,   3.01321264e+00,
         2.13544218e+00,   3.63381614e+00,   2.90174193e+00,
         2.89032569e+00,   3.36027796e+00,   3.55088749e+00,
         3.82925127e+00,   2.44817656e+00,   2.95656945e+00,
         2.53447110e+00,   2.68668806e+00,   2.22852762e+00,
         2.70566330e+00,   2.50805543e+00,   2.65011627e+00,
         3.16851004e+00,   3.19199759e+00,   3.12008187e+00,
         2.64144332e+00,   3.33990031e+00,   3.14186038e+00,
         3.07152464e+00,  -2.86123540e+00,  -3.75628517e+00,
        -3.46650071e+00,  -1.96312336e+00,  -4.08584250e+00,
        -4.43989390e+00,  -2.33009009e+00,  -3.76989929e+00,
        -3.47217588e+00,  -4.62322234e+00,  -3.84867681e+00,
        -3.30264246e+00,  -4.10170030e+00,  -4.18543319e+00,
        -3.13677017e+00,  -4.02760990e+00,  -3.81648701e+00,
        -4.09884739e+00,  -4.15929987e+00,  -3.52046641e+00,
        -4.19020034e+00,  -3.08122509e+00,  -3.73117402e+00,
        -3.30519428e+00,  -3.00234339e+00,  -4.20655128e+00,
        -4.13297881e+00,  -3.02241052e+00,  -3.32202685e+00,
        -2.40290500e+00,  -3.91017306e+00,  -3.93447521e+00,
        -2.99844504e+00,  -4.00339158e+00,  -3.90779615e+00,
        -2.53372715e+00,  -3.88265022e+00,  -3.93141305e+00,
        -3.37099689e+00,  -4.01687068e+00,  -4.10545262e+00,
        -4.15897189e+00,  -3.86407049e+00,  -4.03153203e+00,
        -3.86543381e+00,  -3.16983378e+00,  -3.52573013e+00,
        -3.98670959e+00,  -3.53079778e+00,  -4.21531835e+00,
        -2.36668613e-02,   1.41027069e+00,   1.31991526e+00,
         6.08910605e-01,   1.15076378e+00,   3.95336744e-01,
         4.73791626e-01,   7.34629294e-01,   1.93631668e+00,
        -3.14415388e-01,   1.51154579e+00,   2.83219286e-01,
        -1.03589740e+00,   1.30587715e-01,   6.99975962e-01,
         1.94754245e-02,   6.34568400e-01,   8.18465735e-01,
         5.34951436e-01,   1.26986732e+00,  -5.78313129e-01,
         6.02348364e-01,   5.78960438e-01,   7.27977735e-01,
         4.67148037e-01,   6.76231143e-01,   5.04878634e-01,
         9.67682894e-01,  -5.71994673e-01,   9.39971643e-01,
         1.17782995e+00,   9.10244899e-01,   7.59535933e-01,
         2.07718062e-01,   7.15855232e-01,   1.33225175e+00,
         9.90362758e-01,  -1.26677513e-01,   1.07080726e+00,
         1.89846060e+00,   9.28754134e-01,  -4.43904939e-01,
         1.07372285e+00,   1.41078015e+00,   1.51336293e+00,
         1.21946973e+00,   5.16204002e-01,   2.67743224e-01,
         7.81182994e-01,  -3.45255284e-01,  -1.24125421e-01,
        -5.51491566e-01,   7.31750642e-03,   5.19327743e-01,
         2.35790652e+00,   8.78161434e-01,  -2.20656337e-01,
         9.40578165e-01,   9.14185445e-01,   1.05208735e-02,
         7.53363104e-01,   1.02930717e+00,   1.37547094e-02,
         8.23850174e-01,   5.21130172e-01,   9.27535300e-01,
         3.81486465e-01,   3.07041537e-01,   8.42547280e-01,
         9.63920100e-01,  -8.21332320e-01,   8.73185079e-02,
         9.41927428e-02,   8.18284885e-02,  -1.54261891e-01,
         2.67677894e-01,   1.58707717e-01,   8.48378239e-01,
        -3.36047719e-01,   8.19303161e-01,   6.94557484e-02,
        -1.48176578e+00,   6.52973499e-01,   2.01835782e+00,
         7.75799623e-01,   4.98537719e-01,   1.36323595e+00,
         6.78518878e-02,   1.18751284e+00,   1.68429555e+00,
         6.47557780e-01,   4.86927559e-01,   4.28571553e-01,
         5.03268772e-01,   7.18047246e-01,   6.19517220e-01,
        -3.71113094e-01,   7.31278898e-01,   8.16921575e-01,
         1.02624774e+00,   4.35927108e-02,   1.25466434e+00,
        -9.81215353e-01,   2.08666013e-01,   3.65454698e-01,
         9.29160913e-01,   1.92447004e-01,   3.16952334e-01,
        -5.65326656e-03,   6.10162755e-02,   1.00142128e-01,
         3.84196566e-01,   3.81130334e-01,  -1.40385665e-01,
         1.51904006e+00,   7.17814869e-01,   1.10581262e+00,
        -3.93096461e-01,   1.19451289e+00,   1.94765720e+00,
         3.05758345e-01,   1.34184693e+00,   3.98634891e-01,
         7.69597754e-01,   1.08730653e+00,   8.46503612e-01,
         1.30089487e+00,   7.79420240e-01,   1.37138744e+00,
         1.43778735e+00,   1.24273264e+00,   8.36788238e-02,
         1.41569839e+00,   1.18034562e+00,  -4.69989299e-01,
         9.01877634e-01,   2.78609678e-02,   2.81610568e-01,
         2.03039803e-01,   2.97171070e-01,  -4.93754878e-01,
         1.15621615e+00,  -4.82175925e-01,   1.40298112e-01,
         4.65313481e-01,   1.14171969e+00,   5.53140180e-01,
        -7.40757011e-02,   2.05391287e+00,   4.81946172e-02,
        -2.64163545e-01,   1.18646654e+00,   5.24231071e-01,
         7.94041649e-01,   7.68376835e-01,  -4.84919752e-01,
         1.59520642e+00,  -1.31878019e+00,   5.58468128e-01,
         1.25063328e+00,  -1.48821414e-01,   8.48286348e-01,
         3.18197600e-01,   4.38353554e-01,   3.46747217e-01,
         6.50660550e-01,  -7.17086547e-01,   6.46062613e-01,
         1.42601222e+00,  -3.02586180e-01,   9.21561575e-01,
         7.40032949e-01,   4.22181272e-01,   5.97066971e-01,
         1.43902012e-01,   1.37141492e+00,   1.07642021e+00,
         2.19389783e+00,   4.55315802e+00,  -1.87730050e+00,
         1.75417922e-01,  -2.98762914e-01,   8.54596556e-01,
        -5.06314333e-01,   8.55997917e-01,   1.40682444e-01,
         1.36023585e+00,   1.02678430e+00,   6.35166337e-01,
         1.26809347e+00,  -2.05095501e-01,   6.02345365e-01,
         1.04064991e+00,   1.20634018e+00,   6.55003549e-01,
        -3.88109523e-01,   1.96302387e+00,   1.26536746e+00,
         1.87506779e+00,   9.75214322e-01,  -4.46295223e-01,
         8.61131722e-01,  -7.21783582e-04,   3.31865983e+00,
         3.61246769e+00,   3.35850315e+00,   3.63098692e+00])

In [ ]:


In [116]:
lp = LaplacianParams()

# sim = similarity(als.U)
sim = build_graph(als.U, GraphParams())
# Seems to work better with U... 
L = build_laplacian(sim,lp)

In [117]:
lhfs = []
lconf = []
for i in range(len(R[0])):
    if i%1000 == 0:
        print(i)
    hfs0, confidence = simple_hfs(als.U, R[:,i]*2, L, sim)
    maxconfidences = np.array([max(confidence[i,:]) for i in range(len(confidence))])
    
    lim = np.percentile(maxconfidences, 90)
    
    lhfs.append(hfs0/2)
    lconf.append(maxconfidences > lim)

R_barre = np.vstack(lhfs).T
confs = np.vstack(lconf).T


0
/home/marc/anaconda3/lib/python3.5/site-packages/numpy/core/numeric.py:190: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  a = empty(shape, dtype, order)
1000
2000
3000
4000
5000
6000
7000
8000
9000

In [118]:
print(RMSEvec(ground_truth[deleted_i,deleted_j],R_barre[deleted_i,deleted_j]))


1.33119486631

In [ ]:


In [119]:
R_barre_limited = R_barre * confs
R_barre_final = copy.deepcopy(R_barre_limited)
R_barre_final[R != 0] = 0
# R_dict_barre = dictfromR(R_barre_final)

In [120]:
N = len(R)
M = len(R[0])
K = 1

als_trans = ALS(K,N,M,"Users","Movies","Ratings",lbda = 0.1,lbda2 = 0.1)
print("Als created")
ans = als_trans.fitTransductive(R_dict,R_barre_final,C1=1,C2=.1)

R_rec_trans = np.dot(als_trans.U,np.transpose(als_trans.V))


Als created

In [121]:
print(RMSEvec(ground_truth[deleted_i,deleted_j],R_rec_trans[deleted_i,deleted_j]))


1.21227852531

In [ ]:


In [268]:
enleves = np.array([10, 30, 5, 5, 1, 0.1, 0.01, 20, 1, 0.5, 0.5, 0.1, 0.2, 0.01, 0.05, 0.3, 0.4, 0.25, 0.15, 0.35])
RMSEhfs = np.array([1.32674374885, 1.30905728547, 1.44788836719, 1.4249715405, 1.43950610356,1.26059863196, 1.78376517003, 1.3206959484, 1.35854861733, 1.39001285193, 1.52719852323, 1.06357327853, 1.08529290817, 1.28805702866, 1.23867418052, 
1.1801936887, 1.48496202627, 1.13833946302, 1.15589476823, 1.33119486631]) #HFS
RMSErand = np.array([1.10827678053, 1.06910116524, 1.21354273217, 1.21354273217, 1.05455697314, 0.998278220708, 1.16881836212, 1.06500160262, 1.05455697314, 1.05057013321, 1.05057013321, 0.998278220708, 1.02676253025, 1.16881836212, 1.08760845752, 1.09035716054, 1.0698806655, 1.08158711036, 0.929207809006, 1.11193665912]) #Constant
RMSEals = np.array([3.34103109698, 3.6567784054, 3.11900504497, 3.36009286288, 2.97774025308,1.06654905746, 0.91409271283, 3.48617106862, 4.30114920462, 3.04552449344, 3.81279272213, 0.915592088969, 2.02556708006, 0.80595040701, 0.957383505365, 1.08584093858, 4.44853531724, 1.83389908845, 2.82346349481, 4.11149823383]) #ALS
RMSEtrans = np.array([1.29307089433,1.32952142571, 1.32641438877, 1.43307918684, 1.62984802727, 1.07843682215, 0.95859828853, 1.27306921394, 1.71125206849, 1.31887592815, 1.14583621658, 0.947932371713, 0.878965723091, 1.1518595346,0.994812640054, 0.930838724808, 1.08594083348, 0.864330876328, 0.886734080467, 1.21227852531]) #Transductive

Transductive

10% - 1 - 0.1 - 1.29307089433

5% - 1 - 0.1 - 1.32641438877

5% - 1 - 1 - 1.43307918684

1% - 1 - 0.1 - 1.62984802727

0.1% - 1 - 0.1 - 1.07843682215

0.01% - 1 - 0.1 - 0.95859828853

20% - 1 - 0.1 - 1.27306921394

30% - 1 - 0.1 - 1.32952142571

1% - 1 - 0.1 - 1.71125206849

.5% - 1 - 0.1 - 1.31887592815


In [271]:
plt.plot(np.unique(enleves), [np.mean(RMSEhfs[np.where(enleves == i)]) for i in np.unique(enleves)], "k-.+", label = "HFS")
plt.plot(np.unique(enleves), [np.mean(RMSErand[np.where(enleves == i)]) for i in np.unique(enleves)], "k-+", label = "Constant (3.55)")
plt.plot(np.unique(enleves), [np.mean(RMSEals[np.where(enleves == i)]) for i in np.unique(enleves)], "k--+", label = "ALS")
plt.plot(np.unique(enleves), [np.mean(RMSEtrans[np.where(enleves == i)]) for i in np.unique(enleves)], "k:+", label = "Transduction")

plt.xlabel("Percentage of data taken out")
plt.ylabel("RMSE on taken out data")

plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.show()



In [272]:
plt.plot(np.unique(enleves), [np.mean(RMSEhfs[np.where(enleves == i)]) for i in np.unique(enleves)], "k-.+", label = "HFS")
plt.plot(np.unique(enleves), [np.mean(RMSErand[np.where(enleves == i)]) for i in np.unique(enleves)], "k-+", label = "Constant (3.55)")
plt.plot(np.unique(enleves), [np.mean(RMSEals[np.where(enleves == i)]) for i in np.unique(enleves)], "k--+", label = "ALS")
plt.plot(np.unique(enleves), [np.mean(RMSEtrans[np.where(enleves == i)]) for i in np.unique(enleves)], "k:+", label = "Transduction")

plt.xlabel("Percentage of data taken out")
plt.ylabel("RMSE on taken out data")

plt.xlim(xmin = 0.01, xmax = 0.5)

plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.show()



In [ ]:


In [ ]: