In [5]:
import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
from collections import Counter
import collections
import random
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import Bio
from Bio import SeqIO
import os
import concurrent.futures
import functools
from functools import partial
import math
import threading
import time
import random
from random import shuffle
import pickle
import tempfile
import ntpath
import os.path

# k-mer size to use
k = 9

#
# NOTE!!!!!!!!!!!!!!!!
#
# We can reduce problem space if we get the reverse complement, and add a bit to indicate reversed or not...
# Not really.... revcomp just doubles it back up again....
#
# Also -- Build a recurrent network to predict sequences that come after a given kmer?
# Look at word2vec, dna2vec, bag of words, skip-gram
#

# Problem space
space = 5 ** k

def partition(n, step, coll):
    for i in range(0, len(coll), step):
        if (i+n > len(coll)):
            break #  raise StopIteration...
        yield coll[i:i+n]
        
def get_kmers(k):
    return lambda sequence: partition(k, k, sequence)

def convert_nt(c):
    return {"N": 0, "A": 1, "C": 2, "T": 3, "G": 4}.get(c, 0)

def convert_nt_complement(c):
    return {"N": 0, "A": 3, "C": 4, "T": 1, "G": 2}.get(c, 0)

def convert_kmer_to_int(kmer):
    return int(''.join(str(x) for x in (map(convert_nt, kmer))), 5)

def convert_kmer_to_int_complement(kmer):
    return int(''.join(str(x) for x in reversed(list(map(convert_nt_complement, kmer)))), 5)

def convert_base5(n):
    return {"0": "N", "1": "A", "2": "C", "3": "T", "4": "G"}.get(n,"N")

def convert_to_kmer(kmer):
    return ''.join(map(convert_base5, str(np.base_repr(kmer, 5))))

# Not using sparse tensors anymore.
   
tf.logging.set_verbosity(tf.logging.INFO)

# Get all kmers, in order, with a sliding window of k (but sliding 1bp for each iteration up to k)
# Also get RC for all....

def kmer_processor(seq,offset):
    return list(map(convert_kmer_to_int, get_kmers(k)(seq[offset:])))

def get_kmers_from_seq(sequence):
    kmers_from_seq = list()

    kp = functools.partial(kmer_processor, sequence)
    
    for i in map(kp, range(0,k)):
        kmers_from_seq.append(i)

    rev = sequence[::-1]
    kpr = functools.partial(kmer_processor, rev)
    
    for i in map(kpr, range(0,k)):
        kmers_from_seq.append(i)
            
#    for i in range(0,k):
#        kmers_from_seq.append(kmer_processor(sequence,i))
#    for i in range(0,k):
#        kmers_from_seq.append(kmer_processor(rev, i))
    return kmers_from_seq

data = list()

def load_fasta(filename):
    data = dict()
    file_base_name = ntpath.basename(filename)
    picklefilename = file_base_name + ".picklepickle"
    if os.path.isfile(picklefilename):
        print("Loading from pickle")
        data = pickle.load(open(picklefilename, "rb"))
    else:
        print("File not found, generating new sequence: " + picklefilename)
        for seq_record in SeqIO.parse(filename, "fasta"):
            data.update({seq_record.id:
                         get_kmers_from_seq(seq_record.seq.upper())})
        pickle.dump(data, open(picklefilename, "wb"))
    return(data)
        
def get_kmers_from_file(filename):
    kmer_list = list()
    for seq_record in SeqIO.parse(filename, "fasta"):
        kmer_list.extend(get_kmers_from_seq(seq_record.seq.upper()))
    return set([item for sublist in kmer_list for item in sublist])

all_kmers = set()

# Very slow, should make this part concurrent...

def find_all_kmers(directory):
    kmer_master_list = list()
    files = [directory + "/" + f for f in os.listdir(directory)]
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        for i in executor.map(get_kmers_from_file, files):
            kmer_master_list.extend(list(i))
            kmer_master_list = list(set(kmer_master_list))
            print("Total unique kmers: " + str(len(set(kmer_master_list))))
    return set(kmer_master_list)

def get_categories(directory):
    data = list()
    files = os.listdir(directory)
    for filename in files:
        for seq_record in SeqIO.parse(directory + "/" + filename, "fasta"):
            data.append(seq_record.id)
    data = sorted(list(set(data)))
    return(data)

In [ ]:
filegen = training_file_generator("training-files/")
training_data = load_fasta(filegen())

In [ ]:


In [ ]:


In [ ]:
# Because this was run at work on a smaller sample of files....
# with open("all_kmers_subset.txt", "w") as f:
#     for s in all_kmers:
#         f.write(str(s) +"\n")

# Because this was run at work on a smaller sample of files....
all_kmers = list()
# with open("all_kmers_subset.txt", "r") as f:
#     for line in f:
#         all_kmers.append(int(line.strip()))

all_kmers = pickle.load(open("all_kmers.p", "rb"))

all_kmers = set(all_kmers)
len(all_kmers)
# len(data)

# all_kmers = set([item for sublist in data for item in sublist])
unused_kmers = set(range(0, space)) - all_kmers

kmer_dict = dict()
reverse_kmer_dict = dict();

a = 0
for i in all_kmers:
    kmer_dict[i] = a
    reverse_kmer_dict[a] = i
    a += 1
    
kmer_count = len(all_kmers)

[len(all_kmers), len(unused_kmers), space]

In [19]:



Out[19]:
[269132, 1683993, 1953125]

In [7]:
def training_file_generator(directory):
    files = [directory + "/" + f for f in os.listdir(directory)]
    random.shuffle(files)
    def gen():
        nonlocal files
        if (len(files) == 0):
            files = [directory + "/" + f for f in os.listdir(directory)]
            random.shuffle(files)
        return(files.pop())
    return gen

def gen_random_training_data(input_data, window_size):
    rname = random.choice(list(input_data.keys()))
    rdata = random.choice(input_data[rname])
    idx = random.randrange(window_size + 1, len(rdata) - window_size - 1)
    tdata = list();
    for i in range(idx - window_size - 1, idx + window_size):
        if (i < 0): continue
        if (i >= len(rdata)): break
        if type(rdata[idx]) == list: break;
        if type(rdata[i]) == list: break
        tdata.append(kmer_dict[rdata[i]])
    return tdata, rname

# The current state is, each training batch is from a single FASTA file (strain, usually)
# This can be ok, as long as training batch is a large number
# Need to speed up reading of FASTA files though, maybe pyfaidx or something?

# Define the one-hot dictionary...

oh = dict()
a = 0
for i in replicons_list:
    oh[i] = tf.one_hot(a, len(replicons_list))
    a += 1
    
oh = dict()
a = 0
for i in replicons_list:
    oh[i] = a
    a += 1
    
oh = dict()
oh['Main'] = [1.0, 0.0, 0.0]
oh['pSymA'] = [0.0, 1.0, 0.0]
oh['pSymB'] = [0.0, 0.0, 1.0]


def generate_training_batch(data, batch_size, window_size):
    training_batch_data = list();
    while len(training_batch_data) < batch_size:
         training_batch_data.append(gen_random_training_data(data, 
                                                             window_size))
    return training_batch_data

def train_input_fn():
    rdata = generate_training_batch(training_data, 1, window_size)[0]
    return rdata[0], oh[rdata[1]]
    # return {"train_input": rdata[0]}, oh[rdata[1]]

In [ ]:
# filegen = training_file_generator("training-files/")

# training_data = load_fasta(filegen())
# training_data_backup = training_data

# len(training_data)

# gen_random_training_data(training_data, 7)

# generate_training_batch(training_data, 5, 7)

# random.choice(list(input_data.keys()))

train_input_fn()

In [ ]:
valid_examples[0]

In [ ]:
filegen = training_file_generator("training-files/")

# training_data = load_fasta(filegen())

batch_size = 1024
embedding_size = 128
window_size = 7

validation_set = generate_training_batch(training_data, 10000, window_size)
# validation_kmers = list(set([i[0] for i in validation_set]))
# del validation_set

# We pick a random validation set to sample nearest neighbors. Here we limit the
# validation samples to the words that have a low numeric ID, which by
# construction are also the most frequent.
valid_size = 1024
valid_examples = [i[0] for i in validation_set]
num_sampled = 256
del validation_set

learning_rate = 0.1

# Network Parameters
n_hidden_1 = 500
n_hidden_2 = 50

In [14]:



Out[14]:
<tf.Tensor 'ArgMax:0' shape=(3,) dtype=int64>

In [29]:
# https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/neural_network.py


Out[29]:
<tf.Tensor 'Softmax:0' shape=(1, 3) dtype=float32>

In [15]:
graph = tf.Graph()

batch_size = 1

with graph.as_default():
    # Load embedding
    kmers = tf.Variable(tf.constant(0.0, shape=[kmer_count, 128]),
                       trainable=False, name="kmers")

    # embeddings = np.load("embeddings_200000.npy")

    embedding_placeholder = tf.placeholder(tf.int32, [kmer_count, 128])
    embedding_init = kmers.assign(embeddings)

    # Input data.
    # Take 1 kmer and the 7 on each side of it
    # So for k=9, we are testing 135bp
    # So n = 15
    train_input = tf.placeholder(tf.int32, shape=[batch_size, 15]) 
    train_label = tf.placeholder(tf.int32, shape=[batch_size, 3])
#    train_label_r = tf.reshape(train_label, [-1])
#    labels = tf.one_hot(train_label, len(replicons_list), dtype=tf.int32)

 
    kmer_input = tf.nn.embedding_lookup(embeddings, train_input)
    kmer_input_r = tf.reshape(kmer_input, [batch_size, -1]) # Flatten

    # valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
    
    # replicons = tf.placeholder(tf.int16, shape=[None, 1])

    l1 = tf.layers.dense(kmer_input_r, n_hidden_1)
    l2 = tf.layers.dense(l1, n_hidden_2)
    logits = tf.layers.dense(l2, len(replicons_list))
    
    pred_classes = tf.argmax(logits, axis=1)
    pred_prob = tf.nn.softmax(logits)
    
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = logits, labels = train_label))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    train = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    
    acc = tf.metrics.accuracy(labels = tf.argmax(train_label, 1), predictions=pred_classes)
    
    init = tf.global_variables_initializer()

In [22]:
train_input_fn()


Out[22]:
([239238,
  169260,
  242754,
  215856,
  60843,
  132808,
  249405,
  129105,
  112336,
  19860,
  139171,
  66831,
  213222,
  255357,
  21515],
 [0.0, 0.0, 1.0])

In [23]:
with tf.Session(graph=graph, config=tf.ConfigProto(log_device_placement=True)) as session:
    init.run()
    num_steps = 5000
    for step in xrange(num_steps):
      myt, label = train_input_fn()
      train.run(feed_dict={train_input: [myt], 
                           train_label: [label]})
    
#    print(session.run([acc], feed_dict={train_input: [myt], train_label: [label]}))
    
    for step in xrange(10):
        myt, label = train_input_fn()
        classification = session.run([pred_classes, pred_prob], feed_dict={train_input: [myt]})
        print(classification, " ", label)


[array([2], dtype=int64), array([[ 0.1505875 ,  0.2130429 ,  0.63636953]], dtype=float32)]   [0.0, 1.0, 0.0]
[array([2], dtype=int64), array([[ 0.17296444,  0.41227347,  0.41476214]], dtype=float32)]   [0.0, 1.0, 0.0]
[array([2], dtype=int64), array([[ 0.20260364,  0.31371108,  0.48368534]], dtype=float32)]   [1.0, 0.0, 0.0]
[array([2], dtype=int64), array([[ 0.33472094,  0.28635615,  0.37892291]], dtype=float32)]   [1.0, 0.0, 0.0]
[array([2], dtype=int64), array([[ 0.10926115,  0.29860041,  0.59213847]], dtype=float32)]   [0.0, 0.0, 1.0]
[array([2], dtype=int64), array([[ 0.21947457,  0.25672138,  0.52380401]], dtype=float32)]   [0.0, 1.0, 0.0]
[array([2], dtype=int64), array([[ 0.1778408 ,  0.30563816,  0.51652104]], dtype=float32)]   [1.0, 0.0, 0.0]
[array([2], dtype=int64), array([[ 0.11665278,  0.31180292,  0.57154435]], dtype=float32)]   [1.0, 0.0, 0.0]
[array([2], dtype=int64), array([[ 0.16575523,  0.3371191 ,  0.49712569]], dtype=float32)]   [1.0, 0.0, 0.0]
[array([2], dtype=int64), array([[ 0.18270759,  0.1957254 ,  0.62156695]], dtype=float32)]   [0.0, 1.0, 0.0]

In [48]:



Out[48]:
{'train_input': <tf.Tensor 'embedding_lookup_12:0' shape=(15, 128) dtype=float32>}

In [ ]:
num_steps = 5

executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
future = executor.submit(load_fasta, filegen())

tdata = list()
tdata = future.result()
print("tdata length: ", str(len(tdata)))

with tf.Session(graph=graph, config=tf.ConfigProto(log_device_placement=True)) as session:
  # We must initialize all variables before we use them.
  init.run()
  print('Initialized')

  average_loss = 0
  for step in xrange(num_steps):
    
    if step % 15000 == 0: # Change files every 15k steps
        print("Loading new file at step: ", step)
        # Start loading the next file, so it has time to finish while the neural net does its training
        tdata = future.result()
        future = executor.submit(load_fasta, filegen())
        
    if step == 5:
        print("Reached step 5!")
        
    if len(tdata) == 0:
        print("Using short-circuit load-fasta at step: ", step)
        tdata = load_fasta(filegen()) # Emergency short-circuit here....
        
    batch_data = generate_training_batch(tdata, batch_size, window_size)
    feed_dict = {train_inputs: [x[0] for x in batch_data], 
                 train_labels: [[x[1]] for x in batch_data]}

    # We perform one update step by evaluating the optimizer op (including it
    # in the list of returned values for session.run()
    _, loss_val = session.run([optimizer, loss], feed_dict=feed_dict)
    average_loss += loss_val

    # Print status every 10k steps
    if step % 10000 == 0:
        if step > 0:
            average_loss /= 2000
            # The average loss is an estimate of the loss over the last 2000 batches.
        print('Average loss at step ', step, ': ', average_loss)
        average_loss = 0
    
    # Save every 50k steps
#    if step % 100000 == 0:
#        print("Saving model at step: ", step)
#        saver.save(session, './replicon-model', global_step=step)
#        print("Saved model at step: ", step)

        
#    if step % 20000 == 0:
#        sim = similarity.eval()
#        accuracy = 0
#        for i in range(0, 100):
#            rand_kmer = random.choice(list(validation_dict.keys()))
#            top_k = 10
#            nearest = (-sim[rand_kmer, :]).argsort()[1:top_k + 1]

In [ ]:


In [19]:
kmer_dict


Out[19]:
{0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 6: 5,
 7: 6,
 8: 7,
 9: 8,
 11: 9,
 12: 10,
 13: 11,
 14: 12,
 16: 13,
 17: 14,
 18: 15,
 19: 16,
 21: 17,
 22: 18,
 23: 19,
 24: 20,
 31: 21,
 32: 22,
 33: 23,
 34: 24,
 36: 25,
 37: 26,
 38: 27,
 39: 28,
 41: 29,
 42: 30,
 43: 31,
 44: 32,
 46: 33,
 47: 34,
 48: 35,
 49: 36,
 56: 37,
 57: 38,
 58: 39,
 59: 40,
 61: 41,
 62: 42,
 63: 43,
 64: 44,
 66: 45,
 67: 46,
 68: 47,
 69: 48,
 71: 49,
 72: 50,
 73: 51,
 74: 52,
 81: 53,
 82: 54,
 83: 55,
 84: 56,
 86: 57,
 87: 58,
 88: 59,
 89: 60,
 91: 61,
 92: 62,
 93: 63,
 94: 64,
 96: 65,
 97: 66,
 98: 67,
 99: 68,
 106: 69,
 107: 70,
 108: 71,
 109: 72,
 111: 73,
 112: 74,
 113: 75,
 114: 76,
 116: 77,
 117: 78,
 118: 79,
 119: 80,
 121: 81,
 122: 82,
 123: 83,
 124: 84,
 156: 85,
 158: 86,
 159: 87,
 161: 88,
 162: 89,
 163: 90,
 164: 91,
 166: 92,
 168: 93,
 169: 94,
 171: 95,
 172: 96,
 173: 97,
 174: 98,
 181: 99,
 182: 100,
 183: 101,
 184: 102,
 186: 103,
 187: 104,
 188: 105,
 189: 106,
 191: 107,
 192: 108,
 194: 109,
 196: 110,
 197: 111,
 199: 112,
 207: 113,
 208: 114,
 209: 115,
 212: 116,
 213: 117,
 214: 118,
 216: 119,
 217: 120,
 218: 121,
 221: 122,
 222: 123,
 223: 124,
 231: 125,
 232: 126,
 233: 127,
 234: 128,
 236: 129,
 237: 130,
 238: 131,
 239: 132,
 241: 133,
 242: 134,
 243: 135,
 244: 136,
 246: 137,
 247: 138,
 248: 139,
 281: 140,
 282: 141,
 283: 142,
 284: 143,
 286: 144,
 287: 145,
 288: 146,
 289: 147,
 291: 148,
 292: 149,
 293: 150,
 294: 151,
 296: 152,
 297: 153,
 298: 154,
 299: 155,
 306: 156,
 307: 157,
 309: 158,
 311: 159,
 312: 160,
 313: 161,
 314: 162,
 316: 163,
 317: 164,
 318: 165,
 319: 166,
 321: 167,
 322: 168,
 323: 169,
 324: 170,
 331: 171,
 332: 172,
 333: 173,
 334: 174,
 336: 175,
 337: 176,
 338: 177,
 339: 178,
 342: 179,
 343: 180,
 344: 181,
 346: 182,
 347: 183,
 348: 184,
 356: 185,
 357: 186,
 359: 187,
 361: 188,
 362: 189,
 363: 190,
 364: 191,
 366: 192,
 367: 193,
 368: 194,
 369: 195,
 371: 196,
 372: 197,
 373: 198,
 374: 199,
 406: 200,
 407: 201,
 409: 202,
 411: 203,
 412: 204,
 413: 205,
 414: 206,
 416: 207,
 417: 208,
 418: 209,
 419: 210,
 421: 211,
 422: 212,
 423: 213,
 432: 214,
 433: 215,
 437: 216,
 438: 217,
 439: 218,
 441: 219,
 442: 220,
 443: 221,
 444: 222,
 446: 223,
 447: 224,
 448: 225,
 449: 226,
 457: 227,
 458: 228,
 459: 229,
 461: 230,
 462: 231,
 463: 232,
 464: 233,
 466: 234,
 467: 235,
 468: 236,
 469: 237,
 471: 238,
 472: 239,
 473: 240,
 474: 241,
 481: 242,
 482: 243,
 483: 244,
 487: 245,
 488: 246,
 489: 247,
 491: 248,
 492: 249,
 493: 250,
 494: 251,
 496: 252,
 497: 253,
 499: 254,
 531: 255,
 532: 256,
 533: 257,
 534: 258,
 536: 259,
 537: 260,
 538: 261,
 539: 262,
 541: 263,
 542: 264,
 543: 265,
 544: 266,
 546: 267,
 547: 268,
 548: 269,
 549: 270,
 556: 271,
 557: 272,
 558: 273,
 559: 274,
 561: 275,
 562: 276,
 563: 277,
 564: 278,
 566: 279,
 567: 280,
 568: 281,
 569: 282,
 571: 283,
 572: 284,
 573: 285,
 574: 286,
 582: 287,
 583: 288,
 584: 289,
 586: 290,
 587: 291,
 588: 292,
 589: 293,
 591: 294,
 592: 295,
 593: 296,
 594: 297,
 596: 298,
 597: 299,
 598: 300,
 599: 301,
 606: 302,
 607: 303,
 608: 304,
 609: 305,
 611: 306,
 612: 307,
 613: 308,
 614: 309,
 616: 310,
 617: 311,
 618: 312,
 619: 313,
 621: 314,
 622: 315,
 623: 316,
 624: 317,
 781: 318,
 782: 319,
 784: 320,
 794: 321,
 796: 322,
 797: 323,
 798: 324,
 806: 325,
 811: 326,
 812: 327,
 814: 328,
 819: 329,
 823: 330,
 824: 331,
 831: 332,
 832: 333,
 833: 334,
 843: 335,
 844: 336,
 846: 337,
 857: 338,
 859: 339,
 861: 340,
 863: 341,
 864: 342,
 866: 343,
 867: 344,
 868: 345,
 871: 346,
 872: 347,
 906: 348,
 909: 349,
 911: 350,
 912: 351,
 917: 352,
 922: 353,
 923: 354,
 932: 355,
 934: 356,
 936: 357,
 939: 358,
 942: 359,
 943: 360,
 944: 361,
 946: 362,
 949: 363,
 957: 364,
 958: 365,
 964: 366,
 972: 367,
 981: 368,
 982: 369,
 984: 370,
 987: 371,
 988: 372,
 996: 373,
 997: 374,
 999: 375,
 1039: 376,
 1042: 377,
 1048: 378,
 1049: 379,
 1062: 380,
 1064: 381,
 1067: 382,
 1071: 383,
 1072: 384,
 1073: 385,
 1074: 386,
 1082: 387,
 1086: 388,
 1092: 389,
 1093: 390,
 1094: 391,
 1106: 392,
 1108: 393,
 1112: 394,
 1113: 395,
 1116: 396,
 1117: 397,
 1119: 398,
 1159: 399,
 1162: 400,
 1164: 401,
 1166: 402,
 1167: 403,
 1172: 404,
 1173: 405,
 1174: 406,
 1182: 407,
 1184: 408,
 1189: 409,
 1191: 410,
 1193: 411,
 1197: 412,
 1198: 413,
 1199: 414,
 1209: 415,
 1213: 416,
 1217: 417,
 1218: 418,
 1222: 419,
 1232: 420,
 1236: 421,
 1237: 422,
 1238: 423,
 1239: 424,
 1244: 425,
 1406: 426,
 1407: 427,
 1408: 428,
 1412: 429,
 1413: 430,
 1416: 431,
 1418: 432,
 1421: 433,
 1422: 434,
 1423: 435,
 1050000: 436,
 1424: 437,
 1431: 438,
 1433: 439,
 1437: 440,
 1438: 441,
 1439: 442,
 1443: 443,
 1446: 444,
 1447: 445,
 1449: 446,
 1457: 447,
 1458: 448,
 1459: 449,
 1461: 450,
 1463: 451,
 1464: 452,
 1467: 453,
 1472: 454,
 1474: 455,
 1481: 456,
 1482: 457,
 1486: 458,
 1487: 459,
 1488: 460,
 1489: 461,
 1491: 462,
 1492: 463,
 1493: 464,
 1496: 465,
 1499: 466,
 1532: 467,
 1534: 468,
 1537: 469,
 1538: 470,
 1539: 471,
 1546: 472,
 1547: 473,
 1557: 474,
 1562: 475,
 1567: 476,
 1568: 477,
 1569: 478,
 1572: 479,
 1574: 480,
 1581: 481,
 1584: 482,
 1586: 483,
 1587: 484,
 1588: 485,
 1589: 486,
 1592: 487,
 1593: 488,
 1594: 489,
 1596: 490,
 1599: 491,
 1606: 492,
 1607: 493,
 1609: 494,
 1611: 495,
 1612: 496,
 1613: 497,
 1614: 498,
 1616: 499,
 1617: 500,
 1621: 501,
 1622: 502,
 1623: 503,
 1624: 504,
 1656: 505,
 1659: 506,
 1662: 507,
 1666: 508,
 1667: 509,
 1672: 510,
 1673: 511,
 1674: 512,
 1681: 513,
 1686: 514,
 1688: 515,
 1691: 516,
 1696: 517,
 1697: 518,
 1698: 519,
 1711: 520,
 1713: 521,
 1714: 522,
 1717: 523,
 1718: 524,
 1721: 525,
 1722: 526,
 1733: 527,
 1736: 528,
 1737: 529,
 1738: 530,
 1742: 531,
 1744: 532,
 1781: 533,
 1782: 534,
 1784: 535,
 1786: 536,
 1789: 537,
 1796: 538,
 1797: 539,
 1798: 540,
 1799: 541,
 1806: 542,
 1807: 543,
 1808: 544,
 1809: 545,
 1811: 546,
 1812: 547,
 1814: 548,
 1817: 549,
 1819: 550,
 1821: 551,
 1822: 552,
 1824: 553,
 1832: 554,
 1833: 555,
 1836: 556,
 1838: 557,
 1839: 558,
 1842: 559,
 1843: 560,
 1846: 561,
 1847: 562,
 1848: 563,
 1856: 564,
 1857: 565,
 1861: 566,
 1862: 567,
 1863: 568,
 1864: 569,
 1866: 570,
 1867: 571,
 1869: 572,
 1871: 573,
 1872: 574,
 1873: 575,
 2031: 576,
 2032: 577,
 2034: 578,
 2039: 579,
 2047: 580,
 2049: 581,
 2057: 582,
 2058: 583,
 2063: 584,
 2066: 585,
 2071: 586,
 2082: 587,
 2087: 588,
 2092: 589,
 2097: 590,
 2106: 591,
 2111: 592,
 2112: 593,
 2113: 594,
 2118: 595,
 2162: 596,
 2163: 597,
 2164: 598,
 2167: 599,
 2169: 600,
 2186: 601,
 2187: 602,
 2188: 603,
 2193: 604,
 2198: 605,
 2199: 606,
 1050781: 607,
 1050782: 608,
 1050783: 609,
 1050784: 610,
 2209: 611,
 1050786: 612,
 1050787: 613,
 1050788: 614,
 2213: 615,
 2214: 616,
 1050791: 617,
 1050792: 618,
 1050793: 619,
 1050794: 620,
 2218: 621,
 2217: 622,
 1050796: 623,
 1050797: 624,
 2221: 625,
 1050798: 626,
 2222: 627,
 1050799: 628,
 2223: 629,
 2224: 630,
 1050806: 631,
 1050807: 632,
 1050808: 633,
 2232: 634,
 1050809: 635,
 1050811: 636,
 1050812: 637,
 1050813: 638,
 1050814: 639,
 2238: 640,
 1050816: 641,
 1050817: 642,
 1050818: 643,
 1050819: 644,
 2237: 645,
 1050821: 646,
 1050822: 647,
 1050823: 648,
 1050824: 649,
 2242: 650,
 2244: 651,
 2246: 652,
 2247: 653,
 2249: 654,
 1050825: 655,
 1050831: 656,
 1050832: 657,
 1050833: 658,
 1050834: 659,
 1050836: 660,
 1050837: 661,
 1050838: 662,
 1050839: 663,
 1050841: 664,
 1050842: 665,
 1050843: 666,
 1050844: 667,
 1050846: 668,
 1050847: 669,
 1050848: 670,
 1050849: 671,
 1050856: 672,
 1050857: 673,
 1050858: 674,
 1050859: 675,
 1050861: 676,
 1050862: 677,
 1050863: 678,
 1050864: 679,
 2288: 680,
 1050866: 681,
 1050867: 682,
 1050868: 683,
 1050869: 684,
 2292: 685,
 1050871: 686,
 1050872: 687,
 1050873: 688,
 1050874: 689,
 2293: 690,
 2294: 691,
 2296: 692,
 2298: 693,
 2306: 694,
 2311: 695,
 2312: 696,
 2316: 697,
 2317: 698,
 2318: 699,
 2323: 700,
 1050906: 701,
 1050907: 702,
 1050908: 703,
 1050909: 704,
 2334: 705,
 1050911: 706,
 1050912: 707,
 1050913: 708,
 1050914: 709,
 2336: 710,
 1050916: 711,
 1050917: 712,
 1050918: 713,
 1050919: 714,
 2343: 715,
 1050921: 716,
 1050922: 717,
 1050923: 718,
 1050924: 719,
 1050925: 720,
 2346: 721,
 2344: 722,
 2348: 723,
 2349: 724,
 1050931: 725,
 1050932: 726,
 1050933: 727,
 1050934: 728,
 2356: 729,
 1050936: 730,
 1050937: 731,
 2361: 732,
 1050938: 733,
 1050939: 734,
 1050941: 735,
 1050942: 736,
 1050943: 737,
 1050944: 738,
 2368: 739,
 1050946: 740,
 1050947: 741,
 1050948: 742,
 1050949: 743,
 2371: 744,
 1050956: 745,
 1050957: 746,
 1050958: 747,
 1050959: 748,
 1050961: 749,
 1050962: 750,
 1050963: 751,
 1050964: 752,
 1050966: 753,
 1050967: 754,
 1050968: 755,
 1050969: 756,
 1050971: 757,
 1050972: 758,
 1050973: 759,
 1050974: 760,
 1050981: 761,
 1050982: 762,
 1050983: 763,
 1050984: 764,
 2408: 765,
 1050986: 766,
 1050987: 767,
 1050988: 768,
 1050989: 769,
 2414: 770,
 1050991: 771,
 1050992: 772,
 1050993: 773,
 1050994: 774,
 2418: 775,
 1050996: 776,
 1050997: 777,
 1050998: 778,
 1050999: 779,
 2419: 780,
 2439: 781,
 2441: 782,
 2442: 783,
 2443: 784,
 2449: 785,
 1051031: 786,
 1051032: 787,
 1051033: 788,
 1051034: 789,
 2459: 790,
 1051036: 791,
 1051037: 792,
 1051038: 793,
 1051039: 794,
 2461: 795,
 1051041: 796,
 1051042: 797,
 1051043: 798,
 2468: 799,
 1051044: 800,
 1051046: 801,
 1051047: 802,
 1051048: 803,
 2473: 804,
 1051049: 805,
 2469: 806,
 1051056: 807,
 1051057: 808,
 2482: 809,
 2483: 810,
 1051059: 811,
 1051061: 812,
 1051062: 813,
 1051063: 814,
 1051064: 815,
 1051058: 816,
 1051066: 817,
 1051067: 818,
 1051068: 819,
 1051069: 820,
 2486: 821,
 1051071: 822,
 1051072: 823,
 1051073: 824,
 1051074: 825,
 2497: 826,
 1051081: 827,
 1051082: 828,
 1051083: 829,
 1051084: 830,
 1051086: 831,
 1051087: 832,
 1051088: 833,
 1051089: 834,
 1051091: 835,
 1051092: 836,
 1051093: 837,
 1051094: 838,
 1051096: 839,
 1051097: 840,
 1051098: 841,
 1051099: 842,
 1051106: 843,
 1051107: 844,
 1051108: 845,
 1051109: 846,
 1051111: 847,
 1051112: 848,
 1051113: 849,
 1051114: 850,
 1051116: 851,
 1051117: 852,
 1051118: 853,
 1051119: 854,
 1051121: 855,
 1051122: 856,
 1051123: 857,
 1051124: 858,
 1051156: 859,
 1051157: 860,
 1051158: 861,
 1051159: 862,
 1051161: 863,
 1051162: 864,
 1051163: 865,
 1051164: 866,
 1051166: 867,
 1051167: 868,
 1051168: 869,
 1051169: 870,
 1051171: 871,
 1051172: 872,
 1051173: 873,
 1051174: 874,
 1051181: 875,
 1051182: 876,
 1051183: 877,
 1051184: 878,
 1051186: 879,
 1051187: 880,
 1051188: 881,
 1051189: 882,
 1051191: 883,
 1051192: 884,
 1051193: 885,
 1051194: 886,
 1051196: 887,
 1051197: 888,
 1051198: 889,
 1051199: 890,
 1051206: 891,
 1051207: 892,
 1051208: 893,
 1051209: 894,
 1051211: 895,
 1051212: 896,
 1051213: 897,
 1051214: 898,
 1051216: 899,
 1051217: 900,
 1051218: 901,
 1051219: 902,
 1051221: 903,
 1051222: 904,
 1051223: 905,
 1051224: 906,
 1051231: 907,
 1051232: 908,
 1051233: 909,
 1051234: 910,
 2656: 911,
 1051236: 912,
 1051237: 913,
 1051238: 914,
 1051239: 915,
 2663: 916,
 1051241: 917,
 1051242: 918,
 2666: 919,
 1051243: 920,
 1051244: 921,
 1051246: 922,
 1051247: 923,
 1051248: 924,
 2672: 925,
 1051249: 926,
 2674: 927,
 2671: 928,
 2681: 929,
 2683: 930,
 2684: 931,
 2687: 932,
 2689: 933,
 2693: 934,
 2696: 935,
 2698: 936,
 2699: 937,
 2706: 938,
 2709: 939,
 2713: 940,
 2717: 941,
 2718: 942,
 2721: 943,
 2722: 944,
 2723: 945,
 2724: 946,
 2731: 947,
 2733: 948,
 2734: 949,
 2736: 950,
 2737: 951,
 2738: 952,
 2743: 953,
 2746: 954,
 2747: 955,
 2749: 956,
 2781: 957,
 2782: 958,
 2784: 959,
 2786: 960,
 2787: 961,
 2789: 962,
 2791: 963,
 2793: 964,
 2794: 965,
 2796: 966,
 2797: 967,
 2808: 968,
 2809: 969,
 2811: 970,
 2813: 971,
 2814: 972,
 2816: 973,
 2817: 974,
 2818: 975,
 2819: 976,
 2821: 977,
 2822: 978,
 2823: 979,
 2824: 980,
 1051406: 981,
 1051407: 982,
 1051408: 983,
 1051409: 984,
 2834: 985,
 1051411: 986,
 1051412: 987,
 1051413: 988,
 1051414: 989,
 2837: 990,
 1051416: 991,
 1051417: 992,
 1051418: 993,
 1051419: 994,
 2841: 995,
 1051421: 996,
 1051422: 997,
 1051423: 998,
 1051424: 999,
 ...}

In [2]:
def model_fn(features, labels, mode):
    kmers = tf.Variable(tf.constant(0.0, shape=[kmer_count, 128]),
                        trainable=False, name="kmers")
    
    embedding_placeholder = tf.placeholder(tf.int32, [kmer_count, 128])
    embedding_init = kmers.assign(embeddings)
    
    train_input = tf.placeholder(tf.int32, shape=[batch_size, 15]) 
    train_label = tf.placeholder(tf.int32, shape=[batch_size, 3])
    
    kmer_input = tf.nn.embedding_lookup(embeddings, train_input)
    kmer_input_r = tf.reshape(kmer_input, [batch_size, -1])
    
    l1 = tf.layers.dense(kmer_input_r, n_hidden_1)
    l2 = tf.layers.dense(l1, n_hidden_2)
    logits = tf.layers.dense(l2, len(replicons_list))
    
    pred_classes = tf.argmax(logits, axis=1)
    pred_prob = tf.nn.softmax(logits)
    
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = logits, labels = train_label))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    train = optimizer.minimize(loss, global_step=tf.train.get_global_step())

    acc = tf.metrics.accuracy(labels = tf.argmax(train_label, 1), predictions=pred_classes)

    estim_specs = tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=pred_classes,
            loss=loss,
            train_op=train,
            eval_metric_ops={'accuracy': acc})

    return estim_specs

In [29]:
tbatch = generate_training_batch(training_data, 1, window_size)
len(tbatch[0][0])


Out[29]:
15

In [3]:
embeddings = np.load("embeddings_200000.npy")

def get_kmer_embeddings(kmers):
    a = list() # numpy.empty(128 * 15)
    for k in kmers:
        a.append(embeddings[k])
    return np.hstack(a)

# training_data = load_fasta(filegen())
# get_kmer_embeddings(tbatch[0][0])

# tf.convert_to_tensor([tf.convert_to_tensor(get_kmer_embeddings(tbatch[0][0]))])

In [8]:
batch_size = 1024
embedding_size = 128
window_size = 7

replicons_list = get_categories("training-files/")

filegen = training_file_generator("training-files/")

repdict = dict()
a = 0
for i in replicons_list:
    repdict[i] = a
    a += 1

def train_input_fn(data):
    tbatch = generate_training_batch(data, 1, window_size)
    dat = {'x': tf.convert_to_tensor([tf.convert_to_tensor(get_kmer_embeddings(tbatch[0][0]))])}
    lab = tf.convert_to_tensor([repdict[tbatch[0][1]]])
    return dat, lab

def test_input_fn(data):
    tbatch = generate_training_batch(data, 1, window_size)
    dat = {'x': tf.convert_to_tensor([tf.convert_to_tensor(get_kmer_embeddings(tbatch[0][0]))])}
    lab = tf.convert_to_tensor([repdict[tbatch[0][1]]])
    return dat, lab

In [9]:
# Because this was run at work on a smaller sample of files....
# with open("all_kmers_subset.txt", "w") as f:
#     for s in all_kmers:
#         f.write(str(s) +"\n")

# Because this was run at work on a smaller sample of files....
all_kmers = list()
# with open("all_kmers_subset.txt", "r") as f:
#     for line in f:
#         all_kmers.append(int(line.strip()))

all_kmers = pickle.load(open("all_kmers.p", "rb"))

all_kmers = set(all_kmers)
len(all_kmers)
# len(data)

# all_kmers = set([item for sublist in data for item in sublist])
unused_kmers = set(range(0, space)) - all_kmers

kmer_dict = dict()
reverse_kmer_dict = dict();

a = 0
for i in all_kmers:
    kmer_dict[i] = a
    reverse_kmer_dict[a] = i
    a += 1
    
kmer_count = len(all_kmers)

[len(all_kmers), len(unused_kmers), space]


Out[9]:
[269132, 1683993, 1953125]

In [10]:
feature_columns = [tf.feature_column.numeric_column("x", shape=[1920])]
nn = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                hidden_units = [5000,500,100],
                                activation_fn=tf.nn.relu,
                                dropout=0.2,
                                n_classes=len(replicons_list),
                                optimizer="Adam")


INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: C:\Users\Joey\AppData\Local\Temp\tmpuaplusri
INFO:tensorflow:Using config: {'_model_dir': 'C:\\Users\\Joey\\AppData\\Local\\Temp\\tmpuaplusri', '_tf_random_seed': 1, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100}

In [ ]:
tf.logging.set_verbosity(tf.logging.ERROR)

filegen = training_file_generator("training-files/")

for i in xrange(100):
    training_data = load_fasta(filegen())
    tfn = functools.partial(train_input_fn, training_data)
    nn.train(input_fn=tfn, steps=10000)


Loading from pickle

In [ ]:
# training_data = load_fasta(filegen())
filegen = training_file_generator("training-files/")
tf.logging.set_verbosity(tf.logging.ERROR)
tfn = functools.partial(train_input_fn, training_data)
accuracy_score = nn.evaluate(input_fn=tfn, steps=100)['accuracy']
print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

In [135]:
tfn = functools.partial(train_input_fn, training_data)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-135-0eef4d8744d3> in <module>()
      1 tfn = functools.partial(train_input_fn, training_data)
----> 2 tf.Session.run(tfn()[1])

TypeError: run() missing 1 required positional argument: 'fetches'

In [ ]: