In this notebook a DNNRegressor is used through TensorFlow's tf.contrib.learn library. The example shows how to generate the feature_columns and feed the input using input_fn argument.
In [1]:
# Used to clear up the workspace.
%reset -f
import numpy as np
import pickle
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import estimator
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
# Load the data.
data = pickle.load(open('../data/data-ant.pkl', 'rb'))
observations = data['observations']
actions = data['actions']
# We will only look at the first label column, since multiple regression is not supported for some reason...
actions = actions[:, 0]
# Split the data.
X_train, X_test, y_train, y_test = train_test_split(observations, actions, test_size=10, random_state=42)
num_train = X_train.shape[0]
num_test = X_test.shape[0]
pred_fn and feed_fn functions take lists or numpy arrays as input and generate feature columns or labels. Feature columns takes the form of a dictionary with column names as Keys and tf.constant of columns as Values, while the label is simply a tf.constant of labels.
np.newaxis is added in order to address TensorFlow's warning that the input should be a two instead of one dimensional tensor.
In [2]:
def pred_fn(X):
return {str("my_col" + str(k)): tf.constant(X[:, k][:, np.newaxis]) for k in range(X.shape[1])}
def input_fn(X, y):
feature_cols = pred_fn(X)
label = tf.constant(y)
return feature_cols, label
In [3]:
feature_cols = [tf.contrib.layers.real_valued_column(str("my_col") + str(i)) for i in range(X_train.shape[1])]
# This does not work for some reason.
#feature_cols = tf.contrib.learn.infer_real_valued_columns_from_input(X_train)
In [4]:
regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_cols, hidden_units=[100, 100])
regressor.fit(input_fn=lambda: input_fn(X_train, y_train), steps=1000);
In [5]:
pred = list(regressor.predict_scores(input_fn=lambda: pred_fn(X_test)))
print pred
print y_test
print mean_squared_error(pred, y_test)