log-sum-exp trick


In [6]:
%matplotlib inline

In [7]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
np.exp(0)


Out[2]:
1.0

In [5]:
np.exp(-1000)


Out[5]:
0.0

In [12]:
xs = np.arange(-100, 0, 0.01)
ls = np.exp(xs)

In [13]:
plt.plot(xs, ls)


Out[13]:
[<matplotlib.lines.Line2D at 0x10ba44cf8>]

In [112]:
def log_sum_exp(xs):
    xs = np.array(xs)
    return np.log(np.sum(np.exp(xs)))

def log_sum_exp_trick(xs, f=np.max):
    xs = np.array(xs)
    a = f(xs)
    return a + np.log(np.sum(np.exp(xs - a)))

In [64]:
log_sum_exp([-10] * 10)


Out[64]:
-7.697414907005954

In [65]:
log_sum_exp_trick([-10] * 10)


Out[65]:
-7.697414907005954

In [67]:
log_sum_exp([-1e10] * 10)


/Users/amir.ziai/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:3: RuntimeWarning: divide by zero encountered in log
  app.launch_new_instance()
Out[67]:
-inf

In [69]:
log_sum_exp_trick([-1e10] * 10, f=np.max)


Out[69]:
-9999999997.697414

In [70]:
log_sum_exp_trick([-1e10] * 10, f=np.min)


Out[70]:
-9999999997.697414

In [101]:
10 ** -10 == 1e-10


Out[101]:
True

In [88]:
x = [-1e10] * 10 + [0.999]

In [98]:
x = [-1e5] * 2
log_sum_exp_trick(x)


Out[98]:
-99999.30685281944

In [99]:
log_sum_exp_trick(x)


Out[99]:
-99999.30685281944

In [92]:
log_sum_exp(x)


Out[92]:
0.9990000000000001

In [90]:
log_sum_exp_trick([-1e10] * 10 + [-0.999], f=np.min)


/Users/amir.ziai/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:8: RuntimeWarning: overflow encountered in exp
Out[90]:
inf

In [91]:
log_sum_exp_trick([-1e10] * 10 + [-0.999], f=np.max)


Out[91]:
-0.999

In [119]:
xs = [-1e10, -1e8]
log_sum_exp(xs)


/Users/amir.ziai/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:3: RuntimeWarning: divide by zero encountered in log
  app.launch_new_instance()
Out[119]:
-inf

In [120]:
log_sum_exp_trick(xs)


Out[120]:
-100000000.0