In [1]:
import os
import time
import numpy
import tensorflow as tf
import input_data
import vw_c3d_newnetwork
import vw_c3d_tools
import math
import numpy as np

BATCH_SIZE = 16
gpu_num = 1
#MAX_STEPS = 10000
NUM_FRAMES_PER_CLIP = 16
N_CLASSE = 9
CROP_SIZE = 112
CHANNELS = 3
MAX_EPOCHS = 100
INIT_LEARNINGRATE = 3e-3

train_log_dir = './logs//train//'
val_log_dir = './logs//val//'
model_dir = './models/'
model_filename = './models/'
is_finetune = False

def placeholder_inputs(batch_size):
    images_placeholder = tf.placeholder(tf.float32, shape=(BATCH_SIZE,
                                                           NUM_FRAMES_PER_CLIP,
                                                           CROP_SIZE,
                                                           CROP_SIZE,
                                                           CHANNELS))
    labels_placeholder = tf.placeholder(tf.int64, shape=[BATCH_SIZE, N_CLASSE])
    return images_placeholder, labels_placeholder

def train():
    with tf.Graph().as_default():
        #建立读取数据的pipeline,在训练过程中迭代器可以自动读取数据
        train_images_batch, train_labels_batch, _, _, _, train_total_num = input_data.read_clip_and_label(
                                                      filename='list/train.list',
                                                      batch_size=BATCH_SIZE * gpu_num,
                                                      num_frames_per_clip=NUM_FRAMES_PER_CLIP,
                                                      crop_size=CROP_SIZE,
                                                      shuffle=True
                                                      )
        val_images_batch, val_labels_batch, _, _, _, val_total_num= input_data.read_clip_and_label(
                                                      filename='list/test.list',
                                                      batch_size=BATCH_SIZE * gpu_num,
                                                      num_frames_per_clip=NUM_FRAMES_PER_CLIP,
                                                      crop_size=CROP_SIZE,
                                                      shuffle=True
                                                      )
 

        #定义学习率
        global_step = tf.Variable(0, name='global_step', trainable=False) 
        initial_learning_rate = INIT_LEARNINGRATE       #初始学习率
        learning_rate_decay_rate = 0.90                 #学习率衰减率
        step_of_epoch = math.floor(train_total_num / BATCH_SIZE)    #迭代完一次所有样本需要的步数
        learning_rate_decay_steps = step_of_epoch       #学习率衰减一次所需要的步数
        learning_rate = tf.train.exponential_decay(initial_learning_rate,
                    global_step, 
                    learning_rate_decay_steps,
                    learning_rate_decay_rate,staircase=True)              
        
        #将数据映射成tenor形式,并进行one-hot编码
        train_images_batch = tf.cast(train_images_batch,dtype=tf.float32)
        train_labels_batch = tf.one_hot(train_labels_batch, depth= N_CLASSE)
        train_labels_batch = tf.cast(train_labels_batch,dtype=tf.int32)
        val_images_batch = tf.cast(val_images_batch,dtype=tf.float32)
        val_labels_batch = tf.one_hot(val_labels_batch, depth= N_CLASSE)
        val_labels_batch = tf.cast(val_labels_batch,dtype=tf.int32)
        
        #构建logits、loss、accuracy、train_op
        logits = vw_c3d_newnetwork.C3D_MODEL(train_images_batch,N_CLASSE)
        loss = vw_c3d_tools.loss(logits, train_labels_batch)
        accuracy = vw_c3d_tools.accuracy(logits,train_labels_batch)
        train_op = vw_c3d_tools.optimize(loss,learning_rate,global_step)
        
        #构建数据和标签的place_holder 
        images_placeholder, labels_placeholder = placeholder_inputs(BATCH_SIZE * gpu_num)
        
        #准备存储模型
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        saver = tf.train.Saver(tf.global_variables())
        
        #设置绘图的op
        summary_op = tf.summary.merge_all()

        #初始化所有变量,并开启会话
        init = tf.global_variables_initializer()
        sess = tf.Session()
        sess.run(init)
        if is_finetune:
            saver.restore(sess, model_filename)
        
        #tensoboard绘图的writer
        tra_summary_writer = tf.summary.FileWriter(train_log_dir, sess.graph)
        val_summary_writer = tf.summary.FileWriter(val_log_dir, sess.graph)
        
        #打印一些基本信息
        MAX_STEPS = int(MAX_EPOCHS * step_of_epoch)
        print("MAX_STEPS: ",MAX_STEPS)
        print("MAX_EPOCHS: ", MAX_EPOCHS)
        print("step_of_epoch", step_of_epoch)
        print("Train samples: ", train_total_num)
        print("Val samples", val_total_num)
        
        for step in range(MAX_STEPS):
            start_time = time.time()
            tra_images,tra_labels = sess.run([train_images_batch,train_labels_batch])
            _, train_loss, train_accuracy = sess.run([train_op,loss,accuracy], feed_dict ={
                                                          images_placeholder: tra_images,
                                                          labels_placeholder: tra_labels
                                                          })
            duration = time.time() - start_time
            print('Step %d: %.3f sec' % (step, duration))
            if step % 10 == 0 or (step + 1) == MAX_STEPS:
                print ('Step: %d, train_loss: %.4f, train_accuracy: %.4f%%' % (step, train_loss, train_accuracy))
                summary_str = sess.run(summary_op)
                tra_summary_writer.add_summary(summary_str, step)
            
            if step % 50 == 0 or (step + 1) == MAX_STEPS:
                val_images,val_labels=sess.run([val_images_batch,val_labels_batch])
                val_loss, val_acc = sess.run([loss, accuracy], feed_dict={
                                                      images_placeholder: val_images,
                                                      labels_placeholder: val_labels})
                summary_str = sess.run(summary_op)
                val_summary_writer.add_summary(summary_str, step)
                print('**  Step %d, val_loss = %.2f, val_accuracy = %.2f%%  **' %(step, val_loss, val_acc))
            if step % (step_of_epoch*3) == 0 or (step + 1) == MAX_STEPS:
                checkpoint_path = os.path.join(model_dir, 'model')
                saver.save(sess, checkpoint_path, global_step=step)
        print("Done!")

In [2]:
train()


MAX_STEPS:  7800
MAX_EPOCHS:  100
step_of_epoch 78
Train samples:  1260
Val samples 538
Step 0: 2.527 sec
Step: 0, train_loss: 5.5642, train_accuracy: 6.2500%
**  Step 0, val_loss = 5.21, val_accuracy = 6.25%  **
Step 1: 1.124 sec
Step 2: 1.128 sec
Step 3: 0.981 sec
Step 4: 1.006 sec
Step 5: 1.001 sec
Step 6: 1.067 sec
Step 7: 1.045 sec
Step 8: 0.991 sec
Step 9: 0.992 sec
Step 10: 0.992 sec
Step: 10, train_loss: 6.7070, train_accuracy: 43.7500%
Step 11: 1.022 sec
Step 12: 1.018 sec
Step 13: 1.153 sec
Step 14: 1.017 sec
Step 15: 1.027 sec
Step 16: 1.078 sec
Step 17: 1.040 sec
Step 18: 1.016 sec
Step 19: 1.127 sec
Step 20: 1.019 sec
Step: 20, train_loss: 7.7721, train_accuracy: 56.2500%
Step 21: 1.057 sec
Step 22: 1.014 sec
Step 23: 1.191 sec
Step 24: 1.133 sec
Step 25: 1.193 sec
Step 26: 1.004 sec
Step 27: 1.158 sec
Step 28: 1.082 sec
Step 29: 0.993 sec
Step 30: 1.053 sec
Step: 30, train_loss: 7.6158, train_accuracy: 75.0000%
Step 31: 1.061 sec
Step 32: 1.031 sec
Step 33: 1.197 sec
Step 34: 0.979 sec
Step 35: 1.036 sec
Step 36: 0.996 sec
Step 37: 0.993 sec
Step 38: 1.017 sec
Step 39: 0.993 sec
Step 40: 1.011 sec
Step: 40, train_loss: 7.1073, train_accuracy: 75.0000%
Step 41: 1.025 sec
Step 42: 1.011 sec
Step 43: 1.038 sec
Step 44: 0.999 sec
Step 45: 1.210 sec
Step 46: 1.004 sec
Step 47: 1.062 sec
Step 48: 1.157 sec
Step 49: 0.993 sec
Step 50: 1.091 sec
Step: 50, train_loss: 6.6127, train_accuracy: 87.5000%
**  Step 50, val_loss = 6.59, val_accuracy = 87.50%  **
Step 51: 1.071 sec
Step 52: 1.193 sec
Step 53: 1.106 sec
Step 54: 1.016 sec
Step 55: 1.081 sec
Step 56: 1.042 sec
Step 57: 1.012 sec
Step 58: 1.141 sec
Step 59: 1.119 sec
Step 60: 1.124 sec
Step: 60, train_loss: 6.5098, train_accuracy: 81.2500%
Step 61: 1.092 sec
Step 62: 1.072 sec
Step 63: 1.151 sec
Step 64: 1.168 sec
Step 65: 1.115 sec
Step 66: 1.166 sec
Step 67: 1.020 sec
Step 68: 1.171 sec
Step 69: 1.013 sec
Step 70: 1.006 sec
Step: 70, train_loss: 6.1801, train_accuracy: 87.5000%
Step 71: 1.058 sec
Step 72: 1.112 sec
Step 73: 1.132 sec
Step 74: 1.185 sec
Step 75: 1.114 sec
Step 76: 0.996 sec
Step 77: 1.003 sec
Step 78: 1.000 sec
Step 79: 1.015 sec
Step 80: 1.096 sec
Step: 80, train_loss: 6.0882, train_accuracy: 81.2500%
Step 81: 1.031 sec
Step 82: 1.096 sec
Step 83: 1.024 sec
Step 84: 1.089 sec
Step 85: 1.144 sec
Step 86: 1.060 sec
Step 87: 0.989 sec
Step 88: 0.996 sec
Step 89: 1.169 sec
Step 90: 1.145 sec
Step: 90, train_loss: 5.9098, train_accuracy: 87.5000%
Step 91: 1.055 sec
Step 92: 0.998 sec
Step 93: 1.173 sec
Step 94: 1.101 sec
Step 95: 1.125 sec
Step 96: 1.028 sec
Step 97: 1.019 sec
Step 98: 1.195 sec
Step 99: 1.071 sec
Step 100: 1.095 sec
Step: 100, train_loss: 5.6489, train_accuracy: 87.5000%
**  Step 100, val_loss = 5.61, val_accuracy = 87.50%  **
Step 101: 1.073 sec
Step 102: 1.032 sec
Step 103: 1.131 sec
Step 104: 1.093 sec
Step 105: 1.168 sec
Step 106: 1.146 sec
Step 107: 1.004 sec
Step 108: 1.105 sec
Step 109: 0.984 sec
Step 110: 1.035 sec
Step: 110, train_loss: 5.9850, train_accuracy: 75.0000%
Step 111: 1.043 sec
Step 112: 0.998 sec
Step 113: 1.001 sec
Step 114: 1.051 sec
Step 115: 1.015 sec
Step 116: 1.047 sec
Step 117: 0.987 sec
Step 118: 0.986 sec
Step 119: 1.085 sec
Step 120: 1.062 sec
Step: 120, train_loss: 5.3266, train_accuracy: 87.5000%
Step 121: 1.123 sec
Step 122: 1.272 sec
Step 123: 1.021 sec
Step 124: 1.081 sec
Step 125: 1.175 sec
Step 126: 1.068 sec
Step 127: 1.761 sec
Step 128: 1.867 sec
Step 129: 1.036 sec
Step 130: 1.220 sec
Step: 130, train_loss: 5.3866, train_accuracy: 81.2500%
Step 131: 1.042 sec
Step 132: 1.275 sec
Step 133: 1.235 sec
Step 134: 1.454 sec
Step 135: 1.064 sec
Step 136: 1.388 sec
Step 137: 1.106 sec
Step 138: 1.271 sec
Step 139: 1.062 sec
Step 140: 1.607 sec
Step: 140, train_loss: 5.0385, train_accuracy: 87.5000%
Step 141: 1.201 sec
Step 142: 1.028 sec
Step 143: 1.019 sec
Step 144: 1.118 sec
Step 145: 1.363 sec
Step 146: 1.217 sec
Step 147: 1.020 sec
Step 148: 1.024 sec
Step 149: 1.064 sec
Step 150: 1.103 sec
Step: 150, train_loss: 4.9513, train_accuracy: 87.5000%
**  Step 150, val_loss = 4.92, val_accuracy = 87.50%  **
Step 151: 1.126 sec
Step 152: 1.155 sec
Step 153: 1.416 sec
Step 154: 1.062 sec
Step 155: 1.100 sec
Step 156: 1.079 sec
Step 157: 1.599 sec
Step 158: 1.020 sec
Step 159: 1.025 sec
Step 160: 1.100 sec
Step: 160, train_loss: 4.7998, train_accuracy: 87.5000%
Step 161: 1.155 sec
Step 162: 1.064 sec
Step 163: 1.293 sec
Step 164: 1.222 sec
Step 165: 1.083 sec
Step 166: 1.466 sec
Step 167: 1.076 sec
Step 168: 1.033 sec
Step 169: 1.018 sec
Step 170: 1.258 sec
Step: 170, train_loss: 4.8054, train_accuracy: 87.5000%
Step 171: 1.524 sec
Step 172: 1.086 sec
Step 173: 1.113 sec
Step 174: 1.640 sec
Step 175: 1.243 sec
Step 176: 1.049 sec
Step 177: 1.024 sec
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-2-93fd337a0d5c> in <module>()
----> 1 train()

<ipython-input-1-f45193fe7af3> in train()
    113             _, train_loss, train_accuracy = sess.run([train_op,loss,accuracy], feed_dict ={
    114                                                           images_placeholder: tra_images,
--> 115                                                           labels_placeholder: tra_labels
    116                                                           })
    117             duration = time.time() - start_time

/home/wjj/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/wjj/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    963     if final_fetches or final_targets:
    964       results = self._do_run(handle, final_targets, final_fetches,
--> 965                              feed_dict_string, options, run_metadata)
    966     else:
    967       results = []

/home/wjj/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1013     if handle is None:
   1014       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1015                            target_list, options, run_metadata)
   1016     else:
   1017       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/wjj/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1020   def _do_call(self, fn, *args):
   1021     try:
-> 1022       return fn(*args)
   1023     except errors.OpError as e:
   1024       message = compat.as_text(e.message)

/home/wjj/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1002         return tf_session.TF_Run(session, options,
   1003                                  feed_dict, fetch_list, target_list,
-> 1004                                  status, run_metadata)
   1005 
   1006     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

In [ ]: