KL Divergence


In [1]:
import os
import warnings
warnings.filterwarnings('ignore')

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (4,4) # Make the figures a bit bigger
plt.style.use('fivethirtyeight')

In [3]:
import numpy as np
from scipy.stats import norm
from matplotlib import pyplot as plt
import tensorflow as tf
import seaborn as sns
sns.set()

In [4]:
def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

In [5]:
x = np.arange(-10, 10, 0.001)
p = norm.pdf(x, 0, 2)
q = norm.pdf(x, 1, 2)
print(p.shape)
print(q.shape)
kl = kl_divergence(p, q)
plt.title('KL(P||Q) = {}'.format(kl))
plt.plot(x, p)
plt.plot(x, q, c='red')


(20000,)
(20000,)
Out[5]:
[<matplotlib.lines.Line2D at 0x7fb85e68abe0>]

In [6]:
kl = kl_divergence(p, q)

Reference