Feature Extraction

In this notebook, I will try to learn features of the radio images.


In [2]:
import sys

import keras.layers
import keras.models
import matplotlib.pyplot
import numpy
import scipy.ndimage.filters

sys.path.insert(1, '..')
import crowdastro.config
import crowdastro.data
import crowdastro.show

%matplotlib inline

NEIGHBOURHOOD_RADIUS = 20

In [4]:
input_shape = (1, NEIGHBOURHOOD_RADIUS * 2, NEIGHBOURHOOD_RADIUS * 2)
n_conv_filters = 100
conv_width = 4
hidden_dim = 256

In [17]:
model = keras.models.Sequential()

encoder = keras.models.Sequential()
encoder.add(keras.layers.Convolution2D(n_conv_filters, conv_width, conv_width, border_mode='valid', input_shape=input_shape))
encoder.add(keras.layers.Activation('relu'))
encoder.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
encoder.add(keras.layers.Dropout(0.25))
encoder.add(keras.layers.Flatten())
encoder.add(keras.layers.Dense(hidden_dim))

decoder = keras.models.Sequential([
        keras.layers.Dense(hidden_dim * 2, input_shape=(hidden_dim,)),
        keras.layers.Dense(input_shape[1] * input_shape[2])
])

ae = keras.layers.AutoEncoder(encoder=encoder, decoder=decoder, output_reconstruction=True)
model.add(ae)
model.compile(optimizer='sgd', loss='mse')

In [23]:
# Get some images, get some neighbourhoods, and hence find a training set.

def potential_hosts(subject, sigma=2, threshold=0.05):
    ir = crowdastro.data.get_ir(subject)

    neighborhood = numpy.ones((10, 10))
    blurred_ir = scipy.ndimage.filters.gaussian_filter(ir, sigma) > threshold
    local_max = scipy.ndimage.filters.maximum_filter(blurred_ir, footprint=neighborhood) == blurred_ir
    region_labels, n_labels = scipy.ndimage.measurements.label(local_max)
    maxima = numpy.array(
            [numpy.array((region_labels == i + 1).nonzero()).T.mean(axis=0)
             for i in range(n_labels)]
    )
    maxima = maxima[numpy.logical_and(maxima[:, 1] != 0, maxima[:, 1] != 499)]
    return maxima

training_inputs = []
training_outputs = []
n = 150
for subject in crowdastro.data.db.radio_subjects.find({'metadata.survey': 'atlas'}).limit(n):
    hosts = potential_hosts(subject)
    
    ir = crowdastro.data.get_ir(subject)
    ir = numpy.pad(ir, NEIGHBOURHOOD_RADIUS, mode='constant')
    
    for host_y, host_x in hosts:
            ir_neighbourhood = ir[int(host_x) : int(host_x + 2 * NEIGHBOURHOOD_RADIUS),
                                  int(host_y) : int(host_y + 2 * NEIGHBOURHOOD_RADIUS)]
            training_inputs.append(ir_neighbourhood)
            training_outputs.append(ir_neighbourhood.flatten())

training_inputs = numpy.array(training_inputs)
training_outputs = numpy.array(training_outputs)

print('Found {} training inputs.'.format(len(training_inputs)))


Found 7092 training inputs.
K:\Languages\Anaconda3\lib\site-packages\astropy\io\fits\util.py:578: UserWarning: Could not find appropriate MS Visual C Runtime library or library is corrupt/misconfigured; cannot determine whether your file object was opened in append mode.  Please consider using a file object opened in write mode instead.
  'Could not find appropriate MS Visual C Runtime '

In [24]:
training_inputs = training_inputs.reshape((training_inputs.shape[0], 1, training_inputs.shape[1], training_inputs.shape[2]))

In [28]:
import IPython.display

for i in range(200):
    print(i)
    model.fit(training_inputs, training_outputs, nb_epoch=1)
#     matplotlib.pyplot.figure(figsize=(15, 15))
    for i, kernel in enumerate(encoder.get_weights()[0]):
        kernel = kernel[0]
        matplotlib.pyplot.subplot(10, 10, i + 1)
        matplotlib.pyplot.axis('off')
        matplotlib.pyplot.imshow(kernel, cmap='gray')
        matplotlib.pyplot.subplots_adjust(hspace=0, wspace=0)
    IPython.display.clear_output(wait=True)
    IPython.display.display(matplotlib.pyplot.gcf())


166
Epoch 1/1
 384/7092 [>.............................] - ETA: 115s - loss: 0.0636
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-29-99a6238ac12c> in <module>()
      3 for i in range(200):
      4     print(i)
----> 5     model.fit(training_inputs, training_outputs, nb_epoch=1)
      6 #     matplotlib.pyplot.figure(figsize=(15, 15))
      7     for i, kernel in enumerate(encoder.get_weights()[0]):

K:\Languages\Anaconda3\lib\site-packages\keras\models.py in fit(self, X, y, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle, show_accuracy, class_weight, sample_weight)
    644                          verbose=verbose, callbacks=callbacks,
    645                          val_f=val_f, val_ins=val_ins,
--> 646                          shuffle=shuffle, metrics=metrics)
    647 
    648     def predict(self, X, batch_size=128, verbose=0):

K:\Languages\Anaconda3\lib\site-packages\keras\models.py in _fit(self, f, ins, out_labels, batch_size, nb_epoch, verbose, callbacks, val_f, val_ins, shuffle, metrics)
    278                 batch_logs['size'] = len(batch_ids)
    279                 callbacks.on_batch_begin(batch_index, batch_logs)
--> 280                 outs = f(ins_batch)
    281                 if type(outs) != list:
    282                     outs = [outs]

K:\Languages\Anaconda3\lib\site-packages\keras\backend\theano_backend.py in __call__(self, inputs)
    382     def __call__(self, inputs):
    383         assert type(inputs) in {list, tuple}
--> 384         return self.function(*inputs)
    385 
    386 

K:\Languages\Anaconda3\lib\site-packages\theano\compile\function_module.py in __call__(self, *args, **kwargs)
    857         t0_fn = time.time()
    858         try:
--> 859             outputs = self.fn()
    860         except Exception:
    861             if hasattr(self.fn, 'position_of_error'):

K:\Languages\Anaconda3\lib\site-packages\theano\gof\op.py in rval(p, i, o, n)
    910             # default arguments are stored in the closure of `rval`
    911             def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 912                 r = p(n, [x[0] for x in i], o)
    913                 for o in node.outputs:
    914                     compute_map[o][0] = True

KeyboardInterrupt: 
Traceback (most recent call last):

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2232, in print_figure
    **kwargs)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py", line 527, in print_png
    FigureCanvasAgg.draw(self)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py", line 474, in draw
    self.figure.draw(self.renderer)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\artist.py", line 61, in draw_wrapper
    draw(artist, renderer, *args, **kwargs)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\figure.py", line 1159, in draw
    func(*args)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\artist.py", line 61, in draw_wrapper
    draw(artist, renderer, *args, **kwargs)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\axes\_base.py", line 2324, in draw
    a.draw(renderer)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\artist.py", line 61, in draw_wrapper
    draw(artist, renderer, *args, **kwargs)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\image.py", line 394, in draw
    renderer.draw_image(gc, l, b, im)

ValueError: object too deep for desired array


During handling of the above exception, another exception occurred:


Traceback (most recent call last):

  File "K:\Languages\Anaconda3\lib\site-packages\ipykernel\ipkernel.py", line 175, in do_execute
    shell.run_cell(code, store_history=store_history, silent=silent)

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2908, in run_cell
    self.events.trigger('post_execute')

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\events.py", line 74, in trigger
    func(*args, **kwargs)

  File "K:\Languages\Anaconda3\lib\site-packages\ipykernel\pylab\backend_inline.py", line 113, in flush_figures
    return show(True)

  File "K:\Languages\Anaconda3\lib\site-packages\ipykernel\pylab\backend_inline.py", line 36, in show
    display(figure_manager.canvas.figure)

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\display.py", line 159, in display
    format_dict, md_dict = format(obj, include=include, exclude=exclude)

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\formatters.py", line 175, in format
    data = formatter(obj)

  File "<decorator-gen-9>", line 2, in __call__

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\formatters.py", line 220, in catch_format_error
    r = method(self, *args, **kwargs)

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\formatters.py", line 337, in __call__
    return printer(obj)

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\pylabtools.py", line 207, in <lambda>
    png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))

  File "K:\Languages\Anaconda3\lib\site-packages\IPython\core\pylabtools.py", line 117, in print_figure
    fig.canvas.print_figure(bytes_io, **kw)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2237, in print_figure
    self.figure.dpi = origDPI

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\figure.py", line 410, in _set_dpi
    self.dpi_scale_trans.clear().scale(dpi, dpi)

  File "K:\Languages\Anaconda3\lib\site-packages\matplotlib\transforms.py", line 1966, in scale
    self._mtx = np.dot(scale_mtx, self._mtx)

KeyboardInterrupt

In [29]:
print([i.shape for i in encoder.get_weights()])
matplotlib.pyplot.figure(figsize=(15, 15))
for i, kernel in enumerate(encoder.get_weights()[0]):
    kernel = kernel[0]
    matplotlib.pyplot.subplot(10, 10, i + 1)
    matplotlib.pyplot.axis('off')
    matplotlib.pyplot.imshow(kernel, cmap='gray')
    matplotlib.pyplot.subplots_adjust(hspace=0, wspace=0)
matplotlib.pyplot.show()


[(100, 1, 8, 8), (100,), (40000, 256), (256,)]
ERROR! Session/line number was not unique in database. History logging moved to new session 99

In [44]:
# crowdastro.show.image(training_inputs[:1, 0])
# matplotlib.pyplot.show()
encoder.compile(optimizer='sgd', loss='mse')
im = encoder.predict(training_inputs[:1])
matplotlib.pyplot.imshow(training_inputs[:1].reshape(48, 48), cmap='gray')
matplotlib.pyplot.show()
print(im.shape)
matplotlib.pyplot.imshow(im.reshape(16, 16), cmap='gray')
matplotlib.pyplot.show()


(1, 256)

In [45]:
model.load_weights('features.h5')
model.save_weights('features.h5')


[WARNING] features.h5 already exists - overwrite? [y/n]n

In [49]:
encoder.save_weights('features_encoder.h5')

In [ ]: