In [4]:
import numpy as np

from Solver import Solver
from data_utils.data import get_CIFAR10_data

data = get_CIFAR10_data()
for k, v in data.iteritems():
  print '%s: ' % k, v.shape
from classifiers.cnn import ThreeLayerConvNet
num_train = 100
small_data = {
  'X_train': data['X_train'][:num_train],
  'y_train': data['y_train'][:num_train],
  'X_val': data['X_val'],
  'y_val': data['y_val'],
}


X_val:  (1000, 3, 32, 32)
X_train:  (49000, 3, 32, 32)
X_test:  (1000, 3, 32, 32)
y_val:  (1000,)
y_train:  (49000,)
y_test:  (1000,)

In [2]:
%load_ext autoreload
%autoreload 2


from classifiers import cnn_huge as cnn

#res = cnn.ResNet()
#res.loss(data['X_train'][:num_train],data['y_train'][:num_train])
small_data['X_train'] -= np.mean(small_data['X_train'],axis=0)

In [5]:
from classifiers import cnn_huge as cnn

res = cnn.ResNet(weight_scale=1.8e-03,reg=0.5)
solver = Solver(res, small_data,
                    update_rule='adam',
                    optim_config={
                      'learning_rate': 1e-3,
                      'stride': 1
                    },
                    verbose=True,
                    num_epochs=5, batch_size=50,
                    print_every=1)

solver.train()


2
(Iteration 1 / 10) loss: 2.620324
(Epoch 0 / 5) train acc: 0.180000; val_acc: 0.119000
(Iteration 2 / 10) loss: 3.471372
(Epoch 1 / 5) train acc: 0.200000; val_acc: 0.130000
(Iteration 3 / 10) loss: 2.143543
(Iteration 4 / 10) loss: 3.082603
(Epoch 2 / 5) train acc: 0.280000; val_acc: 0.131000
(Iteration 5 / 10) loss: 2.422685
(Iteration 6 / 10) loss: 2.772205
(Epoch 3 / 5) train acc: 0.430000; val_acc: 0.156000
(Iteration 7 / 10) loss: 2.335083
(Iteration 8 / 10) loss: 1.710659
(Epoch 4 / 5) train acc: 0.470000; val_acc: 0.190000
(Iteration 9 / 10) loss: 1.915156
(Iteration 10 / 10) loss: 1.602199
(Epoch 5 / 5) train acc: 0.600000; val_acc: 0.210000

In [ ]:
import matplotlib.pyplot as plt

plt.subplot(2, 1, 1)
plt.plot(solver.loss_history, 'o')
plt.xlabel('iteration')
plt.ylabel('loss')

plt.subplot(2, 1, 2)
plt.plot(solver.train_acc_history, '-o')
plt.plot(solver.val_acc_history, '-o')
plt.legend(['train', 'val'], loc='upper left')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.show()

In [4]:
from classifiers import cnn_huge as cnn

bn_solvers = {}
solvers = {}
weight_scales = np.logspace(-4, 0, num=20)
for i, weight_scale in enumerate(weight_scales):
  print 'Running weight scale %d / %d' % (i + 1, len(weight_scales))
  bn_model = cnn.ResNet(reg=0.5, weight_scale=weight_scale, use_batchnorm=True)
  #model = FullyConnectedNet(hidden_dims, weight_scale=weight_scale, use_batchnorm=False)

  bn_solver = Solver(bn_model, small_data,
                  num_epochs=10, batch_size=50,
                  update_rule='adam',
                  optim_config={
                    'learning_rate': 1e-3,
                  },
                  verbose=False, print_every=20)
  bn_solver.train()
  bn_solvers[weight_scale] = bn_solver

  #solver = Solver(model, small_data,
  #                num_epochs=10, batch_size=50,
  #                update_rule='adam',
  #                optim_config={
  #                  'learning_rate': 1e-3,
  #                },
  #                verbose=False, print_every=200)
  #solver.train()
  #solvers[weight_scale] = solver


Running weight scale 1 / 20
2
Running weight scale 2 / 20
2
Running weight scale 3 / 20
2
Running weight scale 4 / 20
2
Running weight scale 5 / 20
2
Running weight scale 6 / 20
2
Running weight scale 7 / 20
2
Running weight scale 8 / 20
2
Running weight scale 9 / 20
2
Running weight scale 10 / 20
2
Running weight scale 11 / 20
2
Running weight scale 12 / 20
2
Running weight scale 13 / 20
2
Running weight scale 14 / 20
2
Running weight scale 15 / 20
2
Layers.py:397: RuntimeWarning: divide by zero encountered in log
  loss = -(np.sum(np.log(prob[np.arange(N),y]))) / N
Running weight scale 16 / 20
2
Running weight scale 17 / 20
2
Running weight scale 18 / 20
2
Running weight scale 19 / 20
2
Running weight scale 20 / 20
2

In [5]:
import matplotlib.pyplot as plt

best_train_accs, bn_best_train_accs = [], []
best_val_accs, bn_best_val_accs = [], []
final_train_loss, bn_final_train_loss = [], []

for ws in weight_scales:
  #best_train_accs.append(max(solvers[ws].train_acc_history))
  bn_best_train_accs.append(max(bn_solvers[ws].train_acc_history))
  
  #best_val_accs.append(max(solvers[ws].val_acc_history))
  bn_best_val_accs.append(max(bn_solvers[ws].val_acc_history))
  
  #final_train_loss.append(np.mean(solvers[ws].loss_history[-100:]))
  bn_final_train_loss.append(np.mean(bn_solvers[ws].loss_history[-100:]))
  
plt.subplot(3, 1, 1)
plt.title('Best val accuracy vs weight initialization scale')
plt.xlabel('Weight initialization scale')
plt.ylabel('Best val accuracy')
#plt.semilogx(weight_scales, best_val_accs, '-o', label='baseline')
plt.semilogx(weight_scales, bn_best_val_accs, '-o', label='batchnorm')
plt.legend(ncol=2, loc='lower right')

plt.subplot(3, 1, 2)
plt.title('Best train accuracy vs weight initialization scale')
plt.xlabel('Weight initialization scale')
plt.ylabel('Best training accuracy')
#plt.semilogx(weight_scales, best_train_accs, '-o', label='baseline')
plt.semilogx(weight_scales, bn_best_train_accs, '-o', label='batchnorm')
plt.legend()

plt.subplot(3, 1, 3)
plt.title('Final training loss vs weight initialization scale')
plt.xlabel('Weight initialization scale')
plt.ylabel('Final training loss')
#plt.semilogx(weight_scales, final_train_loss, '-o', label='baseline')
plt.semilogx(weight_scales, bn_final_train_loss, '-o', label='batchnorm')
plt.legend()

plt.gcf().set_size_inches(10, 15)
plt.show()



In [6]:
print bn_final_train_loss


[1.2755638388286108, 1.5275512339724377, 1.2625992944998934, 1.1614496680080015, 1.5654056694117151, 1.3051667355092313, 1.1835590471150068, 1.1933190746147753, 1.0484197926674885, 1.1644572034283149, 1.0947923094518854, 2.3701118697864363, 8.8613680810625493, 65.070526071725453, 518.36886400041158, inf, inf, inf, inf, inf]

In [7]:
print weight_scales


[  1.00000000e-04   1.62377674e-04   2.63665090e-04   4.28133240e-04
   6.95192796e-04   1.12883789e-03   1.83298071e-03   2.97635144e-03
   4.83293024e-03   7.84759970e-03   1.27427499e-02   2.06913808e-02
   3.35981829e-02   5.45559478e-02   8.85866790e-02   1.43844989e-01
   2.33572147e-01   3.79269019e-01   6.15848211e-01   1.00000000e+00]

In [22]:
weight_scales = np.logspace(-1.5, 0, num=1000)
print weight_scales


[ 0.03162278  0.0317323   0.0318422   0.03195248  0.03206314  0.03217418
  0.03228561  0.03239743  0.03250963  0.03262222  0.0327352   0.03284857
  0.03296234  0.0330765   0.03319105  0.033306    0.03342135  0.0335371
  0.03365325  0.0337698   0.03388676  0.03400412  0.03412189  0.03424006
  0.03435865  0.03447764  0.03459705  0.03471687  0.0348371   0.03495776
  0.03507883  0.03520031  0.03532222  0.03544456  0.03556731  0.03569049
  0.0358141   0.03593814  0.0360626   0.0361875   0.03631283  0.03643859
  0.03656479  0.03669142  0.0368185   0.03694601  0.03707397  0.03720237
  0.03733121  0.0374605   0.03759024  0.03772042  0.03785106  0.03798215
  0.0381137   0.0382457   0.03837815  0.03851107  0.03864445  0.03877828
  0.03891259  0.03904735  0.03918259  0.03931829  0.03945446  0.0395911
  0.03972822  0.03986581  0.04000388  0.04014242  0.04028145  0.04042096
  0.04056095  0.04070142  0.04084239  0.04098384  0.04112578  0.04126821
  0.04141113  0.04155455  0.04169847  0.04184289  0.0419878   0.04213322
  0.04227914  0.04242556  0.0425725   0.04271994  0.04286789  0.04301636
  0.04316534  0.04331483  0.04346485  0.04361538  0.04376643  0.04391801
  0.04407011  0.04422274  0.0443759   0.04452959  0.04468381  0.04483856
  0.04499385  0.04514968  0.04530605  0.04546295  0.04562041  0.04577841
  0.04593695  0.04609604  0.04625569  0.04641589  0.04657664  0.04673795
  0.04689982  0.04706225  0.04722524  0.0473888   0.04755292  0.04771761
  0.04788287  0.0480487   0.04821511  0.0483821   0.04854966  0.0487178
  0.04888653  0.04905584  0.04922573  0.04939622  0.04956729  0.04973896
  0.04991122  0.05008408  0.05025754  0.05043159  0.05060626  0.05078152
  0.05095739  0.05113388  0.05131097  0.05148867  0.051667    0.05184594
  0.05202549  0.05220568  0.05238648  0.05256791  0.05274997  0.05293266
  0.05311598  0.05329994  0.05348454  0.05366977  0.05385564  0.05404216
  0.05422933  0.05441714  0.05460561  0.05479472  0.05498449  0.05517492
  0.05536601  0.05555776  0.05575018  0.05594326  0.05613701  0.05633143
  0.05652652  0.05672229  0.05691874  0.05711586  0.05731368  0.05751217
  0.05771135  0.05791123  0.05811179  0.05831305  0.05851501  0.05871766
  0.05892102  0.05912508  0.05932985  0.05953533  0.05974152  0.05994843
  0.06015605  0.06036439  0.06057345  0.06078323  0.06099374  0.06120498
  0.06141696  0.06162966  0.06184311  0.06205729  0.06227221  0.06248788
  0.0627043   0.06292146  0.06313938  0.06335805  0.06357748  0.06379767
  0.06401862  0.06424034  0.06446282  0.06468608  0.06491011  0.06513491
  0.06536049  0.06558686  0.065814    0.06604194  0.06627066  0.06650018
  0.06673049  0.0669616   0.06719351  0.06742622  0.06765974  0.06789407
  0.06812921  0.06836516  0.06860193  0.06883952  0.06907793  0.06931717
  0.06955724  0.06979814  0.07003987  0.07028244  0.07052585  0.07077011
  0.07101521  0.07126115  0.07150795  0.07175561  0.07200412  0.07225349
  0.07250373  0.07275484  0.07300681  0.07325965  0.07351338  0.07376798
  0.07402346  0.07427982  0.07453708  0.07479523  0.07505426  0.0753142
  0.07557504  0.07583678  0.07609943  0.07636298  0.07662745  0.07689284
  0.07715914  0.07742637  0.07769452  0.0779636   0.07823361  0.07850456
  0.07877645  0.07904928  0.07932305  0.07959777  0.07987344  0.08015007
  0.08042765  0.0807062   0.08098571  0.08126619  0.08154764  0.08183007
  0.08211347  0.08239786  0.08268323  0.08296959  0.08325694  0.08354528
  0.08383462  0.08412497  0.08441632  0.08470868  0.08500206  0.08529644
  0.08559185  0.08588829  0.08618574  0.08648423  0.08678376  0.08708431
  0.08738592  0.08768856  0.08799225  0.088297    0.0886028   0.08890966
  0.08921758  0.08952657  0.08983663  0.09014776  0.09045997  0.09077327
  0.09108764  0.09140311  0.09171967  0.09203732  0.09235607  0.09267593
  0.0929969   0.09331898  0.09364217  0.09396648  0.09429192  0.09461848
  0.09494618  0.095275    0.09560497  0.09593608  0.09626834  0.09660175
  0.09693631  0.09727203  0.09760892  0.09794697  0.09828619  0.09862658
  0.09896816  0.09931092  0.09965486  0.1         0.10034633  0.10069386
  0.1010426   0.10139254  0.1017437   0.10209607  0.10244966  0.10280447
  0.10316052  0.1035178   0.10387631  0.10423607  0.10459707  0.10495932
  0.10532283  0.1056876   0.10605363  0.10642092  0.10678949  0.10715934
  0.10753047  0.10790288  0.10827658  0.10865158  0.10902787  0.10940547
  0.10978438  0.11016459  0.11054613  0.11092899  0.11131317  0.11169868
  0.11208553  0.11247372  0.11286325  0.11325413  0.11364637  0.11403996
  0.11443492  0.11483124  0.11522894  0.11562801  0.11602847  0.11643031
  0.11683355  0.11723818  0.11764421  0.11805165  0.1184605   0.11887077
  0.11928246  0.11969557  0.12011011  0.12052609  0.12094351  0.12136238
  0.1217827   0.12220447  0.1226277   0.1230524   0.12347857  0.12390622
  0.12433534  0.12476596  0.12519806  0.12563166  0.12606676  0.12650337
  0.12694149  0.12738113  0.12782229  0.12826498  0.12870921  0.12915497
  0.12960227  0.13005113  0.13050153  0.1309535   0.13140704  0.13186214
  0.13231882  0.13277708  0.13323693  0.13369837  0.13416141  0.13462606
  0.13509231  0.13556018  0.13602967  0.13650078  0.13697353  0.13744791
  0.13792394  0.13840161  0.13888094  0.13936193  0.13984458  0.14032891
  0.14081491  0.1413026   0.14179197  0.14228305  0.14277582  0.1432703
  0.14376649  0.1442644   0.14476403  0.14526539  0.14576849  0.14627334
  0.14677993  0.14728827  0.14779838  0.14831025  0.1488239   0.14933932
  0.14985653  0.15037553  0.15089633  0.15141893  0.15194334  0.15246957
  0.15299762  0.1535275   0.15405922  0.15459277  0.15512818  0.15566544
  0.15620455  0.15674554  0.1572884   0.15783314  0.15837977  0.15892829
  0.15947871  0.16003103  0.16058527  0.16114143  0.16169951  0.16225953
  0.16282149  0.16338539  0.16395124  0.16451906  0.16508884  0.1656606
  0.16623433  0.16681005  0.16738777  0.16796749  0.16854921  0.16913295
  0.16971871  0.1703065   0.17089633  0.1714882   0.17208212  0.17267809
  0.17327613  0.17387624  0.17447843  0.1750827   0.17568907  0.17629754
  0.17690811  0.1775208   0.17813561  0.17875255  0.17937163  0.17999285
  0.18061622  0.18124175  0.18186945  0.18249932  0.18313138  0.18376562
  0.18440206  0.1850407   0.18568156  0.18632463  0.18696993  0.18761747
  0.18826725  0.18891928  0.18957357  0.19023012  0.19088895  0.19155006
  0.19221345  0.19287915  0.19354715  0.19421747  0.19489011  0.19556507
  0.19624238  0.19692203  0.19760403  0.19828839  0.19897513  0.19966425
  0.20035575  0.20104964  0.20174594  0.20244465  0.20314578  0.20384934
  0.20455534  0.20526378  0.20597467  0.20668802  0.20740385  0.20812216
  0.20884295  0.20956624  0.21029203  0.21102034  0.21175117  0.21248454
  0.21322044  0.21395889  0.21469989  0.21544347  0.21618962  0.21693835
  0.21768968  0.21844361  0.21920015  0.21995931  0.2207211   0.22148552
  0.2222526   0.22302233  0.22379473  0.2245698   0.22534756  0.22612801
  0.22691116  0.22769703  0.22848561  0.22927693  0.23007099  0.2308678
  0.23166737  0.23246971  0.23327482  0.23408273  0.23489343  0.23570694
  0.23652327  0.23734243  0.23816442  0.23898926  0.23981695  0.24064752
  0.24148095  0.24231728  0.2431565   0.24399863  0.24484367  0.24569165
  0.24654255  0.24739641  0.24825322  0.249113    0.24997576  0.25084151
  0.25171025  0.252582    0.25345677  0.25433458  0.25521542  0.25609931
  0.25698626  0.25787629  0.2587694   0.2596656   0.2605649   0.26146732
  0.26237287  0.26328155  0.26419337  0.26510836  0.26602651  0.26694785
  0.26787237  0.2688001   0.26973104  0.27066521  0.27160261  0.27254325
  0.27348716  0.27443433  0.27538478  0.27633853  0.27729558  0.27825594
  0.27921963  0.28018666  0.28115703  0.28213077  0.28310788  0.28408837
  0.28507226  0.28605955  0.28705027  0.28804442  0.289042    0.29004305
  0.29104756  0.29205555  0.29306703  0.29408202  0.29510052  0.29612254
  0.29714811  0.29817723  0.29920991  0.30024617  0.30128602  0.30232947
  0.30337653  0.30442722  0.30548155  0.30653953  0.30760117  0.30866649
  0.3097355   0.31080822  0.31188465  0.3129648   0.3140487   0.31513635
  0.31622777  0.31732296  0.31842195  0.31952475  0.32063137  0.32174182
  0.32285611  0.32397426  0.32509629  0.3262222   0.32735201  0.32848574
  0.32962339  0.33076498  0.33191052  0.33306003  0.33421353  0.33537102
  0.33653251  0.33769803  0.33886759  0.34004119  0.34121886  0.34240061
  0.34358646  0.34477641  0.34597048  0.34716868  0.34837104  0.34957756
  0.35078826  0.35200315  0.35322225  0.35444557  0.35567313  0.35690493
  0.35814101  0.35938137  0.36062602  0.36187498  0.36312827  0.3643859
  0.36564788  0.36691424  0.36818498  0.36946012  0.37073968  0.37202367
  0.3733121   0.374605    0.37590238  0.37720425  0.37851063  0.37982153
  0.38113697  0.38245697  0.38378154  0.3851107   0.38644446  0.38778284
  0.38912586  0.39047352  0.39182586  0.39318288  0.39454459  0.39591103
  0.39728219  0.39865811  0.40003879  0.40142425  0.40281451  0.40420958
  0.40560949  0.40701425  0.40842387  0.40983837  0.41125777  0.41268208
  0.41411133  0.41554553  0.4169847   0.41842885  0.419878    0.42133217
  0.42279138  0.42425564  0.42572498  0.4271994   0.42867892  0.43016358
  0.43165337  0.43314832  0.43464845  0.43615378  0.43766432  0.43918009
  0.44070111  0.4422274   0.44375897  0.44529585  0.44683805  0.44838559
  0.4499385   0.45149678  0.45306045  0.45462955  0.45620407  0.45778405
  0.45936951  0.46096045  0.4625569   0.46415888  0.46576641  0.46737951
  0.46899819  0.47062248  0.4722524   0.47388796  0.47552919  0.47717609
  0.47882871  0.48048704  0.48215112  0.48382097  0.48549659  0.48717802
  0.48886527  0.49055837  0.49225733  0.49396217  0.49567292  0.4973896
  0.49911221  0.5008408   0.50257537  0.50431595  0.50606256  0.50781521
  0.50957394  0.51133875  0.51310968  0.51488675  0.51666996  0.51845935
  0.52025494  0.52205675  0.5238648   0.52567911  0.52749971  0.52932661
  0.53115983  0.53299941  0.53484535  0.53669769  0.53855645  0.54042164
  0.54229329  0.54417143  0.54605607  0.54794723  0.54984495  0.55174924
  0.55366012  0.55557762  0.55750176  0.55943257  0.56137006  0.56331427
  0.5652652   0.5672229   0.56918737  0.57115865  0.57313675  0.57512171
  0.57711354  0.57911226  0.58111791  0.58313051  0.58515008  0.58717664
  0.58921022  0.59125084  0.59329853  0.59535331  0.59741521  0.59948425
  0.60156046  0.60364385  0.60573446  0.60783231  0.60993743  0.61204984
  0.61416956  0.61629663  0.61843106  0.62057288  0.62272212  0.62487881
  0.62704296  0.62921461  0.63139378  0.6335805   0.63577479  0.63797668
  0.6401862   0.64240337  0.64462821  0.64686077  0.64910105  0.65134909
  0.65360492  0.65586857  0.65814005  0.6604194   0.66270664  0.6650018
  0.66730492  0.66961601  0.6719351   0.67426222  0.67659741  0.67894068
  0.68129207  0.6836516   0.6860193   0.68839521  0.69077934  0.69317173
  0.6955724   0.69798139  0.70039872  0.70282443  0.70525853  0.70770107
  0.71015206  0.71261154  0.71507954  0.71755609  0.72004122  0.72253495
  0.72503732  0.72754835  0.73006808  0.73259654  0.73513376  0.73767976
  0.74023458  0.74279825  0.7453708   0.74795225  0.75054265  0.75314202
  0.75575039  0.75836779  0.76099426  0.76362983  0.76627452  0.76892837
  0.77159142  0.77426368  0.7769452   0.77963601  0.78233614  0.78504562
  0.78776448  0.79049276  0.79323049  0.7959777   0.79873442  0.8015007
  0.80427655  0.80706201  0.80985713  0.81266192  0.81547643  0.81830068
  0.82113472  0.82397857  0.82683227  0.82969585  0.83256935  0.83545281
  0.83834624  0.8412497   0.84416322  0.84708683  0.85002056  0.85296445
  0.85591854  0.85888286  0.86185744  0.86484233  0.86783755  0.87084315
  0.87385916  0.87688561  0.87992254  0.88297     0.886028    0.8890966
  0.89217582  0.89526571  0.8983663   0.90147763  0.90459974  0.90773265
  0.91087642  0.91403107  0.91719666  0.9203732   0.92356075  0.92675933
  0.92996899  0.93318977  0.9364217   0.93966483  0.94291919  0.94618482
  0.94946176  0.95275005  0.95604972  0.95936083  0.9626834   0.96601748
  0.96936311  0.97272032  0.97608916  0.97946967  0.98286188  0.98626585
  0.9896816   0.99310918  0.99654863  1.        ]

In [ ]: