Classification

In this notebook, we will train and test a classifier. We will test both logistic regression (LR) and random forests (RF).

Generate a list of objects labelled by RGZ and Norris

Let's start by getting a list of all radio objects with both RGZ and Norris labels. For any given set of Franzen objects with the same ID, only the first component will be part of RGZ, so we will only look at first components. Norris et al. may have labelled the second (or third, etc.) component, so we need to associate each Norris label with the primary component.

The output of this section is a list of table keys for objects with a Norris and RGZ label.


In [3]:
import sklearn.model_selection, numpy, astropy.io.ascii as asc

In [4]:
table = asc.read('/Users/alger/data/Crowdastro/one-table-to-rule-them-all.tbl')
# clean = numpy.array(asc.read('clean-atlas.tbl')['Clean']).astype(bool)
# clean.shape

In [5]:
primary_component_to_norris_swire = {}
primary_component_to_rgz_swire = {}
# Only includes primary components (though labels may come from any component).
primary_components_with_both_labels = set()
for row in table:# table[clean]
    primary_component_id = row['Primary Component ID (RGZ)']
    if not primary_component_id:
        # If something does not have a primary component ID, it cannot have an RGZ label, so we don't care about it!
        continue
    
    norris_swire = row['Source SWIRE (Norris)']
    rgz_swire = row['Source SWIRE Name (RGZ)']
    
    if not norris_swire or not rgz_swire or not norris_swire.startswith('SWIRE'):
        continue
    
    primary_component_to_norris_swire[primary_component_id] = norris_swire
    primary_component_to_rgz_swire[primary_component_id] = rgz_swire
    primary_components_with_both_labels.add(primary_component_id)

print(len(primary_components_with_both_labels))

keys = numpy.array([r['Key'] for r in table
        if r['Component Zooniverse ID (RGZ)'] and
        r['Component ID (Franzen)'] in primary_components_with_both_labels], dtype=int)
print(', '.join(table[keys]['Component Zooniverse ID (RGZ)']))


477
ARG0003rb2, ARG0003rfr, ARG0003r8s, ARG0003r2j, ARG0003raz, ARG0003ro4, ARG0003r8e, ARG0003r3w, ARG0003r55, ARG0003rj2, ARG0003rj6, ARG0003r8r, ARG0003r6i, ARG0003sky, ARG0003ref, ARG0003rbl, ARG0003r3k, ARG0003r88, ARG0003r38, ARG0003r29, ARG0003r93, ARG0003rbv, ARG0003r9h, ARG0003rgq, ARG0003r8u, ARG0003rbk, ARG0003rul, ARG0003rfd, ARG0003r3a, ARG0003rie, ARG0003r4g, ARG0003r7c, ARG0003r5o, ARG0003ra5, ARG0003rbe, ARG0003r2i, ARG0003r5p, ARG0003rdd, ARG0003r92, ARG0003r4w, ARG0003r7p, ARG0003r24, ARG0003r7n, ARG0003rcv, ARG0003raf, ARG0003r42, ARG0003ras, ARG0003rb7, ARG0003r4z, ARG0003r7w, ARG0003r9b, ARG0003r2a, ARG0003r17, ARG0003rea, ARG0003r3i, ARG0003r3q, ARG0003r4n, ARG0003r2f, ARG0003r5j, ARG0003ra1, ARG0003r6c, ARG0003r22, ARG0003rcf, ARG0003rfi, ARG0003r87, ARG0003r8b, ARG0003rbw, ARG0003r4m, ARG0003r5r, ARG0003r52, ARG0003r54, ARG0003r5x, ARG0003rfa, ARG0003r7j, ARG0003rdg, ARG0003r4r, ARG0003r8n, ARG0003r6v, ARG0003r3s, ARG0003r64, ARG0003r8z, ARG0003rce, ARG0003r6g, ARG0003r6h, ARG0003r5u, ARG0003r2z, ARG0003r2d, ARG0003r4s, ARG0003raw, ARG0003r2h, ARG0003rcb, ARG0003r3j, ARG0003ri1, ARG0003r8o, ARG0003r97, ARG0003r7i, ARG0003rg0, ARG0003r65, ARG0003r6o, ARG0003r3f, ARG0003rvy, ARG0003rdf, ARG0003r72, ARG0003rec, ARG0003r45, ARG0003raq, ARG0003r76, ARG0003r75, ARG0003r7a, ARG0003r4f, ARG0003r4a, ARG0003r8p, ARG0003rdv, ARG0003r1e, ARG0003rc0, ARG0003r8l, ARG0003r62, ARG0003rcg, ARG0003rdt, ARG0003r49, ARG0003r8m, ARG0003r8k, ARG0003r7e, ARG0003r6m, ARG0003r8w, ARG0003rci, ARG0003r2n, ARG0003r6s, ARG0003r1z, ARG0003r5g, ARG0003r6w, ARG0003r6q, ARG0003r1k, ARG0003r27, ARG0003r6a, ARG0003r1r, ARG0003r8c, ARG0003ra6, ARG0003r7s, ARG0003rak, ARG0003rb4, ARG0003r9a, ARG0003r21, ARG0003rbi, ARG0003rk0, ARG0003rdm, ARG0003rif, ARG0003r6n, ARG0003rbt, ARG0003r3o, ARG0003reo, ARG0003r96, ARG0003r4o, ARG0003rar, ARG0003r1p, ARG0003rmc, ARG0003r3e, ARG0003r5v, ARG0003r9n, ARG0003r7q, ARG0003rq1, ARG0003r5y, ARG0003rex, ARG0003r2y, ARG0003r5c, ARG0003r4p, ARG0003rfc, ARG0003rd3, ARG0003r1o, ARG0003r9y, ARG0003r31, ARG0003r5b, ARG0003r9r, ARG0003rd0, ARG0003rfh, ARG0003req, ARG0003rax, ARG0003r8y, ARG0003rdu, ARG0003r1l, ARG0003rbd, ARG0003r8q, ARG0003r1j, ARG0003r6j, ARG0003r9c, ARG0003r2b, ARG0003r94, ARG0003r4j, ARG0003r9z, ARG0003r6e, ARG0003roz, ARG0003rwe, ARG0003rap, ARG0003r9p, ARG0003s0y, ARG0003r81, ARG0003r3m, ARG0003r5t, ARG0003reu, ARG0003rik, ARG0003rlt, ARG0003r73, ARG0003r26, ARG0003r51, ARG0003r56, ARG0003r74, ARG0003r36, ARG0003r98, ARG0003rb9, ARG0003re1, ARG0003r84, ARG0003rho, ARG0003r6t, ARG0003rh3, ARG0003r2s, ARG0003rac, ARG0003rlf, ARG0003rbn, ARG0003r5q, ARG0003rbj, ARG0003r7z, ARG0003r3r, ARG0003r9q, ARG0003rki, ARG0003rhq, ARG0003r6x, ARG0003rwy, ARG0003r1f, ARG0003rvu, ARG0003rro, ARG0003rae, ARG0003rm1, ARG0003red, ARG0003rs8, ARG0003r7r, ARG0003rm0, ARG0003rfo, ARG0003r4t, ARG0003rj5, ARG0003riy, ARG0003rjn, ARG0003r3v, ARG0003r4x, ARG0003ra4, ARG0003rbs, ARG0003rby, ARG0003r8g, ARG0003ryh, ARG0003r9v, ARG0003rs0, ARG0003r41, ARG0003r43, ARG0003rbc, ARG0003rcn, ARG0003r3x, ARG0003rhx, ARG0003rbo, ARG0003r5s, ARG0003rlp, ARG0003rju, ARG0003rdl, ARG0003r3b, ARG0003r6l, ARG0003r71, ARG0003r7f, ARG0003rsi, ARG0003r5d, ARG0003rn4, ARG0003r9o, ARG0003r2m, ARG0003s24, ARG0003r4e, ARG0003rk2, ARG0003r4u, ARG0003r3h, ARG0003r2u, ARG0003rdi, ARG0003rbp, ARG0003rcr, ARG0003r3l, ARG0003rl8, ARG0003r95, ARG0003rs2, ARG0003rbq, ARG0003rkz, ARG0003rve, ARG0003r5w, ARG0003sgb, ARG0003s3k, ARG0003rqy, ARG0003rko, ARG0003ra0, ARG0003rjm, ARG0003rtj, ARG0003r5m, ARG0003rb8, ARG0003r57, ARG0003s8d, ARG0003se7, ARG0003rd4, ARG0003rd1, ARG0003rsz, ARG0003r1h, ARG0003rb6, ARG0003rgw, ARG0003r40, ARG0003r83, ARG0003rpe, ARG0003rq0, ARG0003r7v, ARG0003riw, ARG0003r30, ARG0003rsm, ARG0003rf3, ARG0003r7m, ARG0003r8a, ARG0003r6b, ARG0003ral, ARG0003r23, ARG0003sc0, ARG0003sms, ARG0003rcp, ARG0003rl0, ARG0003rk5, ARG0003rot, ARG0003rf9, ARG0003r6u, ARG0003rfq, ARG0003r1m, ARG0003r8h, ARG0003rgz, ARG0003rus, ARG0003riz, ARG0003r8j, ARG0003rba, ARG0003r8i, ARG0003r8v, ARG0003r47, ARG0003rjk, ARG0003r6p, ARG0003rjd, ARG0003rqp, ARG0003rtt, ARG0003r35, ARG0003rfe, ARG0003rcy, ARG0003r9w, ARG0003r48, ARG0003r1q, ARG0003rb1, ARG0003rmo, ARG0003rcd, ARG0003r8x, ARG0003s1f, ARG0003rg2, ARG0003r2o, ARG0003rd2, ARG0003rgh, ARG0003r4h, ARG0003rm8, ARG0003sw7, ARG0003r3g, ARG0003ran, ARG0003rqj, ARG0003rcz, ARG0003rkt, ARG0003rii, ARG0003rmi, ARG0003r3p, ARG0003rgi, ARG0003r69, ARG0003r6k, ARG0003rfx, ARG0003rjo, ARG0003r9f, ARG0003r9i, ARG0003rel, ARG0003raj, ARG0003re8, ARG0003r80, ARG0003r6y, ARG0003rrk, ARG0003rme, ARG0003r5h, ARG0003r6d, ARG0003r5l, ARG0003rf1, ARG0003r7l, ARG0003rc7, ARG0003r2v, ARG0003rhk, ARG0003rw0, ARG0003rcx, ARG0003rah, ARG0003r3t, ARG0003r9u, ARG0003r2e, ARG0003rdh, ARG0003r2p, ARG0003r18, ARG0003rng, ARG0003r44, ARG0003r4q, ARG0003r4b, ARG0003r3u, ARG0003r90, ARG0003rab, ARG0003r8t, ARG0003r1t, ARG0003rtx, ARG0003r1s, ARG0003rey, ARG0003rbb, ARG0003r7h, ARG0003r2t, ARG0003r7x, ARG0003r4y, ARG0003rai, ARG0003r9k, ARG0003r32, ARG0003r2c, ARG0003rmn, ARG0003r8d, ARG0003rco, ARG0003ren, ARG0003r34, ARG0003r85, ARG0003rcq, ARG0003r79, ARG0003r78, ARG0003reb, ARG0003r67, ARG0003r53, ARG0003rfw, ARG0003r5e, ARG0003ra3, ARG0003rg7, ARG0003rc6, ARG0003r89, ARG0003r4l, ARG0003rk8, ARG0003r39, ARG0003rgg, ARG0003rgb, ARG0003rc8, ARG0003r3z, ARG0003r5z, ARG0003rd9, ARG0003r5f, ARG0003r1n, ARG0003r1y, ARG0003r50, ARG0003r9d, ARG0003r1d, ARG0003rbr, ARG0003ra9, ARG0003rbx, ARG0003ra7, ARG0003r59, ARG0003r6f, ARG0003rdx, ARG0003r1c, ARG0003rc9, ARG0003r1a, ARG0003r1b, ARG0003r1i, ARG0003r1w, ARG0003r1x, ARG0003r25, ARG0003r2k, ARG0003r2x, ARG0003r33, ARG0003r3d, ARG0003r4d, ARG0003r4v, ARG0003r5k, ARG0003r68

Generate training/testing subsets

We want to use 5-fold cross-validation. We will do this cross-validation for the radio objects, not the IR objects we are actually classifying — this is to avoid overlapping image data; SWIRE objects associated with the same radio object break the independence assumption.


In [6]:
kf = sklearn.model_selection.KFold(n_splits=5, shuffle=True)
# List of lists of one-table keys.
sets = [[keys[s] for s in f] for f in kf.split(keys)]

Associate SWIRE objects with each set

These sets correspond to ATLAS objects; we want SWIRE objects to classify. We will convert each of these sets into sets of SWIRE indices.


In [8]:
import h5py
crowdastro_f = h5py.File('/Users/alger/data/Crowdastro/crowdastro-swire.h5', 'r')

In [9]:
import scipy.spatial
swire_coords = crowdastro_f['/swire/cdfs/numeric'][:, :2]
swire_tree = scipy.spatial.KDTree(swire_coords)
swire_names = crowdastro_f['/swire/cdfs/string'].value
name_to_crowdastro = {j.decode('ascii'):i for i, j in enumerate(swire_names)}

In [10]:
swire_sets = []
norris_label_sets = []
rgz_label_sets = []
for split in sets:
    swire_split = []
    for ts in split:
        # ts is a list of keys.
        locs = numpy.array([(i[0], i[1]) for i in table['Component RA (Franzen)', 'Component DEC (Franzen)'][ts]])
        nearby_swire_indices = sorted(set(numpy.concatenate(swire_tree.query_ball_point(locs, 1 / 60))))  # 1 arcmin
        swire_split.append(nearby_swire_indices)

    swire_sets.append(swire_split)

Generate Features for Each SWIRE Object

For each SWIRE object in crowdastro $\cap$ the training/testing sets, we need the distance to the closest radio object, the stellarity, the band 1 - 2 and band 2 - 3 magnitude differences, and an image. The distances and images are generated by crowdastro generate_training_data; the magnitude differences and stellarity can be pulled from the SWIRE catalogue available here.


In [14]:
training_f = h5py.File('/Users/alger/data/Crowdastro/training-swire.h5', 'r')
swire_distances = training_f['raw_features'][:, 8]
swire_images = training_f['raw_features'][:, 9:]
assert swire_images.shape[1] == 32 * 32

In [31]:
# AstroPy can't deal with a large file like the SWIRE file, so we have to do this line-by-line...
SPITZER_SENSITIVITIES = {
    36: 7.3,
    45: 9.7,
    58: 27.5,
    80: 32.5,
    24: 450,
}
headers = []
swire_features = numpy.zeros((len(swire_coords),
                              6 +  # Magnitude differences
                              1 +  # S_3.6
                              2 +  # Stellarities
                              1 +  # Distances
                              32 * 32  # Image
                             ))
for row_num, line in enumerate(open('/Users/alger/data/SWIRE/SWIRE3_CDFS_cat_IRAC24_21Dec05.tbl')):
    if line.startswith('\\'):
        continue
    
    if line.startswith('|') and not headers:
        headers.extend(map(str.strip, line.split('|')[1:-1]))
        lengths = list(map(len, headers))
        continue
    
    if line.startswith('|'):
        continue
    
    line = dict(zip(headers, line.split()))
    
    name = line['object']
    if name not in name_to_crowdastro:
        continue  # Skip non-crowdastro SWIRE.

    crowdastro_index = name_to_crowdastro[name]

    fluxes = []
    for s in [36, 45, 58, 80]:
        aps = []
        for ap in range(1, 6):
            v = line['flux_ap{}_{}'.format(ap, s)]
            try:
                v = float(v)
                if v != -99.0:
                    aps.append(v)
            except:
                pass
        
        if aps:
            fluxes.append(numpy.mean(aps))
        else:
            fluxes.append(SPITZER_SENSITIVITIES[s])  # 5 sigma is an upper-bound for flux in each band.
    mags = [numpy.log10(s) for s in fluxes]
    mag_diffs = [mags[0] - mags[1], mags[0] - mags[2], mags[0] - mags[3],
                 mags[1] - mags[2], mags[1] - mags[3],
                 mags[2] - mags[3]]
    # Guaranteed a stellarity in the first two bands; not so much in the others.
    stellarities_ = [line['stell_{}'.format(s)] for s in [36, 45]]
    stellarities = []
    for s in stellarities_:
        if s != 'null' and s != '-9.00':
            stellarities.append(float(s))
        else:
            stellarities.append(float('nan'))
    # We will have nan stellarities - but we will replace those with the mean later.
    features = numpy.concatenate([
        mag_diffs,
        mags[:1],
        stellarities,
        [swire_distances[crowdastro_index]],
        swire_images[crowdastro_index],
    ])
    swire_features[crowdastro_index] = features

In [32]:
# Set nans to the mean.
for feature in range(swire_features.shape[1]):
    nan = numpy.isnan(swire_features[:, feature])
    swire_features[:, feature][nan] = swire_features[:, feature][~nan].mean()

In [37]:
import seaborn, matplotlib.pyplot as plt
plt.figure(figsize=(15, 8))
feature_names = ['$[3.6] - [4.5]$', '$[3.6] - [5.8]$', '$[3.6] - [8.0]$',
                 '$[4.5] - [5.8]$', '$[4.5] - [8.0]$',
                 '$[5.8] - [8.0]$', '$\log_{10} S_{3.6}$',
                 'Stellarity (3.6)', 'Stellarity (4.5)', 'Distance']
for feature in range(10):
    plt.subplot(2, 5, feature + 1)
    plt.title(feature_names[feature])
    seaborn.distplot(swire_features[:, feature])
plt.subplots_adjust(hspace=0.4)
plt.show()



In [34]:
# Normalise and centre the features.
swire_features -= swire_features.mean(axis=0)
swire_features /= swire_features.std(axis=0)

In [40]:
plt.figure(figsize=(15, 8))
feature_names = ['$[3.6] - [4.5]$', '$[3.6] - [5.8]$', '$[3.6] - [8.0]$',
                 '$[4.5] - [5.8]$', '$[4.5] - [8.0]$',
                 '$[5.8] - [8.0]$', '$\log_{10} S_{3.6}$',
                 'Stellarity (3.6)', 'Stellarity (4.5)', 'Distance']
for feature in range(10):
    plt.subplot(2, 5, feature + 1)
    plt.title(feature_names[feature])
    seaborn.distplot(swire_features[:, feature])
plt.subplots_adjust(hspace=0.4)
plt.show()


Generate labels for each SWIRE object


In [41]:
swire_norris_labels = numpy.zeros((len(swire_coords),), dtype=bool)
swire_rgz_labels = numpy.zeros((len(swire_coords),), dtype=bool)

In [42]:
import astropy.coordinates, re
for row in table:
    n = row['Source SWIRE (Norris)']
    if n and n.startswith('SWIRE'):
        if n in name_to_crowdastro:
            index = name_to_crowdastro[n]
            swire_norris_labels[index] = True
        else:
            m = re.match(r'SWIRE3_J(\d\d)(\d\d)(\d\d\.\d\d)(-\d\d)(\d\d)(\d\d\.\d)', n)
            ra, dec = ' '.join(m.groups()[:3]), ' '.join(m.groups()[3:])
            sc = astropy.coordinates.SkyCoord(ra=ra, dec=dec, unit=('hourangle', 'deg'))
            coord = (sc.ra.deg, sc.dec.deg)
            dist, index = swire_tree.query(coord)
            if dist < 5 / 60 / 60:
                swire_norris_labels[index] = True
    
    n = row['Source SWIRE Name (RGZ)']
    if n:
        index = name_to_crowdastro[n]
        swire_rgz_labels[index] = True

In [43]:
swire_norris_labels.sum(), swire_rgz_labels.sum()


Out[43]:
(507, 2211)

Experiment: Logistic regression

In this section, we will run logistic regression trained on RGZ/Norris and tested on RGZ/Norris.


In [44]:
import sklearn.linear_model, crowdastro.crowd.util, itertools
bas = {'RGZ': {'RGZ': [], 'Norris': []}, 'Norris': {'RGZ': [], 'Norris': []}}
for (train_name, train_labels), (test_name, test_labels) in itertools.product(
        [('RGZ', swire_rgz_labels), ('Norris', swire_norris_labels)], repeat=2):
    for train, test in swire_sets:
        X_train = swire_features[train, :]
        X_test = swire_features[test, :]
        T_train = train_labels[train]
        T_test = test_labels[test]
        lr = sklearn.linear_model.LogisticRegression(class_weight='balanced', penalty='l1')
        lr.fit(X_train, T_train)
        preds = lr.predict(X_test)
        ba = crowdastro.crowd.util.balanced_accuracy(T_test, preds)
        bas[train_name][test_name].append(ba)

In [45]:
for tr in ['RGZ', 'Norris']:
    for te in ['RGZ', 'Norris']:
        print('LR({:^6}) vs {:^6}: ({:.02f} +- {:.02f})%'.format(
            tr, te,
            numpy.mean(bas[tr][te]) * 100,
            numpy.std(bas[tr][te]) * 100))


LR( RGZ  ) vs  RGZ  : (93.34 +- 1.74)%
LR( RGZ  ) vs Norris: (94.14 +- 1.50)%
LR(Norris) vs  RGZ  : (93.92 +- 1.33)%
LR(Norris) vs Norris: (96.13 +- 1.63)%

In [46]:
import astropy.table
def plot_bas(bas):
    dx = []
    dy = []
    dh = []
    for train in bas:
        for test in bas[train]:
            for i in bas[train][test]:
                dx.append(train)
                dh.append(test)
                dy.append(i)
    data = astropy.table.Table(data=[dx, dy, dh], names=('train', 'BA', 'test')).to_pandas()
    ax = plt.figure(figsize=(15, 7))
    vp = seaborn.violinplot(
        scale='width',
        orient='v',
        x='train',
        y='BA',
        hue='test',
        data=data,
        split=True)
    plt.show()

plot_bas(bas)


Experiment: Random forests

Same as above, with random forests.


In [52]:
import sklearn.ensemble
bas_rf = {'RGZ': {'RGZ': [], 'Norris': []}, 'Norris': {'RGZ': [], 'Norris': []}}
for (train_name, train_labels), (test_name, test_labels) in itertools.product(
        [('RGZ', swire_rgz_labels), ('Norris', swire_norris_labels)], repeat=2):
    for train, test in swire_sets:
        X_train = swire_features[train, :]
        X_test = swire_features[test, :]
        T_train = train_labels[train]
        T_test = test_labels[test]
        rf = sklearn.ensemble.RandomForestClassifier(class_weight='balanced', criterion='entropy',
                                                     min_samples_leaf=40)
        rf.fit(X_train, T_train)
        preds = rf.predict(X_test)
        ba = crowdastro.crowd.util.balanced_accuracy(T_test, preds)
        bas_rf[train_name][test_name].append(ba)

In [53]:
for tr in ['RGZ', 'Norris']:
    for te in ['RGZ', 'Norris']:
        print('RF({:^6}) vs {:^6}: ({:.02f} +- {:.02f})%'.format(
            tr, te,
            numpy.mean(bas_rf[tr][te]) * 100,
            numpy.std(bas_rf[tr][te]) * 100))


RF( RGZ  ) vs  RGZ  : (94.24 +- 1.50)%
RF( RGZ  ) vs Norris: (96.46 +- 0.67)%
RF(Norris) vs  RGZ  : (89.42 +- 2.09)%
RF(Norris) vs Norris: (96.92 +- 0.92)%

In [54]:
plot_bas(bas_rf)


Experiment: Feature ablation

We will now repeat this experiment with different subsets of features to determine which features are most useful in making our predictions.

In particular, we expect distance to be the most important predictor by far — for compact objects, the centre of the Gaussian fit to the radio object will most likely be the location of the host galaxy.

We will use Norris labels for both training and testing.

We will test subsets where we remove:

  • Distance
  • Magnitude differences
  • Stellarity
  • Image
  • All combinations thereof.

In [55]:
def powerset(iterable: [1,2,3]) -> [(), (1,), (2,), (3,), (1,2), (1,3), (2,3), (1,2,3)]:
    s = list(iterable)
    return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s)+1))

In [56]:
featuresets = {
    'colour': swire_features[:, :2],
    'stellarity': swire_features[:, 2:4],
    'distance': swire_features[:, 4:5],
    'image': swire_features[:, 5:],
}

import collections
bas_ablation = collections.defaultdict(list)  # Maps features -> balanced accuracies.

for i in powerset(['distance', 'colour', 'stellarity', 'image']):
    if not i:
        continue
    
    print('Testing features:', ', '.join(i))
    for train, test in swire_sets:
        this_featureset = numpy.concatenate([featuresets[j] for j in i], axis=1)
        X_train = this_featureset[train, :]
        X_test = this_featureset[test, :]
        T_train = swire_norris_labels[train]
        T_test = swire_norris_labels[test]
        lr = sklearn.linear_model.LogisticRegression(class_weight='balanced', penalty='l1')
        lr.fit(X_train, T_train)
        preds = lr.predict(X_test)
        ba = crowdastro.crowd.util.balanced_accuracy(T_test, preds)
        bas_ablation[i].append(ba)


Testing features: distance
Testing features: colour
Testing features: stellarity
Testing features: image
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-56-c3cd0a0e9318> in <module>()
     21         T_test = swire_norris_labels[test]
     22         lr = sklearn.linear_model.LogisticRegression(class_weight='balanced', penalty='l1')
---> 23         lr.fit(X_train, T_train)
     24         preds = lr.predict(X_test)
     25         ba = crowdastro.crowd.util.balanced_accuracy(T_test, preds)

/usr/local/lib/python3.6/site-packages/sklearn/linear_model/logistic.py in fit(self, X, y, sample_weight)
   1184                 self.class_weight, self.penalty, self.dual, self.verbose,
   1185                 self.max_iter, self.tol, self.random_state,
-> 1186                 sample_weight=sample_weight)
   1187             self.n_iter_ = np.array([n_iter_])
   1188             return self

/usr/local/lib/python3.6/site-packages/sklearn/svm/base.py in _fit_liblinear(X, y, C, fit_intercept, intercept_scaling, class_weight, penalty, dual, verbose, max_iter, tol, random_state, multi_class, loss, epsilon, sample_weight)
    910         X, y_ind, sp.isspmatrix(X), solver_type, tol, bias, C,
    911         class_weight_, max_iter, rnd.randint(np.iinfo('i').max),
--> 912         epsilon, sample_weight)
    913     # Regarding rnd.randint(..) in the above signature:
    914     # seed for srand in range [0..INT_MAX); due to limitations in Numpy

KeyboardInterrupt: 

In [ ]:
# for i in powerset(['distance', 'colour', 'stellarity', 'image']):
#     if not i:
#         continue

#     print('{:<40}: ({:.02f} +- {:.02f})%'.format(
#         ', '.join(i),
#         numpy.mean(bas_ablation[i]) * 100,
#         numpy.std(bas_ablation[i]) * 100))
ax = plt.figure(figsize=(15, 7))
vp = seaborn.violinplot(
    scale='width',
    orient='v',
    data=[bas_ablation[i] for i in sorted(bas_ablation.keys())])
vp.set_xticklabels(sorted([', '.join(k) for k in bas_ablation.keys()]), rotation='vertical')
plt.show()

Export

We will now export our sets, labels, and features.


In [58]:
import h5py

# Convert SWIRE/ATLAS sets into boolean arrays.
n_swire = len(swire_coords)
n_atlas = crowdastro_f['/atlas/cdfs/numeric'].shape[0]
swire_sets_train_bool = numpy.zeros((5, n_swire), dtype=bool)
swire_sets_test_bool = numpy.zeros((5, n_swire), dtype=bool)
for k in range(5):
    swire_sets_train_bool[k, swire_sets[k][0]] = True
    swire_sets_test_bool[k, swire_sets[k][1]] = True

with h5py.File('/Users/alger/data/Crowdastro/swire_11_05_17.h5', 'w') as f:
    f.create_dataset('features', data=swire_features)
    f.create_dataset('rgz_labels', data=swire_rgz_labels)
    f.create_dataset('norris_labels', data=swire_norris_labels)
    f.create_dataset('clean_swire_sets_train', data=swire_sets_train_bool)
    f.create_dataset('clean_swire_sets_test', data=swire_sets_test_bool)

In [ ]: