In [1]:
import pandas as pd
import numpy as np
import yaml
%matplotlib inline
In [2]:
with open("param.yaml", "r") as file:
param = yaml.load(file.read())
param
Out[2]:
{'forget_bias': 1.0,
'learning_rate': 0.1,
'length_of_sequences': 50,
'num_of_hidden_nodes': 3,
'num_of_input_nodes': 1,
'num_of_output_nodes': 1,
'num_of_prediction_epochs': 100,
'num_of_training_epochs': 2000,
'optimizer': 'GradientDescentOptimizer',
'seed': 0,
'size_of_mini_batch': 100,
'train_data_path': '../train_data/normal.npy'}
In [3]:
train = np.load(param["train_data_path"])
train
Out[3]:
array([[ 0.00000000e+00, 1.25333234e-01],
[ 1.25333234e-01, 2.48689887e-01],
[ 2.48689887e-01, 3.68124553e-01],
...,
[ -3.68124553e-01, -2.48689887e-01],
[ -2.48689887e-01, -1.25333234e-01],
[ -1.25333234e-01, 3.92877345e-15]])
In [4]:
initial = np.load("initial.npy")
initial
Out[4]:
array([ 0.00000000e+00, 1.25333234e-01, 2.48689887e-01,
3.68124553e-01, 4.81753674e-01, 5.87785252e-01,
6.84547106e-01, 7.70513243e-01, 8.44327926e-01,
9.04827052e-01, 9.51056516e-01, 9.82287251e-01,
9.98026728e-01, 9.98026728e-01, 9.82287251e-01,
9.51056516e-01, 9.04827052e-01, 8.44327926e-01,
7.70513243e-01, 6.84547106e-01, 5.87785252e-01,
4.81753674e-01, 3.68124553e-01, 2.48689887e-01,
1.25333234e-01, -3.21624530e-16, -1.25333234e-01,
-2.48689887e-01, -3.68124553e-01, -4.81753674e-01,
-5.87785252e-01, -6.84547106e-01, -7.70513243e-01,
-8.44327926e-01, -9.04827052e-01, -9.51056516e-01,
-9.82287251e-01, -9.98026728e-01, -9.98026728e-01,
-9.82287251e-01, -9.51056516e-01, -9.04827052e-01,
-8.44327926e-01, -7.70513243e-01, -6.84547106e-01,
-5.87785252e-01, -4.81753674e-01, -3.68124553e-01,
-2.48689887e-01, -1.25333234e-01])
In [5]:
output = np.load("output.npy")
output
Out[5]:
array([ 0.00988895, 0.14695144, 0.25904918, 0.37243086, 0.48757735,
0.60086477, 0.70625579, 0.79767245, 0.87118709, 0.92568642,
0.96206832, 0.98212868, 0.98782188, 0.98087126, 0.96256214,
0.93362796, 0.89419365, 0.84375602, 0.78121251, 0.70501488,
0.61359841, 0.5063135 , 0.38486832, 0.25445098, 0.1227048 ,
-0.00394048, -0.12332624, -0.23694739, -0.34727857, -0.45588797,
-0.56281322, -0.66638774, -0.76301855, -0.84742582, -0.91429216,
-0.96066731, -0.98690897, -0.99553448, -0.98949629, -0.97108847,
-0.94154423, -0.90096766, -0.84836352, -0.78172845, -0.69837976,
-0.5959689 , -0.47468704, -0.3400445 , -0.20278543, -0.07297406,
0.04647714, 0.15935385, 0.27105466, 0.38469699, 0.49980545,
0.61241013, 0.71640521, 0.80589384, 0.87726289, 0.929663 ,
0.96412772, 0.98248988, 0.98668706, 0.97839779, 0.95884544,
0.9286924 , 0.88798875, 0.83615476, 0.77202064, 0.69400114,
0.60057926, 0.49131578, 0.36833155, 0.23726918, 0.1058723 ,
-0.01986998, -0.13839036, -0.25146723, -0.36152959, -0.4699595 ,
-0.57658154, -0.67949384, -0.77485186, -0.85724586, -0.92153937,
-0.96520978, -0.98899168, -0.99554962, -0.98779291, -0.96788913,
-0.93692487, -0.89486128, -0.84056956, -0.77193618, -0.68624902,
-0.5813356 , -0.4579193 , -0.32228211, -0.18550214, -0.05699062])
In [6]:
losses = np.load("losses.npy")
losses
Out[6]:
array([[ 1.00000000e+01, 5.17640173e-01],
[ 2.00000000e+01, 4.58749175e-01],
[ 3.00000000e+01, 2.87575126e-01],
[ 4.00000000e+01, 1.02874897e-01],
[ 5.00000000e+01, 2.68941447e-02],
[ 6.00000000e+01, 9.21770558e-03],
[ 7.00000000e+01, 4.33860486e-03],
[ 8.00000000e+01, 2.47932505e-03],
[ 9.00000000e+01, 1.76582334e-03],
[ 1.00000000e+02, 1.08325377e-03],
[ 1.10000000e+02, 1.31148845e-03],
[ 1.20000000e+02, 1.11469300e-03],
[ 1.30000000e+02, 1.12147781e-03],
[ 1.40000000e+02, 9.03500710e-04],
[ 1.50000000e+02, 1.31986197e-03],
[ 1.60000000e+02, 1.06893585e-03],
[ 1.70000000e+02, 1.05251546e-03],
[ 1.80000000e+02, 9.63881146e-04],
[ 1.90000000e+02, 9.65702813e-04],
[ 2.00000000e+02, 7.77539390e-04],
[ 2.10000000e+02, 7.09654414e-04],
[ 2.20000000e+02, 8.98576691e-04],
[ 2.30000000e+02, 8.85663030e-04],
[ 2.40000000e+02, 8.54851620e-04],
[ 2.50000000e+02, 8.33328231e-04],
[ 2.60000000e+02, 7.32842600e-04],
[ 2.70000000e+02, 7.19416596e-04],
[ 2.80000000e+02, 7.29121617e-04],
[ 2.90000000e+02, 6.74145238e-04],
[ 3.00000000e+02, 6.20858918e-04],
[ 3.10000000e+02, 5.73282829e-04],
[ 3.20000000e+02, 5.51229285e-04],
[ 3.30000000e+02, 5.21205424e-04],
[ 3.40000000e+02, 5.80608146e-04],
[ 3.50000000e+02, 6.63045212e-04],
[ 3.60000000e+02, 7.47237180e-04],
[ 3.70000000e+02, 5.58376545e-04],
[ 3.80000000e+02, 5.32474194e-04],
[ 3.90000000e+02, 5.64160408e-04],
[ 4.00000000e+02, 5.88307390e-04],
[ 4.10000000e+02, 4.75764507e-04],
[ 4.20000000e+02, 5.20279049e-04],
[ 4.30000000e+02, 5.85597591e-04],
[ 4.40000000e+02, 5.14558167e-04],
[ 4.50000000e+02, 4.70157625e-04],
[ 4.60000000e+02, 5.42318274e-04],
[ 4.70000000e+02, 5.69655909e-04],
[ 4.80000000e+02, 5.79453947e-04],
[ 4.90000000e+02, 5.01807837e-04],
[ 5.00000000e+02, 5.61874593e-04],
[ 5.10000000e+02, 5.62841014e-04],
[ 5.20000000e+02, 5.05150994e-04],
[ 5.30000000e+02, 5.13631734e-04],
[ 5.40000000e+02, 4.58313007e-04],
[ 5.50000000e+02, 3.60261503e-04],
[ 5.60000000e+02, 3.63296713e-04],
[ 5.70000000e+02, 4.56286449e-04],
[ 5.80000000e+02, 3.75610893e-04],
[ 5.90000000e+02, 3.88536952e-04],
[ 6.00000000e+02, 4.58248600e-04],
[ 6.10000000e+02, 4.35279158e-04],
[ 6.20000000e+02, 4.54050343e-04],
[ 6.30000000e+02, 4.81469266e-04],
[ 6.40000000e+02, 4.09183878e-04],
[ 6.50000000e+02, 4.57609305e-04],
[ 6.60000000e+02, 4.26695275e-04],
[ 6.70000000e+02, 4.12801048e-04],
[ 6.80000000e+02, 3.89356603e-04],
[ 6.90000000e+02, 3.01652413e-04],
[ 7.00000000e+02, 3.94860952e-04],
[ 7.10000000e+02, 2.84416717e-04],
[ 7.20000000e+02, 3.87385575e-04],
[ 7.30000000e+02, 3.98543634e-04],
[ 7.40000000e+02, 3.84653424e-04],
[ 7.50000000e+02, 3.93883209e-04],
[ 7.60000000e+02, 3.71801318e-04],
[ 7.70000000e+02, 3.77119170e-04],
[ 7.80000000e+02, 3.71019385e-04],
[ 7.90000000e+02, 3.92213085e-04],
[ 8.00000000e+02, 3.17707425e-04],
[ 8.10000000e+02, 3.30518553e-04],
[ 8.20000000e+02, 3.60570702e-04],
[ 8.30000000e+02, 3.10018833e-04],
[ 8.40000000e+02, 3.23202315e-04],
[ 8.50000000e+02, 3.22268053e-04],
[ 8.60000000e+02, 3.66954482e-04],
[ 8.70000000e+02, 3.10977630e-04],
[ 8.80000000e+02, 3.42415617e-04],
[ 8.90000000e+02, 3.03473702e-04],
[ 9.00000000e+02, 3.05668538e-04],
[ 9.10000000e+02, 2.83227273e-04],
[ 9.20000000e+02, 3.17177997e-04],
[ 9.30000000e+02, 2.98463594e-04],
[ 9.40000000e+02, 2.56721425e-04],
[ 9.50000000e+02, 3.05421301e-04],
[ 9.60000000e+02, 2.67835043e-04],
[ 9.70000000e+02, 2.97779508e-04],
[ 9.80000000e+02, 2.94320867e-04],
[ 9.90000000e+02, 3.03071487e-04],
[ 1.00000000e+03, 3.29837203e-04],
[ 1.01000000e+03, 3.03327251e-04],
[ 1.02000000e+03, 2.94231577e-04],
[ 1.03000000e+03, 2.78369291e-04],
[ 1.04000000e+03, 2.41354704e-04],
[ 1.05000000e+03, 2.62441201e-04],
[ 1.06000000e+03, 3.11763928e-04],
[ 1.07000000e+03, 2.63200549e-04],
[ 1.08000000e+03, 2.70960474e-04],
[ 1.09000000e+03, 2.44562805e-04],
[ 1.10000000e+03, 2.51854857e-04],
[ 1.11000000e+03, 3.00497399e-04],
[ 1.12000000e+03, 2.12032668e-04],
[ 1.13000000e+03, 2.80644657e-04],
[ 1.14000000e+03, 2.24027390e-04],
[ 1.15000000e+03, 2.45179632e-04],
[ 1.16000000e+03, 2.36973618e-04],
[ 1.17000000e+03, 2.58092419e-04],
[ 1.18000000e+03, 2.21217095e-04],
[ 1.19000000e+03, 2.43080329e-04],
[ 1.20000000e+03, 2.26773802e-04],
[ 1.21000000e+03, 2.48659519e-04],
[ 1.22000000e+03, 2.41746413e-04],
[ 1.23000000e+03, 2.41732079e-04],
[ 1.24000000e+03, 2.47240736e-04],
[ 1.25000000e+03, 2.26257311e-04],
[ 1.26000000e+03, 2.41022164e-04],
[ 1.27000000e+03, 2.12992323e-04],
[ 1.28000000e+03, 2.06907163e-04],
[ 1.29000000e+03, 2.36989275e-04],
[ 1.30000000e+03, 2.01667062e-04],
[ 1.31000000e+03, 2.10282014e-04],
[ 1.32000000e+03, 2.10499202e-04],
[ 1.33000000e+03, 1.88592181e-04],
[ 1.34000000e+03, 1.86043180e-04],
[ 1.35000000e+03, 2.34638224e-04],
[ 1.36000000e+03, 1.64584737e-04],
[ 1.37000000e+03, 2.14510219e-04],
[ 1.38000000e+03, 2.02244337e-04],
[ 1.39000000e+03, 2.35263215e-04],
[ 1.40000000e+03, 1.85667319e-04],
[ 1.41000000e+03, 2.24300442e-04],
[ 1.42000000e+03, 1.99049231e-04],
[ 1.43000000e+03, 1.84221717e-04],
[ 1.44000000e+03, 2.41212416e-04],
[ 1.45000000e+03, 2.00760594e-04],
[ 1.46000000e+03, 1.91481930e-04],
[ 1.47000000e+03, 1.88055623e-04],
[ 1.48000000e+03, 1.88597987e-04],
[ 1.49000000e+03, 1.72172746e-04],
[ 1.50000000e+03, 1.63977689e-04],
[ 1.51000000e+03, 1.85298952e-04],
[ 1.52000000e+03, 2.10915474e-04],
[ 1.53000000e+03, 2.28724486e-04],
[ 1.54000000e+03, 2.07080186e-04],
[ 1.55000000e+03, 1.75856767e-04],
[ 1.56000000e+03, 1.88581398e-04],
[ 1.57000000e+03, 1.77511800e-04],
[ 1.58000000e+03, 1.75164503e-04],
[ 1.59000000e+03, 1.54294947e-04],
[ 1.60000000e+03, 1.96938752e-04],
[ 1.61000000e+03, 2.11087376e-04],
[ 1.62000000e+03, 1.81542462e-04],
[ 1.63000000e+03, 1.70167827e-04],
[ 1.64000000e+03, 2.06581142e-04],
[ 1.65000000e+03, 1.56468639e-04],
[ 1.66000000e+03, 1.41733530e-04],
[ 1.67000000e+03, 1.78146220e-04],
[ 1.68000000e+03, 1.84586082e-04],
[ 1.69000000e+03, 1.73724184e-04],
[ 1.70000000e+03, 2.03003699e-04],
[ 1.71000000e+03, 1.77829046e-04],
[ 1.72000000e+03, 1.54185356e-04],
[ 1.73000000e+03, 1.94821376e-04],
[ 1.74000000e+03, 1.76068468e-04],
[ 1.75000000e+03, 1.53422850e-04],
[ 1.76000000e+03, 1.49866522e-04],
[ 1.77000000e+03, 1.82454532e-04],
[ 1.78000000e+03, 1.56082038e-04],
[ 1.79000000e+03, 1.81741401e-04],
[ 1.80000000e+03, 1.73784880e-04],
[ 1.81000000e+03, 1.40640099e-04],
[ 1.82000000e+03, 1.47329265e-04],
[ 1.83000000e+03, 1.47285915e-04],
[ 1.84000000e+03, 1.58819821e-04],
[ 1.85000000e+03, 1.37373892e-04],
[ 1.86000000e+03, 1.49361615e-04],
[ 1.87000000e+03, 1.67206381e-04],
[ 1.88000000e+03, 1.39467709e-04],
[ 1.89000000e+03, 1.39992699e-04],
[ 1.90000000e+03, 1.64010562e-04],
[ 1.91000000e+03, 1.79535171e-04],
[ 1.92000000e+03, 1.44858321e-04],
[ 1.93000000e+03, 1.57604794e-04],
[ 1.94000000e+03, 1.50522101e-04],
[ 1.95000000e+03, 1.70085696e-04],
[ 1.96000000e+03, 1.66418977e-04],
[ 1.97000000e+03, 1.57556613e-04],
[ 1.98000000e+03, 1.44291349e-04],
[ 1.99000000e+03, 1.47331768e-04],
[ 2.00000000e+03, 1.27852240e-04]])
In [7]:
train_df = pd.DataFrame(train[:len(initial) + len(output), 0], columns=["train"])
initial_df = pd.DataFrame(initial, columns=["initial"])
output_df = pd.DataFrame(output, columns=["output"], index=range(len(initial), len(initial) + len(output)))
merged = pd.concat([train_df, initial_df, output_df])
merged.plot(figsize=(15, 5), grid=True, style=["-", "-", "k--"])
Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x109bc8940>
In [8]:
losses_df = pd.DataFrame(losses, columns=["epoch", "loss"])
losses_df.plot(figsize=(15, 5), grid=True, logy=True, x="epoch")
Out[8]:
<matplotlib.axes._subplots.AxesSubplot at 0x109b9d748>
In [ ]:
Content source: nayutaya/tensorflow-rnn-sin
Similar notebooks: