In [1]:
import numpy as np
import theano
import pickle
import matplotlib.pyplot as plt
import numpy.linalg as la
%matplotlib inline
%precision 3
def load(name):
with open(name, 'rb') as f:
d = pickle.load(f)
return d
def plotaux(d):
aux_wts = d['allwts'][-1][2]
plt.plot(aux_wts[0], aux_wts[1], 'ro')
print (d['allwts'][-1][3])
def print_layers(dump):
layers = dump['layers']
allwts = dump['allwts']
for namedic, wb in zip(layers, allwts):
name, dic = namedic
print("################")
print(name, ":", end="\n\t")
for field, val in sorted(dic.items(), key=lambda x:x[0]):
print(field, ":", val, end='\n\t')
print()
for i, w in enumerate(wb):
print("\t W{}) f{} {}".format(i, str(w.dtype)[-2:], w.shape))
def print_norms(dump):
allwts = dump['allwts']
for layer, wb in enumerate(allwts):
print("Layer :", layer)
for iw, w in enumerate(wb):
if w.ndim == 2:
norms = la.norm(w, axis=0)
print(iw, norms.shape)
print(norms)
#print((100*norms).astype(int))
In [7]:
linesonly = load('aux_only34174_96.pkl')
basic = load('0all555555_01.pkl')
softaux = load('cnn_3softaux640038_02.pkl')
In [8]:
print_layers(linesonly)
In [9]:
print_layers(basic)
In [11]:
print_layers(softaux)
In [12]:
plotaux(softaux)
plt.show()
plotaux(linesonly)
In [13]:
# Copy Lines info
softaux['allwts'][-1][2:] = linesonly['allwts'][-1][2:]
softaux['layers'][-1][1]['n_aux'] = linesonly['layers'][-1][1]['n_aux']
# Copy from basic
for tgt, src in zip(softaux['allwts'][:-1], basic['allwts'][:-1]):
tgt[:] = src[:]
softaux['allwts'][-1][:2] = basic['allwts'][-1][:2]
softaux['layers'][-1][1]['reg']['rate'] = 4
softaux['layers'][-2][1]['reg']['rate'] = 2
In [14]:
print_layers(softaux)
In [22]:
softaux['training_params']['NUM_EPOCHS'] = 41
softaux['training_params']['SEED'] = 555555
softaux['training_params']['CUR_EPOCH'] = 100
softaux['training_params']
Out[22]:
In [20]:
Out[20]:
In [23]:
## Checks
print("Basic")
for s, t in zip(softaux['allwts'], basic['allwts']):
print("#########")
for ss, tt in zip(s, t):
print("#", (ss != tt).sum())
print("Lines only")
for s, t in zip(reversed(softaux['allwts']), reversed(linesonly['allwts'])):
print("#########")
for ss, tt in zip(s, t):
print("#", (ss != tt).sum())
In [25]:
#Augument last AuxCross values
print((softaux['allwts'][-1][-2]).shape, (softaux['allwts'][-1][-1]).shape)
for i in range(4):
name = "merged_{}.pkl".format(2**i)
print(name)
with open(name, 'wb') as f:
pickle.dump(softaux, f, -1)
softaux['allwts'][-1][-2] *= 2
softaux['allwts'][-1][-1] *= 2
In [ ]: