So convergence of the gradient descent depends on the starting point $\large x_0$ and the learning rate $\large \eta$
In [ ]:
from time import sleep
import numpy as np
from ipywidgets import *
import bqplot.pyplot as plt
from bqplot import Toolbar
In [ ]:
f = lambda x: np.exp(-x) * np.sin(5 * x)
df = lambda x: -np.exp(-x) * np.sin(5 * x) + 5 * np.cos(5 *x) * np.exp(-x)
In [ ]:
x = np.linspace(0.5, 2.5, 500)
y = f(x)
In [ ]:
def update_sol_path(x, y):
with sol_path.hold_sync():
sol_path.x = x
sol_path.y = y
with sol_points.hold_sync():
sol_points.x = x
sol_points.y = y
In [ ]:
def gradient_descent(x0, f, df, eta=.1, tol=1e-6, num_iters=10):
x = [x0]
i = 0
while i < num_iters:
x_prev = x[-1]
grad = df(x_prev)
x_curr = x_prev - eta * grad
x.append(x_curr)
sol_lbl.value = sol_lbl_tmpl.format(x_curr)
sleep(.5)
update_sol_path(x, [f(i) for i in x])
if np.abs(x_curr - x_prev) < tol:
break
i += 1
In [ ]:
txt_layout = Layout(width='150px')
x0_box = FloatText(description='x0', layout=txt_layout, value=2.4)
eta_box = FloatText(description='Learning Rate',
style={'description_width':'initial'},
layout=txt_layout, value=.1)
go_btn = Button(description='GO', button_style='success', layout=Layout(width='50px'))
reset_btn = Button(description='Reset', button_style='success', layout=Layout(width='100px'))
sol_lbl_tmpl = 'x = {:.4f}'
sol_lbl = Label()
# sol_lbl.layout.width = '300px'
# plot of curve and solution
fig_layout = Layout(width='720px', height='500px')
fig = plt.figure(layout=fig_layout, title='Gradient Descent', display_toolbar=True)
fig.pyplot = Toolbar(figure=fig)
curve = plt.plot(x, y, colors=['dodgerblue'], stroke_width=2)
sol_path = plt.plot([], [], colors=['#ccc'], opacities=[.7])
sol_points = plt.plot([], [], 'mo', default_size=20)
def optimize():
f.marks = [curve]
gradient_descent(x0_box.value, f, df, eta=eta_box.value)
def reset():
curve.scales['x'].min = .4
curve.scales['x'].max = 2.5
curve.scales['y'].min = -.5
curve.scales['y'].max = .4
sol_path.x = sol_path.y = []
sol_points.x = sol_points.y = []
sol_lbl.value = ''
go_btn.on_click(lambda btn: optimize())
reset_btn.on_click(lambda btn: reset())
final_fig = VBox([fig, fig.pyplot],
layout=Layout(overflow_x='hidden'))
HBox([final_fig, VBox([x0_box, eta_box, go_btn, reset_btn, sol_lbl])])
In [ ]:
In [ ]: