In [1]:
%matplotlib notebook

In [2]:
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt

import random
import math
import sys
import os

In [3]:
class dotdict: pass

In [4]:
FLAGS = dotdict()

FLAGS.rl_weight = 10.
FLAGS.rl_epsilon = 0.8
FLAGS.rl_epsilon_decay = 8000.
FLAGS.rl_confidence_interval = 5000.
FLAGS.rl_confidence_penalty = 3.

In [5]:
# Epsilon
def calc_epsilon(eps, eps_decay, step):
    return eps * math.exp(-step / eps_decay)

In [6]:
# Temperature
def calc_temperature(conf_interval, step):
    offset = math.pi / 2
    temp = math.sin(offset + step / float(conf_interval) * 2 * math.pi)
    temp = (temp + 1) / 2 # adjust range to: [0, 1]
    return temp

In [7]:
# Model Temperature
def calc_model_temp(temp, conf_penalty, eps):
    temp = 1 + (temp - .5) * conf_penalty * eps
    temp = max(1e-3, temp) # model temperature is strictly positive. 0 causes numerical instability.
    return temp

In [8]:
# Wake Sleep Weight
def calc_model_weight(temp, weight):
    return temp * weight

In [9]:
data = []

for step in range(0, 30000, 10):
    eps = calc_epsilon(FLAGS.rl_epsilon, FLAGS.rl_epsilon_decay, step)
    temp = calc_temperature(FLAGS.rl_confidence_interval, step)
    model_temp = calc_model_temp(temp, FLAGS.rl_confidence_penalty, eps)
    model_weight = calc_model_weight(temp, FLAGS.rl_weight)
    
    _data = (step, eps, temp, model_temp, model_weight)

    data.append(_data)
    
plt.figure()
xs = zip(*data)[0]
plt.plot(xs, zip(*data)[1], label='eps')
plt.plot(xs, zip(*data)[2], label='temp')
plt.plot(xs, zip(*data)[3], label='model_temp')
plt.plot(xs, zip(*data)[4], label='model_weight')
plt.plot(xs, [1] * len(xs), label='y=1')
plt.legend(loc='best')
plt.show()



In [ ]: