In [1]:
import pandas as pd
import numpy as np
import yaml
%matplotlib inline
In [2]:
with open("param.yaml", "r") as file:
param = yaml.load(file.read())
param
Out[2]:
{'forget_bias': 1.0,
'learning_rate': 0.1,
'length_of_sequences': 40,
'num_of_hidden_nodes': 2,
'num_of_input_nodes': 1,
'num_of_output_nodes': 1,
'num_of_prediction_epochs': 100,
'num_of_training_epochs': 2000,
'optimizer': 'GradientDescentOptimizer',
'seed': 0,
'size_of_mini_batch': 100,
'train_data_path': '../train_data/normal.npy'}
In [3]:
train = np.load(param["train_data_path"])
train
Out[3]:
array([[ 0.00000000e+00, 1.25333234e-01],
[ 1.25333234e-01, 2.48689887e-01],
[ 2.48689887e-01, 3.68124553e-01],
...,
[ -3.68124553e-01, -2.48689887e-01],
[ -2.48689887e-01, -1.25333234e-01],
[ -1.25333234e-01, 3.92877345e-15]])
In [4]:
initial = np.load("initial.npy")
initial
Out[4]:
array([ 0.00000000e+00, 1.25333234e-01, 2.48689887e-01,
3.68124553e-01, 4.81753674e-01, 5.87785252e-01,
6.84547106e-01, 7.70513243e-01, 8.44327926e-01,
9.04827052e-01, 9.51056516e-01, 9.82287251e-01,
9.98026728e-01, 9.98026728e-01, 9.82287251e-01,
9.51056516e-01, 9.04827052e-01, 8.44327926e-01,
7.70513243e-01, 6.84547106e-01, 5.87785252e-01,
4.81753674e-01, 3.68124553e-01, 2.48689887e-01,
1.25333234e-01, -3.21624530e-16, -1.25333234e-01,
-2.48689887e-01, -3.68124553e-01, -4.81753674e-01,
-5.87785252e-01, -6.84547106e-01, -7.70513243e-01,
-8.44327926e-01, -9.04827052e-01, -9.51056516e-01,
-9.82287251e-01, -9.98026728e-01, -9.98026728e-01,
-9.82287251e-01])
In [5]:
output = np.load("output.npy")
output
Out[5]:
array([-0.95332503, -0.92011762, -0.87067926, -0.80673462, -0.72664982,
-0.6295045 , -0.51608354, -0.39027759, -0.25941375, -0.13179824,
-0.01229022, 0.09983977, 0.2087644 , 0.31856743, 0.43100587,
0.54469228, 0.65530688, 0.75680256, 0.8433274 , 0.91096622,
0.95839024, 0.98634923, 0.99665821, 0.99132085, 0.97201979,
0.939928 , 0.89572382, 0.83971882, 0.77204663, 0.69287592,
0.6026001 , 0.50195622, 0.39205647, 0.27436021, 0.15064213,
0.02298426, -0.10623097, -0.23435897, -0.35858494, -0.47607589,
-0.58418685, -0.68068433, -0.7639299 , -0.83297348, -0.88752097,
-0.92779565, -0.9543364 , -0.96778202, -0.96867907, -0.95732439,
-0.93364966, -0.89714622, -0.84685129, -0.78144413, -0.69957149,
-0.60057408, -0.48571104, -0.35944456, -0.22935978, -0.10326896,
0.01498237, 0.12683848, 0.23645565, 0.34736943, 0.46055913,
0.57385713, 0.68240333, 0.7800858 , 0.86150599, 0.92347634,
0.96532542, 0.98818612, 0.99398392, 0.98465836, 0.9617523 ,
0.92629033, 0.87883079, 0.81961417, 0.7487635 , 0.66649449,
0.57329082, 0.46999818, 0.35783315, 0.23834869, 0.11340808,
-0.01481996, -0.14387387, -0.27105039, -0.39351258, -0.50845891,
-0.61333954, -0.70607579, -0.78522331, -0.85002929, -0.90036577,
-0.93656743, -0.95922077, -0.96895397, -0.9662534 , -0.95131803])
In [6]:
losses = np.load("losses.npy")
losses
Out[6]:
array([[ 1.00000000e+01, 4.73526448e-01],
[ 2.00000000e+01, 5.08257747e-01],
[ 3.00000000e+01, 4.54759717e-01],
[ 4.00000000e+01, 2.80228168e-01],
[ 5.00000000e+01, 1.65138155e-01],
[ 6.00000000e+01, 1.09463006e-01],
[ 7.00000000e+01, 7.65380561e-02],
[ 8.00000000e+01, 6.26570210e-02],
[ 9.00000000e+01, 4.47940640e-02],
[ 1.00000000e+02, 4.65185940e-02],
[ 1.10000000e+02, 3.54321599e-02],
[ 1.20000000e+02, 2.85149198e-02],
[ 1.30000000e+02, 2.11832710e-02],
[ 1.40000000e+02, 1.79777406e-02],
[ 1.50000000e+02, 1.71740204e-02],
[ 1.60000000e+02, 1.81729309e-02],
[ 1.70000000e+02, 1.42980386e-02],
[ 1.80000000e+02, 1.55191096e-02],
[ 1.90000000e+02, 1.29750585e-02],
[ 2.00000000e+02, 1.27227064e-02],
[ 2.10000000e+02, 1.07249832e-02],
[ 2.20000000e+02, 1.07771801e-02],
[ 2.30000000e+02, 9.32370313e-03],
[ 2.40000000e+02, 9.34475567e-03],
[ 2.50000000e+02, 7.00758444e-03],
[ 2.60000000e+02, 6.87805656e-03],
[ 2.70000000e+02, 5.27604204e-03],
[ 2.80000000e+02, 5.85581316e-03],
[ 2.90000000e+02, 4.36704187e-03],
[ 3.00000000e+02, 3.69102159e-03],
[ 3.10000000e+02, 5.06135682e-03],
[ 3.20000000e+02, 4.99257492e-03],
[ 3.30000000e+02, 3.32856528e-03],
[ 3.40000000e+02, 2.86839548e-02],
[ 3.50000000e+02, 3.13794264e-03],
[ 3.60000000e+02, 2.57683941e-03],
[ 3.70000000e+02, 2.84193526e-03],
[ 3.80000000e+02, 2.57769506e-03],
[ 3.90000000e+02, 2.43673287e-03],
[ 4.00000000e+02, 1.95735763e-03],
[ 4.10000000e+02, 2.08921172e-03],
[ 4.20000000e+02, 3.67373903e-03],
[ 4.30000000e+02, 9.19662975e-03],
[ 4.40000000e+02, 2.20143120e-03],
[ 4.50000000e+02, 1.79757690e-03],
[ 4.60000000e+02, 1.60659384e-03],
[ 4.70000000e+02, 1.76684896e-03],
[ 4.80000000e+02, 1.26494805e-03],
[ 4.90000000e+02, 1.30318431e-03],
[ 5.00000000e+02, 1.33646699e-03],
[ 5.10000000e+02, 1.31752167e-03],
[ 5.20000000e+02, 1.07782008e-03],
[ 5.30000000e+02, 4.75899782e-03],
[ 5.40000000e+02, 1.34602515e-03],
[ 5.50000000e+02, 1.27844769e-03],
[ 5.60000000e+02, 1.41843292e-03],
[ 5.70000000e+02, 9.64831386e-04],
[ 5.80000000e+02, 1.16178463e-03],
[ 5.90000000e+02, 1.22023595e-03],
[ 6.00000000e+02, 1.06213009e-03],
[ 6.10000000e+02, 1.07174413e-03],
[ 6.20000000e+02, 9.28206078e-04],
[ 6.30000000e+02, 8.45505507e-04],
[ 6.40000000e+02, 9.45625361e-04],
[ 6.50000000e+02, 1.34726497e-03],
[ 6.60000000e+02, 8.42845067e-04],
[ 6.70000000e+02, 7.23399047e-04],
[ 6.80000000e+02, 7.92542414e-04],
[ 6.90000000e+02, 7.57670496e-04],
[ 7.00000000e+02, 7.35711248e-04],
[ 7.10000000e+02, 7.27594190e-04],
[ 7.20000000e+02, 9.24280903e-04],
[ 7.30000000e+02, 1.18303846e-03],
[ 7.40000000e+02, 7.04787439e-04],
[ 7.50000000e+02, 9.24750580e-04],
[ 7.60000000e+02, 6.57544355e-04],
[ 7.70000000e+02, 9.27980698e-04],
[ 7.80000000e+02, 6.52156770e-04],
[ 7.90000000e+02, 6.19596511e-04],
[ 8.00000000e+02, 6.30084483e-04],
[ 8.10000000e+02, 5.63483918e-04],
[ 8.20000000e+02, 6.06661895e-04],
[ 8.30000000e+02, 6.27219502e-04],
[ 8.40000000e+02, 5.34282241e-04],
[ 8.50000000e+02, 5.03387477e-04],
[ 8.60000000e+02, 5.50669967e-04],
[ 8.70000000e+02, 5.09230536e-04],
[ 8.80000000e+02, 5.22422371e-04],
[ 8.90000000e+02, 4.84370219e-04],
[ 9.00000000e+02, 4.98556998e-04],
[ 9.10000000e+02, 5.48293116e-04],
[ 9.20000000e+02, 5.09995909e-04],
[ 9.30000000e+02, 4.25449980e-04],
[ 9.40000000e+02, 4.50736057e-04],
[ 9.50000000e+02, 4.75562614e-04],
[ 9.60000000e+02, 3.98678472e-04],
[ 9.70000000e+02, 6.71610585e-04],
[ 9.80000000e+02, 4.37908253e-04],
[ 9.90000000e+02, 4.95554705e-04],
[ 1.00000000e+03, 3.85223917e-04],
[ 1.01000000e+03, 5.43306407e-04],
[ 1.02000000e+03, 3.48218542e-04],
[ 1.03000000e+03, 5.09103178e-04],
[ 1.04000000e+03, 3.79028235e-04],
[ 1.05000000e+03, 3.90479719e-04],
[ 1.06000000e+03, 4.93360334e-04],
[ 1.07000000e+03, 4.40368720e-04],
[ 1.08000000e+03, 3.44827888e-04],
[ 1.09000000e+03, 3.67852888e-04],
[ 1.10000000e+03, 3.39852355e-04],
[ 1.11000000e+03, 3.68691457e-04],
[ 1.12000000e+03, 3.66393971e-04],
[ 1.13000000e+03, 3.53295647e-04],
[ 1.14000000e+03, 3.25827277e-04],
[ 1.15000000e+03, 3.81391845e-04],
[ 1.16000000e+03, 4.33518668e-04],
[ 1.17000000e+03, 4.65435238e-04],
[ 1.18000000e+03, 2.79254280e-04],
[ 1.19000000e+03, 3.14876001e-04],
[ 1.20000000e+03, 3.62697465e-04],
[ 1.21000000e+03, 3.58972407e-04],
[ 1.22000000e+03, 3.81029997e-04],
[ 1.23000000e+03, 3.61439510e-04],
[ 1.24000000e+03, 2.95001693e-04],
[ 1.25000000e+03, 2.54618208e-04],
[ 1.26000000e+03, 3.91870708e-04],
[ 1.27000000e+03, 3.26853013e-04],
[ 1.28000000e+03, 3.10488598e-04],
[ 1.29000000e+03, 2.82349909e-04],
[ 1.30000000e+03, 2.80228793e-04],
[ 1.31000000e+03, 3.09809431e-04],
[ 1.32000000e+03, 3.49995593e-04],
[ 1.33000000e+03, 2.66510091e-04],
[ 1.34000000e+03, 2.51806137e-04],
[ 1.35000000e+03, 2.60608416e-04],
[ 1.36000000e+03, 2.88757612e-04],
[ 1.37000000e+03, 3.11634998e-04],
[ 1.38000000e+03, 2.89236574e-04],
[ 1.39000000e+03, 2.72273726e-04],
[ 1.40000000e+03, 3.28727096e-04],
[ 1.41000000e+03, 2.59889581e-04],
[ 1.42000000e+03, 2.86595605e-04],
[ 1.43000000e+03, 2.51565041e-04],
[ 1.44000000e+03, 2.39228160e-04],
[ 1.45000000e+03, 2.55520601e-04],
[ 1.46000000e+03, 2.50665529e-04],
[ 1.47000000e+03, 3.05900641e-04],
[ 1.48000000e+03, 2.27265205e-04],
[ 1.49000000e+03, 2.41819784e-04],
[ 1.50000000e+03, 2.43078539e-04],
[ 1.51000000e+03, 2.93111574e-04],
[ 1.52000000e+03, 2.63434660e-04],
[ 1.53000000e+03, 2.56292871e-04],
[ 1.54000000e+03, 2.57723266e-04],
[ 1.55000000e+03, 2.13120089e-04],
[ 1.56000000e+03, 2.29957514e-04],
[ 1.57000000e+03, 2.94378435e-04],
[ 1.58000000e+03, 2.37885528e-04],
[ 1.59000000e+03, 2.30167512e-04],
[ 1.60000000e+03, 2.64378701e-04],
[ 1.61000000e+03, 2.52351747e-04],
[ 1.62000000e+03, 2.24911753e-04],
[ 1.63000000e+03, 2.57736188e-04],
[ 1.64000000e+03, 2.24463292e-04],
[ 1.65000000e+03, 2.65713403e-04],
[ 1.66000000e+03, 2.42124515e-04],
[ 1.67000000e+03, 2.50803801e-04],
[ 1.68000000e+03, 2.17329391e-04],
[ 1.69000000e+03, 1.91169907e-04],
[ 1.70000000e+03, 2.39249115e-04],
[ 1.71000000e+03, 2.12851694e-04],
[ 1.72000000e+03, 2.25069976e-04],
[ 1.73000000e+03, 2.18229165e-04],
[ 1.74000000e+03, 2.29177545e-04],
[ 1.75000000e+03, 2.71543307e-04],
[ 1.76000000e+03, 2.19050969e-04],
[ 1.77000000e+03, 2.27583718e-04],
[ 1.78000000e+03, 2.36822438e-04],
[ 1.79000000e+03, 2.28106714e-04],
[ 1.80000000e+03, 2.27066877e-04],
[ 1.81000000e+03, 2.14227737e-04],
[ 1.82000000e+03, 2.16090906e-04],
[ 1.83000000e+03, 2.09024176e-04],
[ 1.84000000e+03, 2.21643873e-04],
[ 1.85000000e+03, 1.96708497e-04],
[ 1.86000000e+03, 1.78733200e-04],
[ 1.87000000e+03, 1.87470840e-04],
[ 1.88000000e+03, 1.90660707e-04],
[ 1.89000000e+03, 1.95501038e-04],
[ 1.90000000e+03, 2.32313207e-04],
[ 1.91000000e+03, 1.79974886e-04],
[ 1.92000000e+03, 2.05500313e-04],
[ 1.93000000e+03, 1.85861587e-04],
[ 1.94000000e+03, 2.30946796e-04],
[ 1.95000000e+03, 1.69568957e-04],
[ 1.96000000e+03, 1.68543440e-04],
[ 1.97000000e+03, 2.20824964e-04],
[ 1.98000000e+03, 2.18802015e-04],
[ 1.99000000e+03, 2.10983664e-04],
[ 2.00000000e+03, 1.91029903e-04]])
In [7]:
train_df = pd.DataFrame(train[:len(initial) + len(output), 0], columns=["train"])
initial_df = pd.DataFrame(initial, columns=["initial"])
output_df = pd.DataFrame(output, columns=["output"], index=range(len(initial), len(initial) + len(output)))
merged = pd.concat([train_df, initial_df, output_df])
merged.plot(figsize=(15, 5), grid=True, style=["-", "-", "k--"])
Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x110519940>
In [8]:
losses_df = pd.DataFrame(losses, columns=["epoch", "loss"])
losses_df.plot(figsize=(15, 5), grid=True, logy=True, x="epoch")
Out[8]:
<matplotlib.axes._subplots.AxesSubplot at 0x1104ed588>
In [ ]:
Content source: nayutaya/tensorflow-rnn-sin
Similar notebooks: