行为修正-checkpoint



In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
%matplotlib inline

_EPSILON = 1e-6  # avoid nan

In [2]:
def sigmoid(x):
    return 1 / (np.exp(-x) + 1)

def softmax(x):
    exp_array = np.exp(x)
    return exp_array/np.sum(exp_array)

# https://stackoverflow.com/questions/37292872/how-can-i-one-hot-encode-in-python
def one_hot_encode(x, n_classes):
    """
    One hot encode a list of sample labels. Return a one-hot encoded vector for each label.
    : x: List of sample Labels
    : return: Numpy array of one-hot encoded labels
     """
    return np.eye(n_classes)[x]

In [3]:
actor_outputs = np.random.randint(-100, 100, size=(32, 3))
critic_outputs = np.random.randint(-1000, 1000, size=(32))

In [4]:
policy = map(softmax, actor_outputs)
policy = np.stack(list(policy), axis=0)

stochastic_actions = np.random.randint(3, size=(32))
stochastic_onehot = one_hot_encode(stochastic_actions, 3)
stochastic_policy_action = np.sum(stochastic_onehot * policy, axis=1)

greedy_actions = np.argmax(policy, axis=1)
greedy_onehot = one_hot_encode(greedy_actions, 3)
greedy_policy_action = np.sum(greedy_onehot * policy, axis=1)

In [5]:
tmp = pd.DataFrame()
tmp['greedy_actions'] = -np.log(greedy_policy_action + _EPSILON)
tmp['stochastic_actions'] = -np.log(stochastic_policy_action + _EPSILON)

tmp.describe()


Out[5]:
greedy_actions stochastic_actions
count 3.200000e+01 3.200000e+01
mean 4.395785e-03 1.004203e+01
std 2.242002e-02 5.874599e+00
min -9.999995e-07 -9.999995e-07
25% -9.999995e-07 5.006566e+00
50% -9.999995e-07 1.381551e+01
75% -9.956233e-07 1.381551e+01
max 1.269269e-01 1.381551e+01

In [6]:
action_size = 3
action_space = list(range(action_size))
action_space


Out[6]:
[0, 1, 2]

In [7]:
policy


Out[7]:
array([[  1.80485139e-35,   9.35762297e-14,   1.00000000e+00],
       [  9.93307149e-01,   7.93146247e-30,   6.69285092e-03],
       [  1.97925988e-32,   6.99199000e-56,   1.00000000e+00],
       [  1.00000000e+00,   1.21609930e-37,   1.38879439e-11],
       [  9.99999985e-01,   2.50656744e-46,   1.52299795e-08],
       [  7.86844816e-63,   9.08666032e-80,   1.00000000e+00],
       [  1.00000000e+00,   7.98490425e-30,   6.47023493e-26],
       [  2.03109266e-42,   1.00000000e+00,   1.75879220e-25],
       [  9.99997740e-01,   4.71115515e-58,   2.26032430e-06],
       [  3.97544974e-31,   1.00000000e+00,   1.75879220e-25],
       [  8.31528028e-07,   9.99999168e-01,   9.46262160e-57],
       [  1.18506486e-27,   2.28569368e-49,   1.00000000e+00],
       [  7.58256042e-10,   9.99999999e-01,   8.75651076e-27],
       [  1.00000000e+00,   2.17052201e-29,   2.61027907e-23],
       [  1.00000000e+00,   1.03770332e-53,   2.38026641e-26],
       [  1.71390843e-15,   1.00000000e+00,   1.64581143e-38],
       [  4.90609473e-35,   1.00000000e+00,   1.97925988e-32],
       [  8.80797078e-01,   1.19202922e-01,   6.09018341e-13],
       [  8.40859712e-50,   1.75879220e-25,   1.00000000e+00],
       [  1.92874985e-22,   3.97544974e-31,   1.00000000e+00],
       [  5.10908903e-12,   4.83454164e-68,   1.00000000e+00],
       [  9.99999958e-01,   7.58256011e-10,   4.13993754e-08],
       [  3.35350130e-04,   2.93649703e-30,   9.99664650e-01],
       [  2.10097478e-19,   6.69285092e-03,   9.93307149e-01],
       [  2.61027907e-23,   4.35961000e-28,   1.00000000e+00],
       [  1.00000000e+00,   5.24288566e-22,   1.26641655e-14],
       [  1.00000000e+00,   1.56288219e-18,   1.21609930e-37],
       [  6.37586958e-59,   1.00000000e+00,   5.16642063e-55],
       [  6.05460190e-39,   3.30570063e-37,   1.00000000e+00],
       [  6.05460190e-39,   6.05460190e-39,   1.00000000e+00],
       [  1.00000000e+00,   1.82510451e-78,   6.54284062e-69],
       [  6.47023493e-26,   1.00000000e+00,   4.07955867e-41]])

In [8]:
np.random.choice(action_space, 1, p=policy[0])[0]


Out[8]:
2

In [9]:
np.random.randint(3)


Out[9]:
0

In [ ]: