In [ ]:
%matplotlib inline
'''
grad_descent.py
Use gradient descent to find the minimum value of a
single variable function. This also checks for the existence
of a solution for the equation, f'(x)=0 and plots the intermediate
points traversed.
'''
from sympy import Derivative, Symbol, sympify, solve
import matplotlib.pyplot as plt
def grad_descent(x0, f1x, x):
# check if f1x=0 has a solution
if not solve(f1x):
print('Cannot continue, solution for {0}=0 does not exist'.format(f1x))
return None
epsilon = 1e-6
step_size = 1e-4
x_old = x0
x_new = x_old - step_size*f1x.subs({x:x_old}).evalf()
# list to store the X values traversed
X_traversed = []
while abs(x_old - x_new) > epsilon:
X_traversed.append(x_new)
x_old = x_new
x_new = x_old-step_size*f1x.subs({x:x_old}).evalf()
return x_new, X_traversed
def frange(start, final, interval):
numbers = []
while start < final:
numbers.append(start)
start = start + interval
return numbers
def create_plot(X_traversed, f, var):
# First create the graph of the function itself
x_val = frange(-1, 1, 0.01)
f_val = [f.subs({var:x}) for x in x_val]
plt.plot(x_val, f_val, 'bo')
# calculate the function value at each of the intermediate
# points traversed
f_traversed = [f.subs({var:x}) for x in X_traversed]
plt.plot(X_traversed, f_traversed, 'r.')
plt.legend(['Function', 'Intermediate points'], loc='best')
plt.show()
if __name__ == '__main__':
f = input('Enter a function in one variable: ')
var = input('Enter the variable to differentiate with respect to: ')
var0 = float(input('Enter the initial value of the variable: '))
try:
f = sympify(f)
except SympifyError:
print('Invalid function entered')
else:
var = Symbol(var)
d = Derivative(f, var).doit()
var_min, X_traversed = grad_descent(var0, d, var)
if var_min:
print('{0}: {1}'.format(var.name, var_min))
print('Minimum value: {0}'.format(f.subs({var:var_min})))
create_plot(X_traversed, f, var)