In [1]:
from physlearn.NeuralNet.NeuralNet import NeuralNet
import numpy
import tensorflow as tf
from math import pi
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
x_train = numpy.linspace(0, 2 * pi, 30).reshape(1, 30)
y_train = numpy.sin(x_train[0]).reshape(1, 30)

In [3]:
x_cv = numpy.linspace(0, 2 * pi, 1000).reshape(1, 1000)
y_cv = numpy.sin(x_cv[0]).reshape(1, 1000)

In [4]:
plt.plot(x_train[0], y_train[0], 'x', color='red')


Out[4]:
[<matplotlib.lines.Line2D at 0x7fadb7495160>]

In [5]:
net = NeuralNet(-1, 1)

In [6]:
net.add_input_layer(1)
net.add(10, tf.sigmoid)
net.add_output_layer(1)

In [7]:
net.compile()

In [8]:
cost_list = net.train('prediction', x_train, y_train, 5, 3000, 0.1)
plt.plot(list(map(lambda item: item ** (-1), cost_list)))


100%|██████████| 3000/3000 [00:01<00:00, 1959.27it/s]
Out[8]:
[<matplotlib.lines.Line2D at 0x7fadb3c03320>]

In [9]:
net.calculate_cost(x_train, y_train)


Out[9]:
0.06241226434886385

In [10]:
y_pred = net.run(x_cv)
plt.plot(x_train[0], y_train[0], 'x', color='red')
plt.plot(x_cv[0], y_pred[0])


Out[10]:
[<matplotlib.lines.Line2D at 0x7fadb1a3c518>]

In [11]:
res = net.unroll_matrixes()

In [12]:
res[0]


Out[12]:
[(array([[ 0.86298793],
         [ 1.56153218],
         [ 1.46124137],
         [ 1.26388663],
         [-0.79724256],
         [ 2.05805715],
         [ 0.70013601],
         [ 0.86209423],
         [ 1.23698465],
         [-0.26954899]]), array([[ 0.45230029],
         [ 0.0773675 ],
         [ 0.06125173],
         [-2.54524229],
         [-0.76043917],
         [-1.9324716 ],
         [ 0.20241035],
         [-0.22628199],
         [-2.38871301],
         [-0.09356024]])),
 (array([[-0.8177541 ,  1.35590382,  0.87282782, -2.21011461,  0.6381469 ,
           2.70699917, -0.17095081, -0.36841195, -1.7560022 , -0.84169342]]),
  array([[-0.16331892]]))]

In [13]:
unroll_vector = res[2]

In [14]:
net.roll_matrixes(unroll_vector)


Out[14]:
[(array([[ 0.86298793],
         [ 1.56153218],
         [ 1.46124137],
         [ 1.26388663],
         [-0.79724256],
         [ 2.05805715],
         [ 0.70013601],
         [ 0.86209423],
         [ 1.23698465],
         [-0.26954899]]), array([[ 0.45230029],
         [ 0.0773675 ],
         [ 0.06125173],
         [-2.54524229],
         [-0.76043917],
         [-1.9324716 ],
         [ 0.20241035],
         [-0.22628199],
         [-2.38871301],
         [-0.09356024]])),
 (array([[-0.8177541 ,  1.35590382,  0.87282782, -2.21011461,  0.6381469 ,
           2.70699917, -0.17095081, -0.36841195, -1.7560022 , -0.84169342]]),
  array([[-0.16331892]]))]