In [1]:
# Необходмые команды импорта.
import sys
sys.path.append('../physlearn/')
sys.path.append('../source')
import numpy as np
from numpy import linalg as LA
import tensorflow as tf
from matplotlib import pylab as plt
import numpy.random as rand
from physlearn.NeuralNet.NeuralNet import NeuralNet
from physlearn.Optimizer.NelderMead.NelderMead import NelderMead
import d1_osc
import ann_constructor
import math_util
from lagarisTF import LagarisSolverTF
from visualiser import Visualiser

n_sig = 7
a = -5
b = 5
m = 200
train_x = np.linspace(a, b, m, endpoint = True).reshape(1, m) 

lagar = LagarisSolverTF()
lagar.define_psi(n_sig)
net_x = lagar.get_net_x()
dim = lagar.get_dim()
sess = lagar.get_sess()

potential = tf.square(net_x)
lagar.define_H_psi(potential)
lagar.define_cost(train_x)

J = lagar.get_J()

opt_nm = NelderMead(-2.5,2.5)
opt_nm.set_epsilon_and_sd(0.3, 100)

def opt(J, dim, n_it, eps):
    optimisation_result = opt_nm.optimize(J, dim+1, n_it, eps)
    return optimisation_result

%matplotlib inline


D:\Anaconda\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

In [2]:
train_x[-1]


Out[2]:
array([-5.        , -4.94974874, -4.89949749, -4.84924623, -4.79899497,
       -4.74874372, -4.69849246, -4.64824121, -4.59798995, -4.54773869,
       -4.49748744, -4.44723618, -4.39698492, -4.34673367, -4.29648241,
       -4.24623116, -4.1959799 , -4.14572864, -4.09547739, -4.04522613,
       -3.99497487, -3.94472362, -3.89447236, -3.84422111, -3.79396985,
       -3.74371859, -3.69346734, -3.64321608, -3.59296482, -3.54271357,
       -3.49246231, -3.44221106, -3.3919598 , -3.34170854, -3.29145729,
       -3.24120603, -3.19095477, -3.14070352, -3.09045226, -3.04020101,
       -2.98994975, -2.93969849, -2.88944724, -2.83919598, -2.78894472,
       -2.73869347, -2.68844221, -2.63819095, -2.5879397 , -2.53768844,
       -2.48743719, -2.43718593, -2.38693467, -2.33668342, -2.28643216,
       -2.2361809 , -2.18592965, -2.13567839, -2.08542714, -2.03517588,
       -1.98492462, -1.93467337, -1.88442211, -1.83417085, -1.7839196 ,
       -1.73366834, -1.68341709, -1.63316583, -1.58291457, -1.53266332,
       -1.48241206, -1.4321608 , -1.38190955, -1.33165829, -1.28140704,
       -1.23115578, -1.18090452, -1.13065327, -1.08040201, -1.03015075,
       -0.9798995 , -0.92964824, -0.87939698, -0.82914573, -0.77889447,
       -0.72864322, -0.67839196, -0.6281407 , -0.57788945, -0.52763819,
       -0.47738693, -0.42713568, -0.37688442, -0.32663317, -0.27638191,
       -0.22613065, -0.1758794 , -0.12562814, -0.07537688, -0.02512563,
        0.02512563,  0.07537688,  0.12562814,  0.1758794 ,  0.22613065,
        0.27638191,  0.32663317,  0.37688442,  0.42713568,  0.47738693,
        0.52763819,  0.57788945,  0.6281407 ,  0.67839196,  0.72864322,
        0.77889447,  0.82914573,  0.87939698,  0.92964824,  0.9798995 ,
        1.03015075,  1.08040201,  1.13065327,  1.18090452,  1.23115578,
        1.28140704,  1.33165829,  1.38190955,  1.4321608 ,  1.48241206,
        1.53266332,  1.58291457,  1.63316583,  1.68341709,  1.73366834,
        1.7839196 ,  1.83417085,  1.88442211,  1.93467337,  1.98492462,
        2.03517588,  2.08542714,  2.13567839,  2.18592965,  2.2361809 ,
        2.28643216,  2.33668342,  2.38693467,  2.43718593,  2.48743719,
        2.53768844,  2.5879397 ,  2.63819095,  2.68844221,  2.73869347,
        2.78894472,  2.83919598,  2.88944724,  2.93969849,  2.98994975,
        3.04020101,  3.09045226,  3.14070352,  3.19095477,  3.24120603,
        3.29145729,  3.34170854,  3.3919598 ,  3.44221106,  3.49246231,
        3.54271357,  3.59296482,  3.64321608,  3.69346734,  3.74371859,
        3.79396985,  3.84422111,  3.89447236,  3.94472362,  3.99497487,
        4.04522613,  4.09547739,  4.14572864,  4.1959799 ,  4.24623116,
        4.29648241,  4.34673367,  4.39698492,  4.44723618,  4.49748744,
        4.54773869,  4.59798995,  4.64824121,  4.69849246,  4.74874372,
        4.79899497,  4.84924623,  4.89949749,  4.94974874,  5.        ])

In [3]:
psi = lagar.get_psi()
H_psi = lagar.get_H_psi()
psi_psi = tf.multiply(psi,psi)
psi_H_psi = tf.multiply(H_psi,psi)
norm = tf.reduce_sum(psi_psi)*(b-a)/m
mean_E = tf.reduce_sum(psi_H_psi)*(b-a)/m
eps = mean_E/norm

cost = tf.reduce_sum(tf.square(H_psi-eps*psi))/norm + tf.square(norm-1)

def run(expr,x,beta):
    return lagar.get_net().calc(expr, {lagar.get_net().x : x, lagar.beta : beta})

In [4]:
aa = np.linspace(a, b, dim+1, endpoint = True)
print(J(aa))

def J_my(params):
    beta_l = params[-1]
    lagar.get_net().roll_matrixes(params[0:-1])
    return run(cost,train_x, beta_l)

def J_my2(params):
    beta_l = params[-1]
    lagar.get_net().roll_matrixes(params[0:-1])
    return run(lagar.get_cost(),train_x, beta_l)

print(J_my(aa))
print(J_my2(aa))


12789.859961147637
12789.859961147637
12789.859961147637

In [5]:
#optimisation_result = opt(J_quad, dim, int(35), 1e-6)
optimisation_result = opt(J, dim, int(4e3), 1e-6)
print("J after optimisation: ", J(optimisation_result.x))
print("Информация: ", optimisation_result)


...  2519 (62%) 1179.147 it\s
J after optimisation:  9.829187944957717e-07
Информация:  Is converge: True
Amount of iterations: 2519
Total time: 2.14 s
Reached function value: 9.829187944957717e-07
Reason of break: Minimum cost reached


In [6]:
#print(run(psi,train_x,0.5))
print(run(psi_psi-tf.square(psi),train_x,0.5))
print(run(psi_psi-tf.square(psi),train_x,0.5).shape)
print(run(norm,train_x,0.5))
print(run(norm,train_x,0.5).shape)

print()


print(run(mean_E,train_x,0.5))
print(run(mean_E,train_x,0.5).shape)

print()


print(run(eps,train_x,0.5))
print(run(eps,train_x,0.5).shape)

print()


print(run(cost,train_x,0.5))
print(run(cost,train_x,0.5).shape)


[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0.]]
(1, 200)
1.0009292714986178
()

1.0009311265443734
()

1.0000018533235149
()

0.0001495983319894637
()

In [7]:
psi_list = np.abs(lagar.calc_psi(train_x, optimisation_result.x[-1]))

%matplotlib inline
H_psi_list = lagar.calc_H_psi(train_x, optimisation_result.x[-1])
plt.title('Difference between trial function and real:')
plt.plot(train_x[0], psi_list)
#plt.plot(train_xi[0], image_array)
#plt.plot(train_xi[0], d1_osc.wf(0, train_xi))
Visualiser.show_wf(0, train_x)
plt.plot(train_x[0], psi_list - d1_osc.wf(0, train_x), 'g--')


Out[7]:
[<matplotlib.lines.Line2D at 0x23727285e80>]

In [8]:
H_psi_list = np.abs(lagar.calc_H_psi(train_x, optimisation_result.x[-1]))

In [9]:
psi = lagar.get_psi()

In [10]:
H_psi_list


Out[10]:
array([2.58021905e-06, 3.32221792e-06, 4.26657057e-06, 5.46523964e-06,
       6.98263152e-06, 8.89833714e-06, 1.13104161e-05, 1.43393175e-05,
       1.81325436e-05, 2.28701760e-05, 2.87713994e-05, 3.61021732e-05,
       4.51842169e-05, 5.64054942e-05, 7.02323933e-05, 8.72238224e-05,
       1.08047451e-04, 1.33498346e-04, 1.64520251e-04, 2.02229797e-04,
       2.47943891e-04, 3.03210575e-04, 3.69843614e-04, 4.49961067e-04,
       5.46028072e-04, 6.60904045e-04, 7.97894451e-04, 9.60807237e-04,
       1.15401397e-03, 1.38251561e-03, 1.65201274e-03, 1.96898004e-03,
       2.34074450e-03, 2.77556685e-03, 3.28272546e-03, 3.87260176e-03,
       4.55676606e-03, 5.34806240e-03, 6.26069088e-03, 7.31028564e-03,
       8.51398644e-03, 9.89050167e-03, 1.14601602e-02, 1.32449493e-02,
       1.52685365e-02, 1.75562712e-02, 2.01351638e-02, 2.30338399e-02,
       2.62824643e-02, 2.99126349e-02, 3.39572406e-02, 3.84502833e-02,
       4.34266602e-02, 4.89219052e-02, 5.49718882e-02, 6.16124722e-02,
       6.88791276e-02, 7.68065050e-02, 8.54279690e-02, 9.47750954e-02,
       1.04877135e-01, 1.15760451e-01, 1.27447933e-01, 1.39958399e-01,
       1.53305987e-01, 1.67499552e-01, 1.82542074e-01, 1.98430088e-01,
       2.15153151e-01, 2.32693347e-01, 2.51024856e-01, 2.70113587e-01,
       2.89916892e-01, 3.10383368e-01, 3.31452760e-01, 3.53055970e-01,
       3.75115184e-01, 3.97544122e-01, 4.20248404e-01, 4.43126051e-01,
       4.66068112e-01, 4.88959410e-01, 5.11679409e-01, 5.34103194e-01,
       5.56102551e-01, 5.77547134e-01, 5.98305716e-01, 6.18247488e-01,
       6.37243412e-01, 6.55167587e-01, 6.71898627e-01, 6.87321012e-01,
       7.01326410e-01, 7.13814936e-01, 7.24696326e-01, 7.33891025e-01,
       7.41331145e-01, 7.46961297e-01, 7.50739280e-01, 7.52636600e-01,
       7.52638833e-01, 7.50745808e-01, 7.46971612e-01, 7.41344420e-01,
       7.33906151e-01, 7.24711954e-01, 7.13829528e-01, 7.01338307e-01,
       6.87328492e-01, 6.71899981e-01, 6.55161185e-01, 6.37227769e-01,
       6.18221329e-01, 5.98268027e-01, 5.77497211e-01, 5.56040031e-01,
       5.34028082e-01, 5.11592085e-01, 4.88860627e-01, 4.65958984e-01,
       4.43008024e-01, 4.20123217e-01, 3.97413767e-01, 3.74981845e-01,
       3.52921968e-01, 3.31320494e-01, 3.10255246e-01, 2.89795273e-01,
       2.70000721e-01, 2.50922830e-01, 2.32604040e-01, 2.15078189e-01,
       1.98370815e-01, 1.82499530e-01, 1.67474457e-01, 1.53298739e-01,
       1.39969076e-01, 1.27476306e-01, 1.15806003e-01, 1.04939085e-01,
       9.48524320e-02, 8.55194805e-02, 7.69108171e-02, 6.89947412e-02,
       6.17378007e-02, 5.51052941e-02, 4.90617352e-02, 4.35712777e-02,
       3.85980971e-02, 3.41067304e-02, 3.00623732e-02, 2.64311333e-02,
       2.31802450e-02, 2.02782431e-02, 1.76951005e-02, 1.54023319e-02,
       1.33730655e-02, 1.15820868e-02, 1.00058573e-02, 8.62250998e-03,
       7.41182716e-03, 6.35520087e-03, 5.43558051e-03, 4.63740956e-03,
       3.94655396e-03, 3.35022440e-03, 2.83689447e-03, 2.39621653e-03,
       2.01893657e-03, 1.69680979e-03, 1.42251756e-03, 1.18958698e-03,
       9.92313631e-04, 8.25688006e-04, 6.85326188e-04, 5.67404879e-04,
       4.68601025e-04, 3.86036051e-04, 3.17224687e-04, 2.60028280e-04,
       2.12612444e-04, 1.73408849e-04, 1.41080922e-04, 1.14493216e-04,
       9.26841785e-05, 7.48420708e-05, 6.02837509e-05, 4.84360760e-05,
       3.88196645e-05, 3.10347792e-05, 2.47491056e-05, 1.96872123e-05,
       1.56214990e-05, 1.23644510e-05, 9.76203813e-06, 7.68810966e-06,
       6.03965197e-06, 4.73279174e-06, 3.69943897e-06, 2.88447769e-06])