In [1]:
%matplotlib notebook
In [2]:
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
import random
import math
import sys
import os
In [3]:
class dotdict: pass
In [4]:
FLAGS = dotdict()
FLAGS.rl_weight = 10.
FLAGS.rl_epsilon = 0.8
FLAGS.rl_epsilon_decay = 8000.
FLAGS.rl_confidence_interval = 5000.
FLAGS.rl_confidence_penalty = 3.
In [5]:
# Epsilon
def calc_epsilon(eps, eps_decay, step):
return eps * math.exp(-step / eps_decay)
In [6]:
# Temperature
def calc_temperature(conf_interval, step):
offset = math.pi / 2
temp = math.sin(offset + step / float(conf_interval) * 2 * math.pi)
temp = (temp + 1) / 2 # adjust range to: [0, 1]
return temp
In [7]:
# Model Temperature
def calc_model_temp(temp, conf_penalty, eps):
temp = 1 + (temp - .5) * conf_penalty * eps
temp = max(1e-3, temp) # model temperature is strictly positive. 0 causes numerical instability.
return temp
In [8]:
# Wake Sleep Weight
def calc_model_weight(temp, weight):
return temp * weight
In [9]:
data = []
for step in range(0, 30000, 10):
eps = calc_epsilon(FLAGS.rl_epsilon, FLAGS.rl_epsilon_decay, step)
temp = calc_temperature(FLAGS.rl_confidence_interval, step)
model_temp = calc_model_temp(temp, FLAGS.rl_confidence_penalty, eps)
model_weight = calc_model_weight(temp, FLAGS.rl_weight)
_data = (step, eps, temp, model_temp, model_weight)
data.append(_data)
plt.figure()
xs = zip(*data)[0]
plt.plot(xs, zip(*data)[1], label='eps')
plt.plot(xs, zip(*data)[2], label='temp')
plt.plot(xs, zip(*data)[3], label='model_temp')
plt.plot(xs, zip(*data)[4], label='model_weight')
plt.plot(xs, [1] * len(xs), label='y=1')
plt.legend(loc='best')
plt.show()
In [ ]: