In [ ]:
import os
os.environ["WANDB_ENTITY"] = "mlclass"
In [1]:
import keras
import wandb
from wandb.keras import WandbCallback
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.applications.resnet50 import ResNet50, decode_predictions, preprocess_input
import groceries
import matplotlib.pyplot as plt
Using TensorFlow backend.
In [2]:
(x_train, y_train_raw), (x_test, y_test_raw), class_names = groceries.load_data()
In [3]:
# take a look at the kinds of images we're dealing with
plt.imshow(x_train[100].astype(int))
Out[3]:
<matplotlib.image.AxesImage at 0x7fca2debe240>
In [4]:
# Print out the classes we need to target
class_names
Out[4]:
['BEANS',
'CAKE',
'CANDY',
'CEREAL',
'CHIPS',
'CHOCOLATE',
'COFFEE',
'CORN',
'FISH',
'FLOUR',
'HONEY',
'JAM',
'JUICE',
'MILK',
'NUTS',
'OIL',
'PASTA',
'RICE',
'SODA',
'SPICES',
'SUGAR',
'TEA',
'TOMATO_SAUCE',
'VINEGAR',
'WATER']
In [5]:
# check how balanced our class distribution is
plt.hist(y_train_raw)
Out[5]:
(array([639., 439., 672., 199., 406., 602., 295., 504., 381., 560.]),
array([ 0. , 2.4, 4.8, 7.2, 9.6, 12. , 14.4, 16.8, 19.2, 21.6, 24. ]),
<a list of 10 Patch objects>)
In [6]:
# One hot encode ouput
y_train = keras.utils.to_categorical(y_train_raw)
y_test = keras.utils.to_categorical(y_test_raw)
In [7]:
# We build an extremely simple perceptron to try to fit our data
x_train_normalized = x_train / 255.
x_test_normalized = x_test / 255.
very_simple_model=Sequential()
very_simple_model.add(Flatten())
very_simple_model.add(Dense(25, activation="sigmoid"))
very_simple_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
wandb.init(project="transfer learn")
very_simple_model.fit(x_train_normalized, y_train, epochs=10, validation_data=(x_test_normalized, y_test), callbacks=[WandbCallback()])
W&B Run: https://app.wandb.ai/mlclass/transfer learn/runs/qf6yjgj0
Call `%%wandb` in the cell containing your training loop to display live results.
Train on 4697 samples, validate on 250 samples
Epoch 1/10
4697/4697 [==============================] - 6s 1ms/step - loss: 8.6940 - acc: 0.0319 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 2/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 3/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 4/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 5/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 6/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 7/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 8/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 9/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Epoch 10/10
4697/4697 [==============================] - 5s 1ms/step - loss: 8.6992 - acc: 0.0321 - val_loss: 9.5742 - val_acc: 0.0400
Out[7]:
<keras.callbacks.History at 0x7fc978c93470>
In [8]:
# Load ResNet50 Trained on imagenet
resnet_model = ResNet50(weights="imagenet")
In [9]:
resnet_model.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D) (None, 114, 114, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 56, 56, 64) 0 pool1_pad[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 56, 56, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 56, 56, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 56, 56, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 56, 56, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 56, 56, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 56, 56, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 56, 56, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 56, 56, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 56, 56, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 56, 56, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 56, 56, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 56, 56, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 56, 56, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0]
activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0]
activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0]
activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0]
activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_32 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_33 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0]
activation_31[0][0]
__________________________________________________________________________________________________
activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0]
__________________________________________________________________________________________________
res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0]
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0]
__________________________________________________________________________________________________
activation_35 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0]
__________________________________________________________________________________________________
res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0]
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0]
__________________________________________________________________________________________________
activation_36 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0]
__________________________________________________________________________________________________
res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0]
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0]
activation_34[0][0]
__________________________________________________________________________________________________
activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0]
__________________________________________________________________________________________________
res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0]
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0]
__________________________________________________________________________________________________
activation_38 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_39 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0]
activation_37[0][0]
__________________________________________________________________________________________________
activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_41 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_42 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0]
__________________________________________________________________________________________________
res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0]
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0]
__________________________________________________________________________________________________
activation_44 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0]
__________________________________________________________________________________________________
res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0]
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0]
__________________________________________________________________________________________________
activation_45 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0]
__________________________________________________________________________________________________
res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0]
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0]
activation_43[0][0]
__________________________________________________________________________________________________
activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0]
__________________________________________________________________________________________________
res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0]
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0]
__________________________________________________________________________________________________
activation_47 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0]
__________________________________________________________________________________________________
res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0]
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0]
__________________________________________________________________________________________________
activation_48 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0]
__________________________________________________________________________________________________
res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0]
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0]
activation_46[0][0]
__________________________________________________________________________________________________
activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0]
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 2048) 0 activation_49[0][0]
__________________________________________________________________________________________________
fc1000 (Dense) (None, 1000) 2049000 avg_pool[0][0]
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
In [10]:
from keras.preprocessing import image
import numpy as np
img = image.load_img('elephant.jpg', target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = resnet_model.predict(x)
print('Predicted:', decode_predictions(preds, top=3)[0])
Predicted: [('n01871265', 'tusker', 0.48542652), ('n02504013', 'Indian_elephant', 0.3432938), ('n02504458', 'African_elephant', 0.17125048)]
In [11]:
# We should preprocess the images the same way resnet images were preprocessed
x_train_preprocessed = preprocess_input(x_train)
x_test_preprocessed = preprocess_input(x_test)
In [12]:
# Build a new model that is ResNet50 minus the very last layer
last_layer = resnet_model.get_layer("avg_pool")
resnet_layers = keras.Model(inputs=resnet_model.inputs, outputs=last_layer.output)
resnet_layers.summary()
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D) (None, 114, 114, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 56, 56, 64) 0 pool1_pad[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 56, 56, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 56, 56, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 56, 56, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 56, 56, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 56, 56, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 56, 56, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 56, 56, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 56, 56, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 56, 56, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 56, 56, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 56, 56, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 56, 56, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 56, 56, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0]
activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0]
activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0]
activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0]
activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_32 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_33 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0]
activation_31[0][0]
__________________________________________________________________________________________________
activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0]
__________________________________________________________________________________________________
res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0]
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0]
__________________________________________________________________________________________________
activation_35 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0]
__________________________________________________________________________________________________
res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0]
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0]
__________________________________________________________________________________________________
activation_36 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0]
__________________________________________________________________________________________________
res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0]
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0]
activation_34[0][0]
__________________________________________________________________________________________________
activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0]
__________________________________________________________________________________________________
res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0]
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0]
__________________________________________________________________________________________________
activation_38 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_39 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0]
activation_37[0][0]
__________________________________________________________________________________________________
activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_41 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_42 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0]
__________________________________________________________________________________________________
res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0]
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0]
__________________________________________________________________________________________________
activation_44 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0]
__________________________________________________________________________________________________
res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0]
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0]
__________________________________________________________________________________________________
activation_45 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0]
__________________________________________________________________________________________________
res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0]
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0]
activation_43[0][0]
__________________________________________________________________________________________________
activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0]
__________________________________________________________________________________________________
res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0]
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0]
__________________________________________________________________________________________________
activation_47 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0]
__________________________________________________________________________________________________
res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0]
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0]
__________________________________________________________________________________________________
activation_48 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0]
__________________________________________________________________________________________________
res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0]
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0]
activation_46[0][0]
__________________________________________________________________________________________________
activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0]
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 2048) 0 activation_49[0][0]
==================================================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
In [13]:
# We use our resnet to "predict" but because we have removed the top layer,
# this outputs the activations of the second to last layer on our dataset
x_train_features = resnet_layers.predict(x_train_preprocessed)
In [14]:
x_test_features = resnet_layers.predict(x_test_preprocessed)
In [15]:
feature_model=Sequential()
feature_model.add(Dense(25, activation="sigmoid"))
feature_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
In [16]:
wandb.init(project="transfer learn")
feature_model.fit(x_train_features, y_train, epochs=50, validation_data=(x_test_features, y_test), callbacks=[WandbCallback()])
W&B Run: https://app.wandb.ai/mlclass/transfer learn/runs/fhj30zry
Call `%%wandb` in the cell containing your training loop to display live results.
Train on 4697 samples, validate on 250 samples
Epoch 1/50
4697/4697 [==============================] - 1s 315us/step - loss: 1.6686 - acc: 0.4880 - val_loss: 1.2663 - val_acc: 0.5880
Epoch 2/50
4697/4697 [==============================] - 1s 125us/step - loss: 0.6745 - acc: 0.8086 - val_loss: 1.0515 - val_acc: 0.6600
Epoch 3/50
4697/4697 [==============================] - 1s 127us/step - loss: 0.4548 - acc: 0.8844 - val_loss: 0.9101 - val_acc: 0.7080
Epoch 4/50
4697/4697 [==============================] - 1s 127us/step - loss: 0.3314 - acc: 0.9263 - val_loss: 0.8746 - val_acc: 0.7120
Epoch 5/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.2507 - acc: 0.9500 - val_loss: 0.8200 - val_acc: 0.7600
Epoch 6/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.1928 - acc: 0.9719 - val_loss: 0.8054 - val_acc: 0.7440
Epoch 7/50
4697/4697 [==============================] - 1s 121us/step - loss: 0.1535 - acc: 0.9821 - val_loss: 0.7945 - val_acc: 0.7600
Epoch 8/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.1219 - acc: 0.9904 - val_loss: 0.7969 - val_acc: 0.7520
Epoch 9/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0984 - acc: 0.9957 - val_loss: 0.7987 - val_acc: 0.7480
Epoch 10/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0835 - acc: 0.9962 - val_loss: 0.8141 - val_acc: 0.7520
Epoch 11/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.0692 - acc: 0.9994 - val_loss: 0.7949 - val_acc: 0.7600
Epoch 12/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0588 - acc: 0.9996 - val_loss: 0.7749 - val_acc: 0.7800
Epoch 13/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0503 - acc: 0.9994 - val_loss: 0.7562 - val_acc: 0.7800
Epoch 14/50
4697/4697 [==============================] - 1s 116us/step - loss: 0.0428 - acc: 1.0000 - val_loss: 0.7551 - val_acc: 0.7880
Epoch 15/50
4697/4697 [==============================] - 1s 120us/step - loss: 0.0383 - acc: 1.0000 - val_loss: 0.7725 - val_acc: 0.7760
Epoch 16/50
4697/4697 [==============================] - 1s 114us/step - loss: 0.0336 - acc: 1.0000 - val_loss: 0.7552 - val_acc: 0.7760
Epoch 17/50
4697/4697 [==============================] - 1s 114us/step - loss: 0.0303 - acc: 1.0000 - val_loss: 0.7592 - val_acc: 0.7720
Epoch 18/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0262 - acc: 1.0000 - val_loss: 0.7784 - val_acc: 0.7760
Epoch 19/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0234 - acc: 1.0000 - val_loss: 0.7594 - val_acc: 0.7760
Epoch 20/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0210 - acc: 1.0000 - val_loss: 0.7668 - val_acc: 0.7720
Epoch 21/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0192 - acc: 1.0000 - val_loss: 0.7636 - val_acc: 0.7840
Epoch 22/50
4697/4697 [==============================] - 1s 122us/step - loss: 0.0172 - acc: 1.0000 - val_loss: 0.7720 - val_acc: 0.7760
Epoch 23/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0157 - acc: 1.0000 - val_loss: 0.7745 - val_acc: 0.7800
Epoch 24/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.0143 - acc: 1.0000 - val_loss: 0.7682 - val_acc: 0.7840
Epoch 25/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0129 - acc: 1.0000 - val_loss: 0.7853 - val_acc: 0.7720
Epoch 26/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0119 - acc: 1.0000 - val_loss: 0.7855 - val_acc: 0.7800
Epoch 27/50
4697/4697 [==============================] - 1s 116us/step - loss: 0.0108 - acc: 1.0000 - val_loss: 0.7590 - val_acc: 0.7880
Epoch 28/50
4697/4697 [==============================] - 1s 114us/step - loss: 0.0099 - acc: 1.0000 - val_loss: 0.7731 - val_acc: 0.7840
Epoch 29/50
4697/4697 [==============================] - 1s 121us/step - loss: 0.0092 - acc: 1.0000 - val_loss: 0.7970 - val_acc: 0.7800
Epoch 30/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.0084 - acc: 1.0000 - val_loss: 0.8024 - val_acc: 0.7880
Epoch 31/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.0077 - acc: 1.0000 - val_loss: 0.7973 - val_acc: 0.7840
Epoch 32/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0072 - acc: 1.0000 - val_loss: 0.7964 - val_acc: 0.7840
Epoch 33/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0066 - acc: 1.0000 - val_loss: 0.7938 - val_acc: 0.7920
Epoch 34/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0061 - acc: 1.0000 - val_loss: 0.8105 - val_acc: 0.7880
Epoch 35/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0056 - acc: 1.0000 - val_loss: 0.8242 - val_acc: 0.7800
Epoch 36/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0052 - acc: 1.0000 - val_loss: 0.8039 - val_acc: 0.8000
Epoch 37/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0048 - acc: 1.0000 - val_loss: 0.8109 - val_acc: 0.7840
Epoch 38/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0045 - acc: 1.0000 - val_loss: 0.8186 - val_acc: 0.8000
Epoch 39/50
4697/4697 [==============================] - 1s 122us/step - loss: 0.0041 - acc: 1.0000 - val_loss: 0.8185 - val_acc: 0.7920
Epoch 40/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.0038 - acc: 1.0000 - val_loss: 0.8174 - val_acc: 0.7880
Epoch 41/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.0036 - acc: 1.0000 - val_loss: 0.8179 - val_acc: 0.7920
Epoch 42/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0033 - acc: 1.0000 - val_loss: 0.8350 - val_acc: 0.7800
Epoch 43/50
4697/4697 [==============================] - 1s 120us/step - loss: 0.0031 - acc: 1.0000 - val_loss: 0.8376 - val_acc: 0.7840
Epoch 44/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0028 - acc: 1.0000 - val_loss: 0.8385 - val_acc: 0.7920
Epoch 45/50
4697/4697 [==============================] - 1s 118us/step - loss: 0.0026 - acc: 1.0000 - val_loss: 0.8716 - val_acc: 0.7720
Epoch 46/50
4697/4697 [==============================] - 1s 120us/step - loss: 0.0025 - acc: 1.0000 - val_loss: 0.8523 - val_acc: 0.7880
Epoch 47/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0023 - acc: 1.0000 - val_loss: 0.8606 - val_acc: 0.7840
Epoch 48/50
4697/4697 [==============================] - 1s 116us/step - loss: 0.0021 - acc: 1.0000 - val_loss: 0.8487 - val_acc: 0.7840
Epoch 49/50
4697/4697 [==============================] - 1s 117us/step - loss: 0.0020 - acc: 1.0000 - val_loss: 0.8801 - val_acc: 0.7800
Epoch 50/50
4697/4697 [==============================] - 1s 119us/step - loss: 0.0018 - acc: 1.0000 - val_loss: 0.8767 - val_acc: 0.7680
Out[16]:
<keras.callbacks.History at 0x7fc848347780>
In [17]:
# We can directly stich the models together
new_model=Sequential()
new_model.add(resnet_layers)
new_model.add(Dense(25, activation="sigmoid"))
new_model.layers[0].trainable=False
new_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
new_model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
model_1 (Model) (None, 2048) 23587712
_________________________________________________________________
dense_3 (Dense) (None, 25) 51225
=================================================================
Total params: 23,638,937
Trainable params: 51,225
Non-trainable params: 23,587,712
_________________________________________________________________
In [18]:
wandb.init(project="transfer learn")
new_model.fit(x_train_preprocessed, y_train, epochs=50, validation_data=(x_test_preprocessed, y_test), callbacks=[WandbCallback()])
W&B Run: https://app.wandb.ai/mlclass/transfer learn/runs/xerthmqh
Call `%%wandb` in the cell containing your training loop to display live results.
Train on 4697 samples, validate on 250 samples
Epoch 1/50
4697/4697 [==============================] - 43s 9ms/step - loss: 1.7996 - acc: 0.4826 - val_loss: 2.0460 - val_acc: 0.4000
Epoch 2/50
4697/4697 [==============================] - 41s 9ms/step - loss: 0.8081 - acc: 0.7739 - val_loss: 2.0796 - val_acc: 0.4360
Epoch 3/50
4697/4697 [==============================] - 41s 9ms/step - loss: 0.5515 - acc: 0.8567 - val_loss: 1.9751 - val_acc: 0.4120
Epoch 4/50
4697/4697 [==============================] - 41s 9ms/step - loss: 0.4258 - acc: 0.8948 - val_loss: 1.9526 - val_acc: 0.4560
Epoch 5/50
4672/4697 [============================>.] - ETA: 0s - loss: 0.3375 - acc: 0.9195
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-18-0b1572714d5e> in <module>
1 wandb.init(project="transfer learn")
----> 2 new_model.fit(x_train_preprocessed, y_train, epochs=50, validation_data=(x_test_preprocessed, y_test), callbacks=[WandbCallback()])
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1037 initial_epoch=initial_epoch,
1038 steps_per_epoch=steps_per_epoch,
-> 1039 validation_steps=validation_steps)
1040
1041 def evaluate(self, x=None, y=None,
/usr/local/lib/python3.6/dist-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
210 val_outs = test_loop(model, val_f, val_ins,
211 batch_size=batch_size,
--> 212 verbose=0)
213 val_outs = to_list(val_outs)
214 # Same labels assumed.
/usr/local/lib/python3.6/dist-packages/keras/engine/training_arrays.py in test_loop(model, f, ins, batch_size, verbose, steps)
390 ins_batch[i] = ins_batch[i].toarray()
391
--> 392 batch_outs = f(ins_batch)
393 if isinstance(batch_outs, list):
394 if batch_index == 0:
/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2713 return self._legacy_call(inputs)
2714
-> 2715 return self._call(inputs)
2716 else:
2717 if py_any(is_tensor(x) for x in inputs):
/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
2673 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
2674 else:
-> 2675 fetched = self._callable_fn(*array_vals)
2676 return fetched[:len(self.outputs)]
2677
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
1397 ret = tf_session.TF_SessionRunCallable(
1398 self._session._session, self._handle, args, status,
-> 1399 run_metadata_ptr)
1400 if run_metadata:
1401 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
KeyboardInterrupt:
In [ ]:
# We can allow some of the resnet layers to change as we train.
# Typically you would want to lower the learning rate in conjunction with this.
new_model.layers[0].trainable = True
# We let the last 3 blocks train
for layer in new_model.layers[0].layers[:-11]:
layer.trainable = False
for layer in new_model.layers[0].layers[-11:]:
layer.trainable = True
new_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
wandb.init(project="transfer learn")
new_model.fit(x_train_preprocessed, y_train, epochs=50, validation_data=(x_test_preprocessed, y_test), callbacks=[WandbCallback()])
W&B Run: https://app.wandb.ai/mlclass/transfer learn/runs/udba14n9
Call `%%wandb` in the cell containing your training loop to display live results.
Train on 4697 samples, validate on 250 samples
Epoch 1/50
4697/4697 [==============================] - 53s 11ms/step - loss: 0.5487 - acc: 0.8427 - val_loss: 1.0094 - val_acc: 0.6840
Epoch 2/50
4697/4697 [==============================] - 47s 10ms/step - loss: 0.1611 - acc: 0.9576 - val_loss: 1.1052 - val_acc: 0.7120
Epoch 3/50
4697/4697 [==============================] - 47s 10ms/step - loss: 0.0963 - acc: 0.9742 - val_loss: 1.3093 - val_acc: 0.6720
Epoch 4/50
4697/4697 [==============================] - 47s 10ms/step - loss: 0.0596 - acc: 0.9862 - val_loss: 1.0734 - val_acc: 0.7160
Epoch 5/50
4697/4697 [==============================] - 47s 10ms/step - loss: 0.0439 - acc: 0.9896 - val_loss: 1.1619 - val_acc: 0.7320
Epoch 6/50
4697/4697 [==============================] - 47s 10ms/step - loss: 0.0240 - acc: 0.9949 - val_loss: 1.0267 - val_acc: 0.7360
Epoch 7/50
4697/4697 [==============================] - 47s 10ms/step - loss: 0.0376 - acc: 0.9894 - val_loss: 1.1485 - val_acc: 0.7400
Epoch 8/50
4697/4697 [==============================] - 46s 10ms/step - loss: 0.0505 - acc: 0.9862 - val_loss: 1.1099 - val_acc: 0.7120
Epoch 9/50
1664/4697 [=========>....................] - ETA: 28s - loss: 0.0326 - acc: 0.9922
In [ ]:
Content source: lukas/ml-class
Similar notebooks: