In [1]:
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

In [2]:
def gini(p):
    return p * (1 - p) + (1 - p) * (1 - (1 - p))


def entropy(p):
    return - p * np.log2(p) - (1 - p) * np.log2((1 - p))


def error(p):
    return 1 - np.max([p, 1 - p])

In [6]:
x = np.arange(0.0, 1.0, 0.01)

ent = [entropy(p) if p != 0 else None for p in x]
sc_ent = [e * 0.5 if e else None for e in ent]
err = [error(i) for i in x]

fig, ax = plt.subplots()
for i, lab in zip([sc_ent, gini(x), err], ['Entropy (scaled)','Gini Impurity', 'Misclassification Error']):
    line = ax.plot(x, i, label=lab, linestyle='-', lw=2, marker='None')

ax.grid()
ax.legend()
plt.xlim([0, 1.0])
plt.ylim([0, 0.55])
plt.xlabel('p(class=1)')
plt.ylabel('Impurity')
plt.title('Impurity comparison for binary classification (0 or 1)')
plt.savefig('/home/jbourbeau/public_html/figures/impurity-comparison.png')
plt.show()



In [ ]: