In [1]:
import pandas as pd
import numpy as np
import yaml
%matplotlib inline

In [2]:
with open("param.yaml", "r") as file:
    param = yaml.load(file.read())
param


Out[2]:
{'forget_bias': 1.0,
 'learning_rate': 0.1,
 'length_of_sequences': 50,
 'num_of_hidden_nodes': 3,
 'num_of_input_nodes': 1,
 'num_of_output_nodes': 1,
 'num_of_prediction_epochs': 100,
 'num_of_training_epochs': 2000,
 'optimizer': 'GradientDescentOptimizer',
 'seed': 0,
 'size_of_mini_batch': 100,
 'train_data_path': '../train_data/normal.npy'}

In [3]:
train = np.load(param["train_data_path"])
train


Out[3]:
array([[  0.00000000e+00,   1.25333234e-01],
       [  1.25333234e-01,   2.48689887e-01],
       [  2.48689887e-01,   3.68124553e-01],
       ..., 
       [ -3.68124553e-01,  -2.48689887e-01],
       [ -2.48689887e-01,  -1.25333234e-01],
       [ -1.25333234e-01,   3.92877345e-15]])

In [4]:
initial = np.load("initial.npy")
initial


Out[4]:
array([  0.00000000e+00,   1.25333234e-01,   2.48689887e-01,
         3.68124553e-01,   4.81753674e-01,   5.87785252e-01,
         6.84547106e-01,   7.70513243e-01,   8.44327926e-01,
         9.04827052e-01,   9.51056516e-01,   9.82287251e-01,
         9.98026728e-01,   9.98026728e-01,   9.82287251e-01,
         9.51056516e-01,   9.04827052e-01,   8.44327926e-01,
         7.70513243e-01,   6.84547106e-01,   5.87785252e-01,
         4.81753674e-01,   3.68124553e-01,   2.48689887e-01,
         1.25333234e-01,  -3.21624530e-16,  -1.25333234e-01,
        -2.48689887e-01,  -3.68124553e-01,  -4.81753674e-01,
        -5.87785252e-01,  -6.84547106e-01,  -7.70513243e-01,
        -8.44327926e-01,  -9.04827052e-01,  -9.51056516e-01,
        -9.82287251e-01,  -9.98026728e-01,  -9.98026728e-01,
        -9.82287251e-01,  -9.51056516e-01,  -9.04827052e-01,
        -8.44327926e-01,  -7.70513243e-01,  -6.84547106e-01,
        -5.87785252e-01,  -4.81753674e-01,  -3.68124553e-01,
        -2.48689887e-01,  -1.25333234e-01])

In [5]:
output = np.load("output.npy")
output


Out[5]:
array([ 0.00988895,  0.14695144,  0.25904918,  0.37243086,  0.48757735,
        0.60086477,  0.70625579,  0.79767245,  0.87118709,  0.92568642,
        0.96206832,  0.98212868,  0.98782188,  0.98087126,  0.96256214,
        0.93362796,  0.89419365,  0.84375602,  0.78121251,  0.70501488,
        0.61359841,  0.5063135 ,  0.38486832,  0.25445098,  0.1227048 ,
       -0.00394048, -0.12332624, -0.23694739, -0.34727857, -0.45588797,
       -0.56281322, -0.66638774, -0.76301855, -0.84742582, -0.91429216,
       -0.96066731, -0.98690897, -0.99553448, -0.98949629, -0.97108847,
       -0.94154423, -0.90096766, -0.84836352, -0.78172845, -0.69837976,
       -0.5959689 , -0.47468704, -0.3400445 , -0.20278543, -0.07297406,
        0.04647714,  0.15935385,  0.27105466,  0.38469699,  0.49980545,
        0.61241013,  0.71640521,  0.80589384,  0.87726289,  0.929663  ,
        0.96412772,  0.98248988,  0.98668706,  0.97839779,  0.95884544,
        0.9286924 ,  0.88798875,  0.83615476,  0.77202064,  0.69400114,
        0.60057926,  0.49131578,  0.36833155,  0.23726918,  0.1058723 ,
       -0.01986998, -0.13839036, -0.25146723, -0.36152959, -0.4699595 ,
       -0.57658154, -0.67949384, -0.77485186, -0.85724586, -0.92153937,
       -0.96520978, -0.98899168, -0.99554962, -0.98779291, -0.96788913,
       -0.93692487, -0.89486128, -0.84056956, -0.77193618, -0.68624902,
       -0.5813356 , -0.4579193 , -0.32228211, -0.18550214, -0.05699062])

In [6]:
losses = np.load("losses.npy")
losses


Out[6]:
array([[  1.00000000e+01,   5.17640173e-01],
       [  2.00000000e+01,   4.58749175e-01],
       [  3.00000000e+01,   2.87575126e-01],
       [  4.00000000e+01,   1.02874897e-01],
       [  5.00000000e+01,   2.68941447e-02],
       [  6.00000000e+01,   9.21770558e-03],
       [  7.00000000e+01,   4.33860486e-03],
       [  8.00000000e+01,   2.47932505e-03],
       [  9.00000000e+01,   1.76582334e-03],
       [  1.00000000e+02,   1.08325377e-03],
       [  1.10000000e+02,   1.31148845e-03],
       [  1.20000000e+02,   1.11469300e-03],
       [  1.30000000e+02,   1.12147781e-03],
       [  1.40000000e+02,   9.03500710e-04],
       [  1.50000000e+02,   1.31986197e-03],
       [  1.60000000e+02,   1.06893585e-03],
       [  1.70000000e+02,   1.05251546e-03],
       [  1.80000000e+02,   9.63881146e-04],
       [  1.90000000e+02,   9.65702813e-04],
       [  2.00000000e+02,   7.77539390e-04],
       [  2.10000000e+02,   7.09654414e-04],
       [  2.20000000e+02,   8.98576691e-04],
       [  2.30000000e+02,   8.85663030e-04],
       [  2.40000000e+02,   8.54851620e-04],
       [  2.50000000e+02,   8.33328231e-04],
       [  2.60000000e+02,   7.32842600e-04],
       [  2.70000000e+02,   7.19416596e-04],
       [  2.80000000e+02,   7.29121617e-04],
       [  2.90000000e+02,   6.74145238e-04],
       [  3.00000000e+02,   6.20858918e-04],
       [  3.10000000e+02,   5.73282829e-04],
       [  3.20000000e+02,   5.51229285e-04],
       [  3.30000000e+02,   5.21205424e-04],
       [  3.40000000e+02,   5.80608146e-04],
       [  3.50000000e+02,   6.63045212e-04],
       [  3.60000000e+02,   7.47237180e-04],
       [  3.70000000e+02,   5.58376545e-04],
       [  3.80000000e+02,   5.32474194e-04],
       [  3.90000000e+02,   5.64160408e-04],
       [  4.00000000e+02,   5.88307390e-04],
       [  4.10000000e+02,   4.75764507e-04],
       [  4.20000000e+02,   5.20279049e-04],
       [  4.30000000e+02,   5.85597591e-04],
       [  4.40000000e+02,   5.14558167e-04],
       [  4.50000000e+02,   4.70157625e-04],
       [  4.60000000e+02,   5.42318274e-04],
       [  4.70000000e+02,   5.69655909e-04],
       [  4.80000000e+02,   5.79453947e-04],
       [  4.90000000e+02,   5.01807837e-04],
       [  5.00000000e+02,   5.61874593e-04],
       [  5.10000000e+02,   5.62841014e-04],
       [  5.20000000e+02,   5.05150994e-04],
       [  5.30000000e+02,   5.13631734e-04],
       [  5.40000000e+02,   4.58313007e-04],
       [  5.50000000e+02,   3.60261503e-04],
       [  5.60000000e+02,   3.63296713e-04],
       [  5.70000000e+02,   4.56286449e-04],
       [  5.80000000e+02,   3.75610893e-04],
       [  5.90000000e+02,   3.88536952e-04],
       [  6.00000000e+02,   4.58248600e-04],
       [  6.10000000e+02,   4.35279158e-04],
       [  6.20000000e+02,   4.54050343e-04],
       [  6.30000000e+02,   4.81469266e-04],
       [  6.40000000e+02,   4.09183878e-04],
       [  6.50000000e+02,   4.57609305e-04],
       [  6.60000000e+02,   4.26695275e-04],
       [  6.70000000e+02,   4.12801048e-04],
       [  6.80000000e+02,   3.89356603e-04],
       [  6.90000000e+02,   3.01652413e-04],
       [  7.00000000e+02,   3.94860952e-04],
       [  7.10000000e+02,   2.84416717e-04],
       [  7.20000000e+02,   3.87385575e-04],
       [  7.30000000e+02,   3.98543634e-04],
       [  7.40000000e+02,   3.84653424e-04],
       [  7.50000000e+02,   3.93883209e-04],
       [  7.60000000e+02,   3.71801318e-04],
       [  7.70000000e+02,   3.77119170e-04],
       [  7.80000000e+02,   3.71019385e-04],
       [  7.90000000e+02,   3.92213085e-04],
       [  8.00000000e+02,   3.17707425e-04],
       [  8.10000000e+02,   3.30518553e-04],
       [  8.20000000e+02,   3.60570702e-04],
       [  8.30000000e+02,   3.10018833e-04],
       [  8.40000000e+02,   3.23202315e-04],
       [  8.50000000e+02,   3.22268053e-04],
       [  8.60000000e+02,   3.66954482e-04],
       [  8.70000000e+02,   3.10977630e-04],
       [  8.80000000e+02,   3.42415617e-04],
       [  8.90000000e+02,   3.03473702e-04],
       [  9.00000000e+02,   3.05668538e-04],
       [  9.10000000e+02,   2.83227273e-04],
       [  9.20000000e+02,   3.17177997e-04],
       [  9.30000000e+02,   2.98463594e-04],
       [  9.40000000e+02,   2.56721425e-04],
       [  9.50000000e+02,   3.05421301e-04],
       [  9.60000000e+02,   2.67835043e-04],
       [  9.70000000e+02,   2.97779508e-04],
       [  9.80000000e+02,   2.94320867e-04],
       [  9.90000000e+02,   3.03071487e-04],
       [  1.00000000e+03,   3.29837203e-04],
       [  1.01000000e+03,   3.03327251e-04],
       [  1.02000000e+03,   2.94231577e-04],
       [  1.03000000e+03,   2.78369291e-04],
       [  1.04000000e+03,   2.41354704e-04],
       [  1.05000000e+03,   2.62441201e-04],
       [  1.06000000e+03,   3.11763928e-04],
       [  1.07000000e+03,   2.63200549e-04],
       [  1.08000000e+03,   2.70960474e-04],
       [  1.09000000e+03,   2.44562805e-04],
       [  1.10000000e+03,   2.51854857e-04],
       [  1.11000000e+03,   3.00497399e-04],
       [  1.12000000e+03,   2.12032668e-04],
       [  1.13000000e+03,   2.80644657e-04],
       [  1.14000000e+03,   2.24027390e-04],
       [  1.15000000e+03,   2.45179632e-04],
       [  1.16000000e+03,   2.36973618e-04],
       [  1.17000000e+03,   2.58092419e-04],
       [  1.18000000e+03,   2.21217095e-04],
       [  1.19000000e+03,   2.43080329e-04],
       [  1.20000000e+03,   2.26773802e-04],
       [  1.21000000e+03,   2.48659519e-04],
       [  1.22000000e+03,   2.41746413e-04],
       [  1.23000000e+03,   2.41732079e-04],
       [  1.24000000e+03,   2.47240736e-04],
       [  1.25000000e+03,   2.26257311e-04],
       [  1.26000000e+03,   2.41022164e-04],
       [  1.27000000e+03,   2.12992323e-04],
       [  1.28000000e+03,   2.06907163e-04],
       [  1.29000000e+03,   2.36989275e-04],
       [  1.30000000e+03,   2.01667062e-04],
       [  1.31000000e+03,   2.10282014e-04],
       [  1.32000000e+03,   2.10499202e-04],
       [  1.33000000e+03,   1.88592181e-04],
       [  1.34000000e+03,   1.86043180e-04],
       [  1.35000000e+03,   2.34638224e-04],
       [  1.36000000e+03,   1.64584737e-04],
       [  1.37000000e+03,   2.14510219e-04],
       [  1.38000000e+03,   2.02244337e-04],
       [  1.39000000e+03,   2.35263215e-04],
       [  1.40000000e+03,   1.85667319e-04],
       [  1.41000000e+03,   2.24300442e-04],
       [  1.42000000e+03,   1.99049231e-04],
       [  1.43000000e+03,   1.84221717e-04],
       [  1.44000000e+03,   2.41212416e-04],
       [  1.45000000e+03,   2.00760594e-04],
       [  1.46000000e+03,   1.91481930e-04],
       [  1.47000000e+03,   1.88055623e-04],
       [  1.48000000e+03,   1.88597987e-04],
       [  1.49000000e+03,   1.72172746e-04],
       [  1.50000000e+03,   1.63977689e-04],
       [  1.51000000e+03,   1.85298952e-04],
       [  1.52000000e+03,   2.10915474e-04],
       [  1.53000000e+03,   2.28724486e-04],
       [  1.54000000e+03,   2.07080186e-04],
       [  1.55000000e+03,   1.75856767e-04],
       [  1.56000000e+03,   1.88581398e-04],
       [  1.57000000e+03,   1.77511800e-04],
       [  1.58000000e+03,   1.75164503e-04],
       [  1.59000000e+03,   1.54294947e-04],
       [  1.60000000e+03,   1.96938752e-04],
       [  1.61000000e+03,   2.11087376e-04],
       [  1.62000000e+03,   1.81542462e-04],
       [  1.63000000e+03,   1.70167827e-04],
       [  1.64000000e+03,   2.06581142e-04],
       [  1.65000000e+03,   1.56468639e-04],
       [  1.66000000e+03,   1.41733530e-04],
       [  1.67000000e+03,   1.78146220e-04],
       [  1.68000000e+03,   1.84586082e-04],
       [  1.69000000e+03,   1.73724184e-04],
       [  1.70000000e+03,   2.03003699e-04],
       [  1.71000000e+03,   1.77829046e-04],
       [  1.72000000e+03,   1.54185356e-04],
       [  1.73000000e+03,   1.94821376e-04],
       [  1.74000000e+03,   1.76068468e-04],
       [  1.75000000e+03,   1.53422850e-04],
       [  1.76000000e+03,   1.49866522e-04],
       [  1.77000000e+03,   1.82454532e-04],
       [  1.78000000e+03,   1.56082038e-04],
       [  1.79000000e+03,   1.81741401e-04],
       [  1.80000000e+03,   1.73784880e-04],
       [  1.81000000e+03,   1.40640099e-04],
       [  1.82000000e+03,   1.47329265e-04],
       [  1.83000000e+03,   1.47285915e-04],
       [  1.84000000e+03,   1.58819821e-04],
       [  1.85000000e+03,   1.37373892e-04],
       [  1.86000000e+03,   1.49361615e-04],
       [  1.87000000e+03,   1.67206381e-04],
       [  1.88000000e+03,   1.39467709e-04],
       [  1.89000000e+03,   1.39992699e-04],
       [  1.90000000e+03,   1.64010562e-04],
       [  1.91000000e+03,   1.79535171e-04],
       [  1.92000000e+03,   1.44858321e-04],
       [  1.93000000e+03,   1.57604794e-04],
       [  1.94000000e+03,   1.50522101e-04],
       [  1.95000000e+03,   1.70085696e-04],
       [  1.96000000e+03,   1.66418977e-04],
       [  1.97000000e+03,   1.57556613e-04],
       [  1.98000000e+03,   1.44291349e-04],
       [  1.99000000e+03,   1.47331768e-04],
       [  2.00000000e+03,   1.27852240e-04]])

In [7]:
train_df = pd.DataFrame(train[:len(initial) + len(output), 0], columns=["train"])
initial_df = pd.DataFrame(initial, columns=["initial"])
output_df = pd.DataFrame(output, columns=["output"], index=range(len(initial), len(initial) + len(output)))
merged = pd.concat([train_df, initial_df, output_df])
merged.plot(figsize=(15, 5), grid=True, style=["-", "-", "k--"])


Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x109bc8940>

In [8]:
losses_df = pd.DataFrame(losses, columns=["epoch", "loss"])
losses_df.plot(figsize=(15, 5), grid=True, logy=True, x="epoch")


Out[8]:
<matplotlib.axes._subplots.AxesSubplot at 0x109b9d748>

In [ ]: