In [1]:
%load_ext watermark
%watermark -a "Lilian Besson (Naereen)" -i -v -p numpy,matplotlib,scipy,seaborn
In [2]:
import numpy as np
from scipy import optimize as opt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (15, 8)
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context="notebook", style="darkgrid", palette="hls", font="sans-serif", font_scale=1.8)
In [3]:
def objective(x, a):
return np.exp(- a * x**2) - x
First, let's have a look to its plot for some values of $a$:
In [4]:
X = np.linspace(-2, 2, 2000)
for a in [0, -0.1, 0.1, -1, 1]:
plt.plot(X, objective(X, a), 'o-', label=f"$a={a:.3g}$", markevery=50)
plt.legend()
plt.xlabel("$x$"); plt.ylabel("$y$")
plt.title(r"Function $\exp(- a x^2) - x$ for different $a$")
plt.show()
Out[4]:
Out[4]:
Out[4]:
Out[4]:
Out[4]:
Out[4]:
Out[4]:
Out[4]:
Out[4]:
We can see that a solution to $\exp(-a x^2) = x$ has to be positive, as $\exp(-a x^2) > 0$ for any $x,a$. We also check that if $a < 0$, $\exp(-a x^2) - x$ seems to always be positive, but if $a \geq 0$, it seems to have a unique root. Let's zoom a little bit:
In [5]:
X = np.linspace(0, 1.5, 2000)
for a in [0, -0.1, 0.1, -1, 1]:
plt.plot(X, objective(X, a), 'o-', label=f"$a={a:.3g}$", markevery=50)
plt.legend()
plt.xlabel("$x$"); plt.ylabel("$y$")
plt.title(r"Function $\exp(- a x^2) - x$ for different $a$")
plt.show()
Out[5]:
Out[5]:
Out[5]:
Out[5]:
Out[5]:
Out[5]:
Out[5]:
Out[5]:
Out[5]:
The curve for $a=-0.1$ seems to stay negative, but that's not possible as for $a<0$ and $x\to\infty$, $\exp(-a x^2)$ dominates over $-x$. We can check that it will have a second root:
In [6]:
X = np.linspace(0, 5, 2000)
for a in [0, -0.1, 0.1, 1]:
plt.plot(X, objective(X, a), 'o-', label=f"$a={a:.3g}$", markevery=50)
plt.legend()
plt.xlabel("$x$"); plt.ylabel("$y$")
plt.title(r"Function $\exp(- a x^2) - x$ for different $a$")
plt.show()
Out[6]:
Out[6]:
Out[6]:
Out[6]:
Out[6]:
Out[6]:
Out[6]:
Out[6]:
We can start to try to use scipy.optimize.root
to numerically solve this equation.
In [7]:
def one_solution(a, x0=0, verb=False):
sol = opt.root(objective, x0, args=(a,))
if verb: print(sol)
if sol.success:
return sol.x
else:
raise ValueError(f"No solution was found for a = {a:.3g} (and starting at x0 = {x0:.3g}).")
Let's check that there is no solution for $a < 0$ too small.
In [8]:
one_solution(-1, verb=True)
It can find a solution, but only one (depending on the starting point $x_0$) and not both:
In [9]:
one_solution(-0.1, x0=0, verb=True)
Out[9]:
In [10]:
one_solution(-0.1, x0=10, verb=True)
Out[10]:
For $a > 0$, the equation seems to have a unique solution:
In [11]:
one_solution(1, x0=0)
one_solution(1, x0=-100)
one_solution(1, x0=100)
Out[11]:
Out[11]:
Out[11]:
We can just hack and try different values for $x_0$, expecting to find all the roots.
In [12]:
def solutions(a, x0s=None, tol=1e-10, verb=False):
nbdigits = int(np.log10(1. / tol))
sols = set()
if x0s is None:
x0s = [-10, -5, -2, -1, 0, 1, 2, 5, 10]
for x0 in x0s:
sol = opt.root(objective, x0, args=(a,))
if sol.success:
approx = np.round(float(sol.x), nbdigits)
sols.add(approx)
if verb and len(sols) == 0:
print(f"No solution was found for a = {a:.3g} (and starting at x0 = {x0:.3g}).")
return sols
In [13]:
solutions(-10)
Out[13]:
In [14]:
solutions(-0.1)
Out[14]:
In [15]:
solutions(0)
Out[15]:
In [16]:
solutions(1)
Out[16]:
In [17]:
solutions(2)
Out[17]:
In [18]:
def thresholds(amin=-10, amax=10, delta=0.01):
gap_points = dict()
prev_a = amin
prev_nb_sol = len(solutions(prev_a))
for a in np.arange(amin, amax, delta):
nb_sol = len(solutions(a))
if nb_sol != prev_nb_sol:
gap_points[(prev_nb_sol, nb_sol)] = (prev_a, a)
prev_nb_sol = nb_sol
prev_a = a
return gap_points
In [21]:
thresholds(amin=-10, amax=10, delta=0.01)
Out[21]:
In [22]:
thresholds(amin=-8, amax=1, delta=0.01)
Out[22]:
I think having $3$ (or more) solutions is a numerical error.
In [23]:
amin = -100
amax = 100
gap_points = thresholds(amin=amin, amax=amax, delta=0.1)
gap_points
Out[23]:
As we will see below, even having two solutions is nothing but a numerical error.
We can plot the (estimated) number of solution as a function of $a$, to start wit, thanks to the matplotlib.pyplot.hlines
function:
In [24]:
def plot_gap_points(gap_points, amin, amax):
ys = set()
for ym, yM in gap_points.keys():
ys.add(ym)
ys.add(yM)
print(ys)
xleft = dict()
xright = dict()
for (ym, yM), (xm, xM) in gap_points.items():
xleft[ym] = xleft.get(ym, []) + [xm]
xright[yM] = xright.get(yM, []) + [xM]
for ym, yM in gap_points.keys():
xleft[ym].sort()
xright[yM].sort()
print(xleft)
print(xright)
min_xleft = min(sum(list(xleft.values()), []))
max_xright = min(sum(list(xright.values()), []))
plt.figure()
for y in ys:
if y not in xleft and y in xright:
for x in xright[y]:
plt.hlines(y, x, amax)
if y in xleft and min_xleft in xleft[y]:
plt.hlines(y, amin, min_xleft)
del xleft[y][0]
#if y in xright and max_xright in xright[y]:
# plt.hlines(y, max_xright, amax)
# del xright[y][-1]
if y in xleft and y in xright:
for xmin, xmax in zip(xleft[y], xright[y]):
plt.hlines(y, xmin, xmax)
plt.xlabel("Value of $a$")
plt.ylabel("Number of solution")
plt.title(r"Number of solutions to $\exp(- a x^2) = x$, as function of $a$")
return ys, xleft, min_xleft, xright, max_xright
ys, xleft, min_xleft, xright, max_xright = plot_gap_points(gap_points, amin, amax)
In [25]:
def plot_multivalued_function(X, f, maxnboutput=1, **kwargs):
Y = np.zeros((maxnboutput, len(X)))
Y.fill(np.nan)
for i, x in enumerate(X):
ys = sorted(list(f(x)))
for j, y in enumerate(ys):
Y[j, i] = y
for j in range(maxnboutput):
plt.plot(X, Y[j], 'o-', **kwargs)
In [26]:
A = np.linspace(-100, 100, 1000)
plot_multivalued_function(A, solutions, maxnboutput=2, markevery=10)
plt.legend()
plt.xlabel("Parameter $a$"); plt.ylabel("Solution(s)")
plt.title(r"Solution(s) to $\exp(- a x^2) = x$, as function of $a$")
plt.show()
Out[26]:
Out[26]:
Out[26]:
Out[26]:
In [27]:
A = np.linspace(0, 20, 2000)
plot_multivalued_function(A, solutions, maxnboutput=2, markevery=20)
plt.legend()
plt.xlabel("Parameter $a$"); plt.ylabel("Solution(s)")
plt.title(r"Solution(s) to $\exp(- a x^2) = x$, as function of $a$")
plt.show()
Out[27]:
Out[27]:
Out[27]:
Out[27]:
This shows the numerical solution to the equation, and we will check below that the formal solution coincides.
Luckily, we can transform this equation to solve it with the Lambert $W$ function, defined as $W(x) = z \Leftrightarrow x = z \mathrm{e}^{z}$. For more details, please see this page, or this article.
As for (almost) all the special function, we don't need to write it ourself: it is in scipy! scipy.special.lambertw
In [28]:
from scipy.special import lambertw
As the only possible solution are $x>0$ $$ \exp(-a x^2) = x \Leftrightarrow \left(\exp(-a x^2)\right)^2= \exp(-2 a x^2) = x^2 \Leftrightarrow 2 a y \exp(2 a x) = 2 a \;\;(\text{with}\;\; y := x^2) \Leftrightarrow \\ u \exp(u) = 2 a \;\;(\text{with}\;\; u := 2 a y) \Leftrightarrow u = W(2a) \Leftrightarrow y = \frac{W(2a)}{2a} \Leftrightarrow x(a) := \sqrt{\frac{W(2a)}{2a}}. $$
And so it is quite easy to compute, for $a > 0$ (the behavior at $0$ is undefined without a more careful study):
In [30]:
def formal_solution(a):
return np.sqrt(lambertw(2 * a) / (2 * a))
We can check some values:
In [35]:
for a in [0.5, 1, 2, 3, 4]:
xa = formal_solution(a)
assert np.isclose(exp(-a * xa**2), xa)
print(f"a = {a:.3g} gives x(a) = {float(xa):.3g}")
We can try to approximate the solution for small $a$ or large $a$:
For small $a$, $W(2a) \simeq 2a - 4a^2$ so $x(a) \simeq 1 - a$.
For large $a$, we have this bound: $$ \forall x \geq \mathrm{e},{\displaystyle \ln(x)-\ln {\bigl (}\ln(x){\bigr )}+{\frac {\ln {\bigl (}\ln(x){\bigr )}}{2\ln(x)}}\leq W(x)\leq \ln(x)-\ln {\bigl (}\ln(x){\bigr )}+{\frac {e}{e-1}}{\frac {\ln {\bigl (}\ln(x){\bigr )}}{\ln(x)}}} $$
In [36]:
e = np.exp(1)
def upper_bound(a):
up_b = np.log(2*a) - np.log(np.log(2*a)) + (e / (e - 1)) * (np.log(np.log(2*a)) / np.log(2*a))
return np.sqrt(up_b / (2*a))
def lower_bound(a):
lo_b = np.log(2*a) - np.log(np.log(2*a)) + np.log(np.log(2*a)) / (2 * np.log(2*a))
return np.sqrt(lo_b / (2*a))
We can plot all this.
In [39]:
A = np.linspace(0, 20, 4000)
A1 = A[A <= 0.5]
A2 = A[A >= 1]
Ae = A[A >= e]
plt.plot(A, formal_solution(A), label="Solution", markevery=20)
plt.plot(A1, 1 - A1, 'b--', label=r"Tangent at $0$: $x(a) \simeq 1 - a$", markevery=20)
#plt.plot(A2, np.sqrt((np.log(2*A2) - np.log(np.log(2*A2)))/(2*A2)), 'g--', label=r"Asymptote at $+\infty$", markevery=20)
plt.plot(Ae, lower_bound(Ae), 'g--', label=r"Lower-bound for $a \geq e$", markevery=20)
plt.plot(Ae, upper_bound(Ae), 'c--', label=r"Upper-bound for $a \geq e$", markevery=20)
plt.legend()
plt.xlabel("Parameter $a$"); plt.ylabel(r"Solution $x(a) = \sqrt{\frac{W(2a)}{2a}}$")
plt.title(r"Solution to $\exp(- a x^2) = x$, as function of $a$")
plt.show()
Out[39]:
Out[39]:
Out[39]:
Out[39]:
Out[39]:
Out[39]:
Out[39]:
Out[39]:
And voilà.
That's it for today, folks!
See here for other notebooks I wrote, in Python, Julia, OCaml or other languages.