In [2]:
from os.path import expanduser, join
import numpy as np
from sklearn import utils
import matplotlib.pyplot as plt
from skimage.feature import peak_local_max
from localizer import config, util, visualization, models, keras_helpers
%matplotlib inline
#%matplotlib notebook
import seaborn as sns
sns.set(color_codes=True)
In [199]:
import importlib
for module in (config, util, visualization, models, keras_helpers):
importlib.reload(module)
In [4]:
data_dir = join(expanduser("~"), 'deeplocalizer_data', 'data_mxnet2')
In [5]:
X_train, y_train, X_test, y_test, X_val, y_val = util.load_or_restore_data(data_dir)
In [6]:
print(X_train.shape)
print(X_test.shape)
print(X_val.shape)
In [7]:
fig = visualization.plot_sample_images(X_train, y_train, random=True)
In [8]:
Xs_train = util.resize_data(X_train, config.filtersize)
Xs_val = util.resize_data(X_val, config.filtersize)
Xs_test = util.resize_data(X_test, config.filtersize)
print(Xs_train.shape)
print(Xs_test.shape)
print(Xs_val.shape)
In [9]:
fig = visualization.plot_sample_images(Xs_train, y_train)
In [36]:
saliency_network = models.get_saliency_network(train=True, compile=False)
In [37]:
saliency_datagen = keras_helpers.get_datagen(Xs_train)
saliency_class_weight = [1., 1.]
saliency_weight_file = join(expanduser("~"), 'saliency-localizer-models', 'season_2015', 'saliency_weights-wobn')
In [38]:
saliency_history = keras_helpers.fit_model(saliency_network, saliency_datagen, Xs_train, y_train, Xs_val, y_val,
saliency_weight_file, saliency_class_weight, batchsize=128, categorial=False)
In [39]:
batch_error = np.array([hist[0] for hist in saliency_history.batch_hist])
In [40]:
plt.figure(figsize=(16, 6))
_ = plt.plot(batch_error)
In [41]:
batch_error.shape[0] / (658145 / 256)
Out[41]:
In [42]:
batch_error[-1]
Out[42]:
In [43]:
saliency_network.load_weights(saliency_weight_file)
In [52]:
ys_out = keras_helpers.predict_model(saliency_network, Xs_test, saliency_datagen)
In [55]:
max(ys_out[:, 1])
Out[55]:
In [56]:
fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(12, 6))
axes.flat[0].hist(y_test[:])
axes.flat[1].hist(ys_out[:, 1])
plt.tight_layout()
In [57]:
precision, recall, average_precision, thresholds, fpr, tpr, roc_auc = keras_helpers.evaluate_model(
y_test > 0.8, ys_out, visualize=True)
In [58]:
saliency_threshold = keras_helpers.select_threshold(precision, recall, thresholds, min_value=0.98, optimize='recall')
In [59]:
saliency_threshold = 0.5
In [150]:
convolution_container = models.get_saliency_network(train=False, shape=image_filtersize.shape)
In [151]:
convolution_function = keras_helpers.get_convolution_function(saliency_network, convolution_container)
In [152]:
image_path = join(expanduser("~"), 'deeplocalizer_data')
with open(join(image_path, 'test.txt'), 'r') as f:
image_files = [line.split('\n')[0] for line in f.readlines()]
In [153]:
imfile = image_files[0]
image, image_filtersize, targetsize = util.preprocess_image(join(image_path, imfile), config.filtersize)
In [156]:
%%%timeit
convolution_function(image_filtersize.reshape((1, 1, image_filtersize.shape[0], image_filtersize.shape[1])))
In [212]:
saliency = convolution_function(image_filtersize.reshape((1, 1, image_filtersize.shape[0], image_filtersize.shape[1])))[0]
In [213]:
saliency.shape
Out[213]:
In [214]:
saliency = gaussian_filter(saliency[0, 0], sigma=3.)
In [ ]:
def get_saliency_image(self, image_fname):
image, image_filtersize, targetsize = util.preprocess_image(
image_fname, config.filtersize)
saliency = self.convolution_function(
image_filtersize.reshape((1, 1, image_filtersize.shape[0],
image_filtersize.shape[1])))
saliency = gaussian_filter(saliency[0, 0], sigma=3.)
return saliency, image
def detect_tags(self, image_path, saliency_threshold=0.5):
saliency, image = self.get_saliency_image(image_path)
candidates = util.get_candidates(saliency, saliency_threshold)
saliencies = util.extract_saliencies(candidates, saliency)
candidates_img = util.scale_candidates(candidates, saliency)
rois, mask = util.extract_rois(candidates_img, image)
return saliencies[mask], candidates_img, rois
In [216]:
_ = visualization.plot_saliency_image(image_filtersize, saliency, config.filtersize, figsize=(12, 6))
#plt.savefig('saliency.png', dpi=300, bbox_inches='tight')
In [217]:
candidates = util.get_candidates(saliency, saliency_threshold, dist=config.filtersize[0] // 4)
In [218]:
saliencies = util.extract_saliencies(candidates, saliency)
In [219]:
candidates_img = util.scale_candidates(candidates, saliency)
In [220]:
rois, mask = util.extract_rois(candidates_img, image)
In [221]:
fig = visualization.plot_sample_images(rois, saliencies)
In [224]:
plt.figure(figsize=(16, 16))
_ = plt.imshow(visualization.get_roi_overlay(candidates_img, image))