In [1]:
%run setup.ipynb
%matplotlib inline
In [2]:
@functools.lru_cache(maxsize=None)
def compute_information_gain(start_index=0, stop_index=200000):
# # load the data on cluster assignments
# import pickle
# with open('../data/clust_dict.pickle', mode='rb') as f:
# clust_dict = pickle.load(f)
# # define the classes - 'WT' means any susceptible
# classes = ['WT'] + sorted(clust_dict)
# n_classes = len(classes)
#let's try loading the cluster assignments another way
# use the network membership to define haplotype groups
vgsc_clusters = np.load('../data/median_joining_network_membership.npy').astype('U')
clust_dict = {(l if l else 'wt'): set(np.nonzero(vgsc_clusters == l)[0])
for l in np.unique(vgsc_clusters)}
# merge the "other resistant" groups
clust_dict['other_resistant'] = clust_dict['FX'] | clust_dict['SX']
del clust_dict['FX']
del clust_dict['SX']
#define classes ??
classes = sorted(clust_dict)
n_classes = len(classes)
# load haplotypes
callset_haps = np.load('../data/haps_phase1.npz')
haps = allel.HaplotypeArray(callset_haps['haplotypes'])[start_index:stop_index]
pos = allel.SortedIndex(callset_haps['POS'])[start_index:stop_index]
n_haps = haps.shape[1]
# set up target attribute
target_attr = np.zeros(n_haps, dtype=int)
for i, cls in enumerate(classes):
if i > 0:
hap_indices = sorted(clust_dict[cls])
target_attr[hap_indices] = i
# compute entropy for the target attribute
target_freqs = np.bincount(target_attr, minlength=n_classes) / target_attr.shape[0]
target_entropy = scipy.stats.entropy(target_freqs)
# setup output array
gain = np.zeros(pos.shape[0])
# work through the variants one by one
for i in range(pos.shape[0]):
# pull out the attribute data
attr = haps[i]
# split on attribute value and compute entropies for each split
split_entropy = 0
for v in 0, 1, 2:
split = target_attr[attr == v]
if split.shape[0] == 0:
continue
split_freqs = np.bincount(split, minlength=len(classes)) / split.shape[0]
split_entropy += (split.shape[0] / n_haps) * scipy.stats.entropy(split_freqs)
# compute and store gain
gain[i] = target_entropy - split_entropy
return gain, pos, haps, target_attr
In [3]:
def plot_information_gain(start=None, stop=None, ax=None):
fig = None
if ax is None:
fig, ax = plt.subplots(figsize=(10, 3))
sns.despine(ax=ax, offset=5)
gain, pos, _, _ = compute_information_gain()
ax.plot(pos, gain, marker='o', linestyle=' ', mfc='none', mec='k', markersize=2)
ax.set_xlabel('Position (bp)')
ax.set_ylabel('Information gain')
ax.set_xlim(start, stop)
ax.set_ylim(bottom=0)
if fig:
fig.tight_layout()
plot_information_gain()
In [4]:
gene_labels
Out[4]:
In [5]:
sns.set_style('white')
sns.set_style('ticks')
In [6]:
def fig_information_gain(start=int(1.3e6), stop=int(3.5e6)):
# setup figure
fig = plt.figure(figsize=(8, 3), dpi=150)
gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[6, 1])
# plot information gain
ax = fig.add_subplot(gs[0])
sns.despine(ax=ax, offset=5, bottom=True)
plot_information_gain(start, stop, ax)
ax.axvline(region_vgsc.start, zorder=-20, color='#aaaaaa', linestyle='--')
ax.axvline(region_vgsc.end, zorder=-20, color='#aaaaaa', linestyle='--')
ax.set_xticks([])
ax.set_xlabel('')
# plot genes
ax = fig.add_subplot(gs[1])
sns.despine(ax=ax, offset=5)
plot_genes(phase1_ar3.genome, phase1_ar3.geneset_agamp42_fn, chrom='2L', start=start, stop=stop, ax=ax,
label=True, labels={'AGAP004707': 'Vgsc'}, label_unnamed=False)
ax.set_xlim(start, stop)
ax.set_xlabel('Chromosome 2L position (bp)')
ax.set_ylabel('Genes', rotation=0)
fig.suptitle('a', fontweight='bold', x=0, y=1)
fig.tight_layout()
fig.savefig('../artwork/info_gain.png', dpi=150, bbox_inches='tight')
fig_information_gain()
In [7]:
import sklearn.tree
In [8]:
@functools.lru_cache(maxsize=None)
def eval_trees(start, stop, max_depths=tuple(range(2, 11)), min_samples_leaf=5, criterion='entropy', n_splits=10, random_state=42):
# setup data
gain, pos, haps, target = compute_information_gain()
loc = pos.locate_range(start, stop)
data = haps[loc].T
# setup cross-validation
skf = sklearn.model_selection.StratifiedKFold(n_splits=n_splits, random_state=random_state)
# setup outputs
scores = []
n_features = []
depths = []
# interate with increasing maximum depth
for max_depth in max_depths:
# setup the classifier
clf = sklearn.tree.DecisionTreeClassifier(criterion=criterion, max_depth=max_depth, min_samples_leaf=min_samples_leaf, random_state=random_state)
# do cross-validation
for train_index, test_index in skf.split(data, target):
# split the data
data_train, data_test = data[train_index], data[test_index]
target_train, target_test = target[train_index], target[test_index]
# fit the model
clf.fit(data_train, target_train)
# score the model
scores.append(clf.score(data_test, target_test))
# store depth and number of features
depths.append(max_depth)
n_features.append(np.count_nonzero(clf.feature_importances_))
assert np.count_nonzero(clf.feature_importances_) == len(set(clf.tree_.feature[clf.tree_.feature >= 0]))
scores = np.array(scores)
n_features = np.array(n_features)
depths = np.array(depths)
return scores, n_features, depths
In [9]:
def repeat_eval_trees(start, stop, max_depths=tuple(range(2, 11)), min_samples_leaf=5, criterion='entropy', n_splits=10, n_reps=10):
scores = []
n_features = []
depths = []
for i in range(n_reps):
s, f, d = eval_trees(start, stop, max_depths=max_depths, min_samples_leaf=min_samples_leaf, criterion=criterion, n_splits=n_splits, random_state=i)
scores.extend(s)
n_features.extend(f)
depths.extend(d)
scores = np.array(scores)
n_features = np.array(n_features)
depths = np.array(depths)
return scores, n_features, depths
In [10]:
#sns.set_style('darkgrid')
In [11]:
def plot_cv_score(buffer, ax=None, **kwargs):
if ax is None:
fig, ax = plt.subplots()
scores, n_features, depths = repeat_eval_trees(start=region_vgsc.start - buffer, stop=region_vgsc.end + buffer, **kwargs)
ax.plot(n_features, scores, marker='o', mfc='none', mec='k', linestyle=' ', markersize=4)
ax.set_xlabel('No. SNPs in decision tree')
ax.set_ylabel('Cross-validation score')
ax.set_xlim(0, 30)
ax.set_ylim(top=1)
In [12]:
plot_cv_score(20000, criterion='entropy')
In [13]:
plot_cv_score(20000, criterion='gini')
In [14]:
def fig_cv_score(buffer=20000):
fig, axs = plt.subplots(1, 2, sharey=True, figsize=(8, 3), dpi=150)
ax = axs[0]
plot_cv_score(buffer, criterion='entropy', ax=ax)
ax.set_title('LD3')
ax.grid(axis='both')
ax = axs[1]
plot_cv_score(buffer, criterion='gini', ax=ax)
ax.set_title('CART')
ax.set_ylabel('')
ax.grid(axis='both')
fig.suptitle('b', fontweight='bold', x=0, y=1)
fig.tight_layout()
fig.savefig('../artwork/tree_cv.png', bbox_inches='tight', dpi=150)
fig_cv_score()
In [ ]: