Keras implementation of the triplet loss of [1].
Commonly, a machine learning problem consists of 4 components: data (Section 1.2), a trainings-objective (Section 1.3), a model (a parametrizable function, Section 1.4), and a trainings-procedure (Section 1.5).
After defining the problem, we train the model in Section 2 and then evaluate it in Section 3.
In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
In [2]:
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import numpy as np
import os
import keras.backend as K
import tensorflow as tf
assert tf.__version__.startswith("1.3"),"Tensorflow that was used for this example"
assert keras.__version__ == "2.0.8","Tensorflow that was used for this example"
In [3]:
batch_size = 100
num_classes = 10 # how many categories there are
embedding_dim = 32 # how many dimensions the embedded space has
epochs = 30
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'
In [4]:
# The data, shuffled and split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# normalize data
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
Some remarks to the implementation(see Section 3.2):
Calculation of the triplet loss per batch can roughly be divided into five steps:
Step 1) Calculation of the pairwise euclidean distances $\|z_i, z_j\|$ for all pairs $i,j$ in the batch of embedded images. The euclidean distances are used in multiple occations in the paper, e.g. Equations (1), (2) and (3) in [1]. The result will be a [batch_size x batchsize] shaped 2d tensor, which contains the pairwise distances. See [4] for more details on the implemnentation and the results.
Step 2) Next, we calculate pairwise label equality. The result will be a [batch_size x batchsize] shaped 2d tensor, which contains whether each trainings example is in the same class or not. For example, if the pairwise label equality matrix is $L$, and $L{(3,4)}=1$, this means that image $x_3$ has the same label as $x_4$. We need this pairwais label equality to determine positive and negative pairs.
Step 3) Next, we get the euclidean distances from all positive and from all negative examples. We check which is a positive example (using the label equality matrix), and select these examples from the euclidean distance matrix. We set the elements at the diagonal -1 so that distances to an image itself, i.e. anchor to anchor cases are ignored in a later step. The results of step3 are two [batch_size x batch_size] shaped 2d tensors. The positive one contains all $\|z_i^a-z_j^p\|$, whereas the negative one contains all $\|z_i^a-z_j^n\|$ for all $i,j$ pairs in this batch.|
Step 4) Get permutations of all possible permutations of triplets per row.
Step 5) Use only the ones that violate the triplet constraint to calculate the loss.
In [5]:
def triplet_loss(y_true, y_pred, alpha=0.8, batch_size=batch_size):
print("compiling triplet loss: %0.5f"%alpha)
print("Y_pred(these are the embedded images) shape: %s"% y_pred.shape)
print("Y true(these are the labels of the images )shape : %s"% y_true.shape)
z=tf.cast(y_pred, tf.float64)
"""
1) calculate pairwise euclidean distances
In the first step we calculate the pairwise euclidean distance from each embedded image to each other. If we had three images in a batch (z1,z2,z3)
"""
z_row_norm = tf.reduce_sum(tf.pow(z,2), axis=1, keep_dims=True) # [batch_size, 1]
squared_distances=tf.matmul(a=z,b=z,transpose_a=False,transpose_b=True) # => [batch_size, batch_size]
squared_distances = -2 * squared_distances
squared_distances = squared_distances + z_row_norm # => broadcast as row vector
pw_sqrd_euclid_dists = tf.abs(squared_distances + tf.transpose(z_row_norm)) # => broadcast as column vector; use tf.abs because very small -0 floats
#pw_euclid_dists = tf.sqrt(pw_sqrd_euclid_dists)
"""
2) get pairwaise label equality
In this step we calculate which of the true labels are equal to each other.
"""
y_row = tf.expand_dims(K.flatten(y_true), 0) # => [batch_size, 1 ]
y_row_ary = tf.tile(y_row, [batch_size, 1])
pw_label_equality = tf.cast(tf.equal(y_row_ary, tf.transpose(y_row_ary)), tf.int32)
"""
3) Define all positive examples and all negative examples
A positive example is if it has the same label as the anchor.
anchors are on the identity axis, so they are excluded
"""
# get all positive examples
positive_labels_cond = tf.not_equal(pw_label_equality, tf.eye(batch_size, dtype=tf.int32))
positive_ed = tf.where(condition=positive_labels_cond , x=pw_sqrd_euclid_dists, y=tf.ones_like(pw_sqrd_euclid_dists)*-1)
positive_ed = tf.add(positive_ed, tf.eye(batch_size, dtype=tf.float64)*-1) # exclude exclude self distance
# get all negative examples
negative_labels_cond = tf.equal(pw_label_equality, tf.zeros_like(pw_label_equality, dtype=tf.int32)) # get all negative examples
negative_ed = tf.where(condition=negative_labels_cond , x=pw_sqrd_euclid_dists, y=tf.ones_like(pw_sqrd_euclid_dists)*-1)
negative_ed = tf.add(negative_ed, tf.eye(batch_size, dtype=tf.float64)*-1) # exclude self distances
"""
4) Get all possible triplet permutations for each row of the batch.
"""
pos_row = tf.tile(tf.reshape(positive_ed, [-1, 1]), [1, batch_size])
neg_col = tf.reshape(tf.tile(negative_ed, [1 , batch_size]), [-1, batch_size])
"""
5) Select the ones that violate the triplet constraint
"""
# condition: exclude all invalid examples
# we want: distance a=>n should be larger than the distance a=>p+margin
# we want to catch examples where the distance a=>n-margin is smaller than the distance of the positive anchors
neg_greater_zero = tf.greater_equal(neg_col, tf.zeros_like(neg_col)) # all the negative examples
pos_greater_zero = tf.greater_equal(pos_row, tf.zeros_like(pos_row)) # permuted with all positive ones
d_pos_less_than_d_neg = tf.less(x=neg_col-alpha, y=pos_row) # which violate distance anchor-positive < anchor-negative
hinge_loss = tf.maximum(pos_row-neg_col+alpha, 0) # loss calculation for all permutations
permuations_loss = tf.where(tf.logical_and(tf.logical_and(neg_greater_zero,d_pos_less_than_d_neg),pos_greater_zero),hinge_loss,tf.zeros_like(pos_row))
# => shape [BATCH_SIZE*BATCH_SIZE, BATCH_SIZE]. This shape is because we only want all possible combination per row of the batch
"""
6) Sum up
"""
num_non_zero_perms = tf.reduce_sum(tf.cast(tf.greater(x=permuations_loss, y=tf.zeros_like(permuations_loss)), tf.float64))
mean_permutation_loss = tf.reduce_sum(permuations_loss , axis=1) / num_non_zero_perms # only calculate mean between non-zero calculation losses, because 0 means invalid
# => shape [BATCH_SIZE*BATCH_SIZE,1]
per_example_loss = tf.reshape(mean_permutation_loss, [batch_size,batch_size]) # all valid permutations per example
total_hinge_loss = tf.reduce_sum(per_example_loss)
return tf.cast(total_hinge_loss, tf.float32)
In [ ]:
In [6]:
# Convert class vectors to binary class matrices.
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(embedding_dim))
model.add(Activation('linear'))
# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.001, decay=1e-6)
In [7]:
# Let's train the model using RMSprop
model.compile(loss=triplet_loss,
optimizer=opt,
metrics=[])
In [8]:
print(y_test.shape, x_test.shape)
In [9]:
print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)
# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
steps_per_epoch=int(np.ceil(x_train.shape[0] / float(batch_size))),
epochs=epochs,
#validation_data=(x_test, y_test),
workers=1)
# Save model and weights
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)
In [ ]:
# https://stackoverflow.com/questions/33436221/displaying-rotatable-3d-plots-in-ipython-or-ipython-notebook
In [10]:
embedded_test = np.ndarray((x_test.shape[0], embedding_dim), dtype="float32")
for i in range(int(x_test.shape[0]/batch_size )):
start_idx = i * batch_size
end_idx = (i+1) * batch_size
embedded_test[start_idx:end_idx] = model.predict(x_test[start_idx:end_idx])
In [11]:
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import v_measure_score
from sklearn.metrics import adjusted_mutual_info_score
num_classes=10
kmeans = KMeans(n_clusters=num_classes, random_state=0, n_init=20).fit(embedded_test)
kmeans.labels_
print("KMeans V-Measure", v_measure_score(labels_true=y_test.flatten(), labels_pred=kmeans.labels_))
print("KMeans AMI", adjusted_mutual_info_score(labels_true=y_test.flatten(), labels_pred=kmeans.labels_))
In [12]:
import sklearn
#np.seterr(divide='ignore', invalid='ignore')
silhouette = sklearn.metrics.silhouette_score(
X=embedded_test,
labels=y_test.flatten(),
metric='euclidean')
print("Silhouette: %0.3f"%silhouette)
In [13]:
# get a hex color range for number of parts
def get_N_HexCol(N=5):
import colorsys # for get_N_HexCol
HSV_tuples = [(x*1.0/N, 1, 1) for x in range(N)]
hex_out = []
for rgb in HSV_tuples:
rgb = tuple(map(lambda x: int(x*255),colorsys.hsv_to_rgb(*rgb)))
hex_out.append("#%.2X%.2X%.2X"%rgb )
return hex_out
In [17]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import decomposition
from matplotlib.markers import MarkerStyle
pca = decomposition.PCA(n_components=3)
pca.fit(embedded_test)
test_reduced = pca.transform(embedded_test)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot(111, projection='3d')
colors=get_N_HexCol(N=10)
markers = list(MarkerStyle().markers.keys())
for i, x in enumerate(test_reduced[:1000]):
class_id = y_test[i][0]
ax.scatter(x[0], x[1], x[2], c=colors[class_id], marker=markers[class_id])
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()
In [ ]: