In [1]:
import tensorflow as tf
from simulator import Sim
import numpy as np
import rospy
import pickle

In [3]:
rospy.init_node('trainer', anonymous=True)
sim = Sim()
data = sim.states
label = sim.actions

def randomize(dataset, labels):
    permutation = np.random.permutation(labels.shape[0])
    shuffled_dataset = dataset[permutation,:,:]
    shuffled_labels = labels[permutation]
    return shuffled_dataset, shuffled_labels

s_data, s_label = randomize(data,label)


train_data = s_data[:,:].reshape(-1,12).astype(np.float32)
train_label = s_label[:800,:].reshape(-1,3).astype(np.float32)
test_data = s_data[800:,:].reshape(-1,12).astype(np.float32)
test_label = s_label[800:,:].reshape(-1,3).astype(np.float32)

train_data = train_data[:,9:]
test_data = test_data[:,9:]

In [37]:
sim_data = pickle.load(open('./data/sim_data.p', 'rb'))
data = np.array(sim_data['poses'])
label = np.array(sim_data['actions'])
s_data, s_label = randomize(data,label)


train_data = s_data[:6500,:].reshape(-1,12).astype(np.float32)
train_label = s_label[:6500,:].reshape(-1,3).astype(np.float32)
valid_data = s_data[6500:7500,:].reshape(-1,12).astype(np.float32)
valid_label = s_label[6500:7500,:].reshape(-1,3).astype(np.float32)
test_data = s_data[7500:,:].reshape(-1,12).astype(np.float32)
test_label = s_label[7500:,:].reshape(-1,3).astype(np.float32)

train_data = train_data[:,9:]
valid_data = valid_data[:,9:]
test_data = test_data[:,9:]
train_label = (train_label - 10.0)/20.0
valid_label = (valid_label - 10.0)/20.0
test_label = (test_label - 10.0)/20

input_num = 3

In [38]:
graph = tf.Graph()
alpha = 0.01
hid_num1 = 500
hid_num2 = 100
input_num = 3

with graph.as_default():
    tf_train_dataset = tf.placeholder(tf.float32, shape=(None, input_num))
    tf_train_labels = tf.placeholder(tf.float32, shape=(None, 3))
    tf_valid_dataset = tf.constant(valid_data)
    tf_valid_labels = tf.constant(valid_label)
    tf_test_dataset = tf.constant(test_data)
    tf_test_labels = tf.constant(test_label)
    
    F1_weights = tf.Variable(tf.truncated_normal([input_num,hid_num1], stddev=1.0))
    F1_biases = tf.Variable(tf.constant(1.0, shape=[hid_num1]))
    
    F2_weights = tf.Variable(tf.truncated_normal([hid_num1, hid_num2], stddev=1.0))
    F2_biases = tf.Variable(tf.constant(1.0, shape=[hid_num2]))
    
    F3_weights = tf.Variable(tf.truncated_normal([hid_num2,3], stddev=1.0))
    F3_biases = tf.Variable(tf.constant(1.0, shape=[1]))
    
    def model(data):
        fc = tf.matmul(data, F1_weights)
        hidden = tf.nn.relu(fc + F1_biases)
        
        fc = tf.matmul(hidden, F2_weights)
        hidden = tf.nn.sigmoid(fc + F2_biases)
        
        fc = tf.matmul(hidden, F3_weights)
        output = tf.nn.tanh(fc + F3_biases)
    
        return output
    
    train_pred = model(tf_train_dataset)
    reg_loss = alpha * (tf.nn.l2_loss(F1_weights) + tf.nn.l2_loss(F2_weights) + tf.nn.l2_loss(F3_weights))
    loss = tf.losses.mean_squared_error(labels=tf_train_labels, predictions=train_pred)
    loss1 = loss + reg_loss
    optimizer = tf.train.AdamOptimizer(0.0001).minimize(loss1)
    #optimizer = tf.train.RMSPropOptimizer(0.001).minimize(loss)
    #optimizer = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
    
    valid_pred = model(tf_valid_dataset)
    valid_loss = tf.losses.mean_squared_error(labels=tf_valid_labels, predictions=valid_pred)
    test_pred = model(tf_test_dataset)
    test_loss = tf.losses.mean_squared_error(labels=tf_test_labels, predictions=test_pred)

In [40]:
num_steps = 6000
batch_size = 10
config = tf.ConfigProto()
config.log_device_placement = True
with tf.Session(graph=graph, config = config) as session:
    tf.global_variables_initializer().run()
    print('Initialized')
    for step in range(num_steps):
        offset = (step * batch_size) % (train_label.shape[0] - batch_size)
        batch_data = train_data[offset:(offset + batch_size), :]
        batch_labels = train_label[offset:(offset + batch_size), :]
        feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}
        _, l, vl = session.run([optimizer, loss, valid_loss], feed_dict=feed_dict)
        if (step % 2000 == 0):
            print('Minibatch training loss at step %d: %f' % (step, l))
            print('Minibatch validation loss at step %d: %f' % (step, vl))
            print('--------------------------------------')
    print('Test loss: %.3f' % test_loss.eval())
    test_rslt = test_pred.eval()
    i_test = 0
    while(i_test!=''):
        try:
            i_test = input("Input an index of test image (or Enter to quit): ")
            label = test_label[int(i_test),:]*10+10
            rslt = test_rslt[i_test,:]*10+10
            print label
            print rslt
        except:
            break


Initialized
Minibatch training loss at step 0: 1.131981
Minibatch validation loss at step 0: 1.085316
--------------------------------------
Minibatch training loss at step 2000: 0.028547
Minibatch validation loss at step 2000: 0.022426
--------------------------------------
Minibatch training loss at step 4000: 0.022844
Minibatch validation loss at step 4000: 0.020583
--------------------------------------
Test loss: 0.025
Input an index of test image (or Enter to quit): 1
[ 9.5  7.5  5. ]
[ 11.42263889   9.50590992   8.58586884]
Input an index of test image (or Enter to quit): 2
[ 10.5   8.   11. ]
[ 10.97403717   8.35712051  10.73413754]
Input an index of test image (or Enter to quit): 4
[  5.   13.5   6. ]
[  7.37871599  14.13130379   8.03944206]
Input an index of test image (or Enter to quit): 6
[  5.5   8.   13. ]
[  7.47579956   8.79560852  12.64152336]
Input an index of test image (or Enter to quit): 4
[  5.   13.5   6. ]
[  7.37871599  14.13130379   8.03944206]
Input an index of test image (or Enter to quit): 5
[ 7.   9.   5.5]
[  9.74113464  10.91098881   8.87540245]
Input an index of test image (or Enter to quit): 6
[  5.5   8.   13. ]
[  7.47579956   8.79560852  12.64152336]
Input an index of test image (or Enter to quit): 7
[ 12.    7.    6.5]
[ 12.87612057   8.51551437   8.30750465]
Input an index of test image (or Enter to quit): 2
[ 10.5   8.   11. ]
[ 10.97403717   8.35712051  10.73413754]
Input an index of test image (or Enter to quit): 1
[ 9.5  7.5  5. ]
[ 11.42263889   9.50590992   8.58586884]
Input an index of test image (or Enter to quit): 5
[ 7.   9.   5.5]
[  9.74113464  10.91098881   8.87540245]
Input an index of test image (or Enter to quit): 47
[ 10.5  12.5  13.5]
[  8.90600204  11.09647465  11.377738  ]
Input an index of test image (or Enter to quit): 54
[ 12.  11.   5.]
[ 12.2841177   11.24488163   7.11414862]
Input an index of test image (or Enter to quit): 34
[ 13.5  13.   13. ]
[ 11.55771255  10.83611774   9.9939003 ]
Input an index of test image (or Enter to quit): 23
[ 12.5  14.5   9.5]
[ 10.87737083  13.50594139   7.555583  ]
Input an index of test image (or Enter to quit): 76
[  9.5  11.5   9.5]
[  9.71635342  11.57206535   9.3998661 ]
Input an index of test image (or Enter to quit):