In [1]:
import pandas as pd
import numpy as np
import yaml
%matplotlib inline
In [2]:
with open("param.yaml", "r") as file:
param = yaml.load(file.read())
param
Out[2]:
{'forget_bias': 1.0,
'learning_rate': 0.1,
'length_of_sequences': 50,
'num_of_hidden_nodes': 2,
'num_of_input_nodes': 1,
'num_of_output_nodes': 1,
'num_of_prediction_epochs': 100,
'num_of_training_epochs': 3000,
'optimizer': 'GradientDescentOptimizer',
'seed': 0,
'size_of_mini_batch': 100,
'train_data_path': '../train_data/normal.npy'}
In [3]:
train = np.load(param["train_data_path"])
train
Out[3]:
array([[ 0.00000000e+00, 1.25333234e-01],
[ 1.25333234e-01, 2.48689887e-01],
[ 2.48689887e-01, 3.68124553e-01],
...,
[ -3.68124553e-01, -2.48689887e-01],
[ -2.48689887e-01, -1.25333234e-01],
[ -1.25333234e-01, 3.92877345e-15]])
In [4]:
initial = np.load("initial.npy")
initial
Out[4]:
array([ 0.00000000e+00, 1.25333234e-01, 2.48689887e-01,
3.68124553e-01, 4.81753674e-01, 5.87785252e-01,
6.84547106e-01, 7.70513243e-01, 8.44327926e-01,
9.04827052e-01, 9.51056516e-01, 9.82287251e-01,
9.98026728e-01, 9.98026728e-01, 9.82287251e-01,
9.51056516e-01, 9.04827052e-01, 8.44327926e-01,
7.70513243e-01, 6.84547106e-01, 5.87785252e-01,
4.81753674e-01, 3.68124553e-01, 2.48689887e-01,
1.25333234e-01, -3.21624530e-16, -1.25333234e-01,
-2.48689887e-01, -3.68124553e-01, -4.81753674e-01,
-5.87785252e-01, -6.84547106e-01, -7.70513243e-01,
-8.44327926e-01, -9.04827052e-01, -9.51056516e-01,
-9.82287251e-01, -9.98026728e-01, -9.98026728e-01,
-9.82287251e-01, -9.51056516e-01, -9.04827052e-01,
-8.44327926e-01, -7.70513243e-01, -6.84547106e-01,
-5.87785252e-01, -4.81753674e-01, -3.68124553e-01,
-2.48689887e-01, -1.25333234e-01])
In [5]:
output = np.load("output.npy")
output
Out[5]:
array([-0.01163007, 0.11507505, 0.24011663, 0.3597368 , 0.47010452,
0.56773204, 0.64985275, 0.71461999, 0.76108688, 0.78901577,
0.79862201, 0.79034865, 0.76473933, 0.72243309, 0.6642651 ,
0.59141982, 0.5055337 , 0.40864086, 0.3029201 , 0.19036371,
0.07259029, -0.04903597, -0.17303729, -0.2974163 , -0.41942739,
-0.53565294, -0.6423946 , -0.73625678, -0.81470472, -0.87637764,
-0.92107219, -0.94947195, -0.96278471, -0.96241534, -0.94973975,
-0.92597461, -0.89212275, -0.84896779, -0.79709655, -0.73693907,
-0.6688205 , -0.59302258, -0.50985676, -0.41975048, -0.32334602,
-0.2216067 , -0.11591803, -0.00815967, 0.0992815 , 0.20358637,
0.3016789 , 0.39047939, 0.46718192, 0.52948153, 0.57569683,
0.60477555, 0.61621398, 0.60995072, 0.58629096, 0.5458895 ,
0.48978436, 0.41942894, 0.33665282, 0.24351186, 0.14207268,
0.03425908, -0.07812446, -0.19318894, -0.30869773, -0.42190981,
-0.52964294, -0.62857759, -0.71570927, -0.78877628, -0.84650588,
-0.88860476, -0.91555351, -0.92831635, -0.92806906, -0.91600031,
-0.89318985, -0.86055166, -0.81882119, -0.76857066, -0.71024317,
-0.64419866, -0.57077193, -0.49034059, -0.40340501, -0.31067899,
-0.21318346, -0.11233169, -0.0099849 , 0.0915475 , 0.18958548,
0.28123939, 0.36363411, 0.43414801, 0.49060851, 0.53140152])
In [6]:
losses = np.load("losses.npy")
losses
Out[6]:
array([[ 1.00000000e+01, 5.21649063e-01],
[ 2.00000000e+01, 4.98259187e-01],
[ 3.00000000e+01, 5.12061834e-01],
[ 4.00000000e+01, 5.08497775e-01],
[ 5.00000000e+01, 3.76984179e-01],
[ 6.00000000e+01, 2.86902398e-01],
[ 7.00000000e+01, 2.23010182e-01],
[ 8.00000000e+01, 1.33334279e-01],
[ 9.00000000e+01, 1.09339476e-01],
[ 1.00000000e+02, 7.91212395e-02],
[ 1.10000000e+02, 6.07259162e-02],
[ 1.20000000e+02, 4.49932814e-02],
[ 1.30000000e+02, 3.70537601e-02],
[ 1.40000000e+02, 3.04589123e-02],
[ 1.50000000e+02, 2.42561046e-02],
[ 1.60000000e+02, 1.85057689e-02],
[ 1.70000000e+02, 1.69574283e-02],
[ 1.80000000e+02, 1.47507051e-02],
[ 1.90000000e+02, 1.35333249e-02],
[ 2.00000000e+02, 1.27079748e-02],
[ 2.10000000e+02, 1.12369396e-02],
[ 2.20000000e+02, 9.53455735e-03],
[ 2.30000000e+02, 1.04534468e-02],
[ 2.40000000e+02, 1.02619873e-02],
[ 2.50000000e+02, 9.32639278e-03],
[ 2.60000000e+02, 7.47229438e-03],
[ 2.70000000e+02, 9.06517543e-03],
[ 2.80000000e+02, 8.19949247e-03],
[ 2.90000000e+02, 8.28112382e-03],
[ 3.00000000e+02, 7.72635685e-03],
[ 3.10000000e+02, 7.05221435e-03],
[ 3.20000000e+02, 7.11524952e-03],
[ 3.30000000e+02, 7.60497572e-03],
[ 3.40000000e+02, 6.73286011e-03],
[ 3.50000000e+02, 6.27701543e-03],
[ 3.60000000e+02, 5.88546973e-03],
[ 3.70000000e+02, 6.70004915e-03],
[ 3.80000000e+02, 6.44200901e-03],
[ 3.90000000e+02, 5.44016296e-03],
[ 4.00000000e+02, 5.65867312e-03],
[ 4.10000000e+02, 4.56129666e-03],
[ 4.20000000e+02, 4.52047028e-03],
[ 4.30000000e+02, 4.51974012e-03],
[ 4.40000000e+02, 4.70956787e-03],
[ 4.50000000e+02, 4.77563404e-03],
[ 4.60000000e+02, 4.65578865e-03],
[ 4.70000000e+02, 4.77874139e-03],
[ 4.80000000e+02, 4.48078616e-03],
[ 4.90000000e+02, 4.22816677e-03],
[ 5.00000000e+02, 4.62924642e-03],
[ 5.10000000e+02, 3.80018540e-03],
[ 5.20000000e+02, 4.21846611e-03],
[ 5.30000000e+02, 3.91929084e-03],
[ 5.40000000e+02, 3.66806425e-03],
[ 5.50000000e+02, 3.34765459e-03],
[ 5.60000000e+02, 3.80461570e-03],
[ 5.70000000e+02, 3.85011546e-03],
[ 5.80000000e+02, 3.60693643e-03],
[ 5.90000000e+02, 3.15557350e-03],
[ 6.00000000e+02, 3.32210003e-03],
[ 6.10000000e+02, 3.28767416e-03],
[ 6.20000000e+02, 2.88289227e-03],
[ 6.30000000e+02, 2.84908828e-03],
[ 6.40000000e+02, 3.31581617e-03],
[ 6.50000000e+02, 2.80602695e-03],
[ 6.60000000e+02, 2.76946160e-03],
[ 6.70000000e+02, 2.71495990e-03],
[ 6.80000000e+02, 2.52547860e-03],
[ 6.90000000e+02, 2.60824198e-03],
[ 7.00000000e+02, 2.79463362e-03],
[ 7.10000000e+02, 2.58109998e-03],
[ 7.20000000e+02, 2.76109576e-03],
[ 7.30000000e+02, 2.31844047e-03],
[ 7.40000000e+02, 2.73273978e-03],
[ 7.50000000e+02, 2.58772913e-03],
[ 7.60000000e+02, 2.28331704e-03],
[ 7.70000000e+02, 2.69136811e-03],
[ 7.80000000e+02, 2.56716134e-03],
[ 7.90000000e+02, 2.48147198e-03],
[ 8.00000000e+02, 2.31331261e-03],
[ 8.10000000e+02, 2.15493771e-03],
[ 8.20000000e+02, 2.33376143e-03],
[ 8.30000000e+02, 2.01874319e-03],
[ 8.40000000e+02, 2.04768451e-03],
[ 8.50000000e+02, 2.07830686e-03],
[ 8.60000000e+02, 2.19512125e-03],
[ 8.70000000e+02, 2.42997194e-03],
[ 8.80000000e+02, 2.05315789e-03],
[ 8.90000000e+02, 2.27883598e-03],
[ 9.00000000e+02, 2.05022260e-03],
[ 9.10000000e+02, 1.96056394e-03],
[ 9.20000000e+02, 1.85251818e-03],
[ 9.30000000e+02, 1.73176068e-03],
[ 9.40000000e+02, 2.03073025e-03],
[ 9.50000000e+02, 2.23392597e-03],
[ 9.60000000e+02, 2.01467564e-03],
[ 9.70000000e+02, 1.90058933e-03],
[ 9.80000000e+02, 1.91394507e-03],
[ 9.90000000e+02, 1.67136267e-03],
[ 1.00000000e+03, 1.60842645e-03],
[ 1.01000000e+03, 1.46739825e-03],
[ 1.02000000e+03, 1.32955506e-03],
[ 1.03000000e+03, 1.61685015e-03],
[ 1.04000000e+03, 1.65561889e-03],
[ 1.05000000e+03, 1.60828943e-03],
[ 1.06000000e+03, 1.53161574e-03],
[ 1.07000000e+03, 1.59101770e-03],
[ 1.08000000e+03, 1.75807404e-03],
[ 1.09000000e+03, 1.51512772e-03],
[ 1.10000000e+03, 1.30687328e-03],
[ 1.11000000e+03, 1.39617291e-03],
[ 1.12000000e+03, 1.42504636e-03],
[ 1.13000000e+03, 1.33695570e-03],
[ 1.14000000e+03, 1.28335739e-03],
[ 1.15000000e+03, 1.73960568e-03],
[ 1.16000000e+03, 1.31126936e-03],
[ 1.17000000e+03, 1.56872405e-03],
[ 1.18000000e+03, 1.34267274e-03],
[ 1.19000000e+03, 1.39847491e-03],
[ 1.20000000e+03, 1.22323842e-03],
[ 1.21000000e+03, 1.13822997e-03],
[ 1.22000000e+03, 1.42938178e-03],
[ 1.23000000e+03, 1.43208459e-03],
[ 1.24000000e+03, 1.34460453e-03],
[ 1.25000000e+03, 1.88105193e-03],
[ 1.26000000e+03, 1.28378114e-03],
[ 1.27000000e+03, 1.45654904e-03],
[ 1.28000000e+03, 1.29129412e-03],
[ 1.29000000e+03, 1.52279844e-03],
[ 1.30000000e+03, 1.22636242e-03],
[ 1.31000000e+03, 1.21804758e-03],
[ 1.32000000e+03, 1.22829690e-03],
[ 1.33000000e+03, 9.27536166e-04],
[ 1.34000000e+03, 1.07712043e-03],
[ 1.35000000e+03, 1.22907746e-03],
[ 1.36000000e+03, 8.90024356e-04],
[ 1.37000000e+03, 1.03102927e-03],
[ 1.38000000e+03, 1.85414741e-03],
[ 1.39000000e+03, 1.05111743e-03],
[ 1.40000000e+03, 1.10729435e-03],
[ 1.41000000e+03, 1.04203122e-03],
[ 1.42000000e+03, 1.09857519e-03],
[ 1.43000000e+03, 1.39139174e-03],
[ 1.44000000e+03, 1.22202327e-03],
[ 1.45000000e+03, 9.88043612e-04],
[ 1.46000000e+03, 1.04574603e-03],
[ 1.47000000e+03, 9.85451043e-04],
[ 1.48000000e+03, 1.11699896e-03],
[ 1.49000000e+03, 8.67764291e-04],
[ 1.50000000e+03, 1.55183405e-03],
[ 1.51000000e+03, 9.26512352e-04],
[ 1.52000000e+03, 9.01811349e-04],
[ 1.53000000e+03, 8.56888713e-04],
[ 1.54000000e+03, 8.60738743e-04],
[ 1.55000000e+03, 1.33402413e-03],
[ 1.56000000e+03, 8.42967711e-04],
[ 1.57000000e+03, 1.12599775e-03],
[ 1.58000000e+03, 1.07662170e-03],
[ 1.59000000e+03, 8.95051111e-04],
[ 1.60000000e+03, 1.13469642e-03],
[ 1.61000000e+03, 9.23193758e-04],
[ 1.62000000e+03, 4.80679609e-03],
[ 1.63000000e+03, 6.39381399e-03],
[ 1.64000000e+03, 1.28123106e-03],
[ 1.65000000e+03, 8.66188726e-04],
[ 1.66000000e+03, 7.80873874e-04],
[ 1.67000000e+03, 8.33957631e-04],
[ 1.68000000e+03, 8.85847781e-04],
[ 1.69000000e+03, 8.98709462e-04],
[ 1.70000000e+03, 2.03564088e-03],
[ 1.71000000e+03, 1.06045767e-03],
[ 1.72000000e+03, 1.03221158e-03],
[ 1.73000000e+03, 2.90339463e-03],
[ 1.74000000e+03, 1.69916451e-03],
[ 1.75000000e+03, 7.87432713e-04],
[ 1.76000000e+03, 8.65494774e-04],
[ 1.77000000e+03, 7.52690423e-04],
[ 1.78000000e+03, 8.87037604e-04],
[ 1.79000000e+03, 6.85668783e-04],
[ 1.80000000e+03, 1.57764961e-03],
[ 1.81000000e+03, 1.33126695e-03],
[ 1.82000000e+03, 7.51425745e-04],
[ 1.83000000e+03, 1.42862799e-03],
[ 1.84000000e+03, 9.23742482e-04],
[ 1.85000000e+03, 1.11614575e-03],
[ 1.86000000e+03, 6.29607763e-04],
[ 1.87000000e+03, 7.32148415e-04],
[ 1.88000000e+03, 8.04407988e-04],
[ 1.89000000e+03, 7.26758619e-04],
[ 1.90000000e+03, 5.76120103e-04],
[ 1.91000000e+03, 5.07089484e-04],
[ 1.92000000e+03, 6.78267388e-04],
[ 1.93000000e+03, 6.63028506e-04],
[ 1.94000000e+03, 9.89668886e-04],
[ 1.95000000e+03, 7.62208458e-03],
[ 1.96000000e+03, 1.62866653e-03],
[ 1.97000000e+03, 1.11818663e-03],
[ 1.98000000e+03, 9.28766211e-04],
[ 1.99000000e+03, 7.85084092e-04],
[ 2.00000000e+03, 5.86601673e-04],
[ 2.01000000e+03, 5.18271117e-04],
[ 2.02000000e+03, 5.46841242e-04],
[ 2.03000000e+03, 5.96937956e-04],
[ 2.04000000e+03, 6.74773764e-04],
[ 2.05000000e+03, 7.64080731e-04],
[ 2.06000000e+03, 4.85375553e-04],
[ 2.07000000e+03, 5.50887955e-04],
[ 2.08000000e+03, 6.74259150e-04],
[ 2.09000000e+03, 1.61940453e-03],
[ 2.10000000e+03, 9.94797098e-04],
[ 2.11000000e+03, 8.42062058e-04],
[ 2.12000000e+03, 6.65519328e-04],
[ 2.13000000e+03, 1.27558562e-03],
[ 2.14000000e+03, 7.82914460e-04],
[ 2.15000000e+03, 6.54802890e-04],
[ 2.16000000e+03, 4.77460708e-04],
[ 2.17000000e+03, 4.48838546e-04],
[ 2.18000000e+03, 5.09934092e-04],
[ 2.19000000e+03, 1.19736861e-03],
[ 2.20000000e+03, 2.25670612e-03],
[ 2.21000000e+03, 1.21137057e-03],
[ 2.22000000e+03, 1.19532913e-03],
[ 2.23000000e+03, 6.92838104e-04],
[ 2.24000000e+03, 5.73544763e-04],
[ 2.25000000e+03, 5.02661103e-04],
[ 2.26000000e+03, 8.36043851e-04],
[ 2.27000000e+03, 3.61822662e-03],
[ 2.28000000e+03, 5.91823366e-04],
[ 2.29000000e+03, 5.25379321e-03],
[ 2.30000000e+03, 9.56484582e-04],
[ 2.31000000e+03, 4.29267355e-04],
[ 2.32000000e+03, 4.86311852e-04],
[ 2.33000000e+03, 1.08714704e-03],
[ 2.34000000e+03, 7.63970485e-04],
[ 2.35000000e+03, 4.33687819e-04],
[ 2.36000000e+03, 6.46872446e-04],
[ 2.37000000e+03, 5.11613558e-04],
[ 2.38000000e+03, 4.60672076e-04],
[ 2.39000000e+03, 9.29278263e-04],
[ 2.40000000e+03, 1.03490287e-03],
[ 2.41000000e+03, 2.36895331e-03],
[ 2.42000000e+03, 9.99339274e-04],
[ 2.43000000e+03, 4.04851598e-04],
[ 2.44000000e+03, 4.45433194e-04],
[ 2.45000000e+03, 4.33330453e-04],
[ 2.46000000e+03, 4.73813154e-04],
[ 2.47000000e+03, 3.66322114e-04],
[ 2.48000000e+03, 5.46103867e-04],
[ 2.49000000e+03, 4.20714496e-04],
[ 2.50000000e+03, 3.57265089e-04],
[ 2.51000000e+03, 8.84733628e-04],
[ 2.52000000e+03, 1.16093038e-03],
[ 2.53000000e+03, 1.10826846e-02],
[ 2.54000000e+03, 5.52506244e-04],
[ 2.55000000e+03, 3.56514647e-04],
[ 2.56000000e+03, 3.97591357e-04],
[ 2.57000000e+03, 3.52032192e-04],
[ 2.58000000e+03, 5.36011066e-04],
[ 2.59000000e+03, 3.50622169e-04],
[ 2.60000000e+03, 4.20680386e-04],
[ 2.61000000e+03, 6.49139867e-04],
[ 2.62000000e+03, 7.48177641e-04],
[ 2.63000000e+03, 7.72942847e-04],
[ 2.64000000e+03, 1.59547431e-03],
[ 2.65000000e+03, 1.42585742e-03],
[ 2.66000000e+03, 1.62872893e-03],
[ 2.67000000e+03, 5.75655838e-04],
[ 2.68000000e+03, 3.59694648e-04],
[ 2.69000000e+03, 3.37270903e-04],
[ 2.70000000e+03, 3.90696892e-04],
[ 2.71000000e+03, 3.95083043e-04],
[ 2.72000000e+03, 3.42289888e-04],
[ 2.73000000e+03, 4.12322755e-04],
[ 2.74000000e+03, 6.39108999e-04],
[ 2.75000000e+03, 3.79647157e-04],
[ 2.76000000e+03, 9.35855205e-04],
[ 2.77000000e+03, 8.24473507e-04],
[ 2.78000000e+03, 3.76925513e-04],
[ 2.79000000e+03, 4.23456106e-04],
[ 2.80000000e+03, 6.08165166e-04],
[ 2.81000000e+03, 2.65013136e-04],
[ 2.82000000e+03, 2.94867670e-04],
[ 2.83000000e+03, 3.18241422e-04],
[ 2.84000000e+03, 6.09116803e-04],
[ 2.85000000e+03, 1.00229483e-03],
[ 2.86000000e+03, 8.60595726e-04],
[ 2.87000000e+03, 1.75239798e-03],
[ 2.88000000e+03, 1.75377261e-03],
[ 2.89000000e+03, 2.74792721e-04],
[ 2.90000000e+03, 3.66954220e-04],
[ 2.91000000e+03, 3.54643620e-04],
[ 2.92000000e+03, 3.29624861e-04],
[ 2.93000000e+03, 3.26202222e-04],
[ 2.94000000e+03, 3.35671677e-04],
[ 2.95000000e+03, 2.80442648e-04],
[ 2.96000000e+03, 2.74985970e-04],
[ 2.97000000e+03, 3.10174742e-04],
[ 2.98000000e+03, 4.30501823e-04],
[ 2.99000000e+03, 8.73467303e-04],
[ 3.00000000e+03, 1.00922876e-03]])
In [7]:
train_df = pd.DataFrame(train[:len(initial) + len(output), 0], columns=["train"])
initial_df = pd.DataFrame(initial, columns=["initial"])
output_df = pd.DataFrame(output, columns=["output"], index=range(len(initial), len(initial) + len(output)))
merged = pd.concat([train_df, initial_df, output_df])
merged.plot(figsize=(15, 5), grid=True, style=["-", "-", "k--"])
Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x106674940>
In [8]:
losses_df = pd.DataFrame(losses, columns=["epoch", "loss"])
losses_df.plot(figsize=(15, 5), grid=True, logy=True, x="epoch")
Out[8]:
<matplotlib.axes._subplots.AxesSubplot at 0x106648860>
In [ ]:
Content source: nayutaya/tensorflow-rnn-sin
Similar notebooks: