In [24]:
import argparse
import gym
import numpy as np
from itertools import count

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.autograd import Variable


seed = 12
gamma = 0.95
render = True
log_interval = 5

env = gym.make('CartPole-v0')
env.seed(seed)
torch.manual_seed(seed)


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.affine2 = nn.Linear(128, 2)

        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.affine2(x)
        return F.softmax(action_scores)


policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)


def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(Variable(state))
    action = probs.multinomial()
    policy.saved_actions.append(action)
    return action.data


def finish_episode():
    R = 0
    rewards = []
    for r in policy.rewards[::-1]:
        R = r + gamma * R
        rewards.insert(0, R)
    rewards = torch.Tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
    for action, r in zip(policy.saved_actions, rewards):
        action.reinforce(r)
    optimizer.zero_grad()
    autograd.backward(policy.saved_actions, [None for _ in policy.saved_actions])
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_actions[:]


running_reward = 10

for i_episode in count(1):
    state = env.reset()
    
    for t in range(10000): # Don't infinite loop while learning
        action = select_action(state)
        state, reward, done, _ = env.step(action[0,0])
        if render:
            env.render()
        policy.rewards.append(reward)
        if done:
            break

    running_reward = running_reward * 0.99 + t * 0.01
    finish_episode()
    
    if i_episode % log_interval == 0:
        print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
            i_episode, t, running_reward))
    if running_reward > 200:
        print("Solved! Running reward is now {} and "
              "the last episode runs to {} time steps!".format(running_reward, t))
        break


[2017-07-25 11:46:30,955] Making new env: CartPole-v0
Episode 5	Last length:    13	Average length: 10.29
Episode 10	Last length:    16	Average length: 11.17
Episode 15	Last length:    57	Average length: 14.22
Episode 20	Last length:   108	Average length: 17.45
Episode 25	Last length:    66	Average length: 19.96
Episode 30	Last length:   165	Average length: 23.00
Episode 35	Last length:   199	Average length: 29.26
Episode 40	Last length:   199	Average length: 36.82
Episode 45	Last length:   167	Average length: 43.33
Episode 50	Last length:   199	Average length: 50.39
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-24-a2a82503de80> in <module>()
     75         state, reward, done, _ = env.step(action[0,0])
     76         if render:
---> 77             env.render()
     78         policy.rewards.append(reward)
     79         if done:

/Users/Max/gym/gym/core.py in render(self, mode, close)
    148             elif mode not in modes:
    149                 raise error.UnsupportedMode('Unsupported rendering mode: {}. (Supported modes for {}: {})'.format(mode, self, modes))
--> 150         return self._render(mode=mode, close=close)
    151 
    152     def close(self):

/Users/Max/gym/gym/core.py in _render(self, mode, close)
    284 
    285     def _render(self, mode='human', close=False):
--> 286         return self.env.render(mode, close)
    287 
    288     def _close(self):

/Users/Max/gym/gym/core.py in render(self, mode, close)
    148             elif mode not in modes:
    149                 raise error.UnsupportedMode('Unsupported rendering mode: {}. (Supported modes for {}: {})'.format(mode, self, modes))
--> 150         return self._render(mode=mode, close=close)
    151 
    152     def close(self):

/Users/Max/gym/gym/envs/classic_control/cartpole.py in _render(self, mode, close)
    144         self.poletrans.set_rotation(-x[2])
    145 
--> 146         return self.viewer.render(return_rgb_array = mode=='rgb_array')

/Users/Max/gym/gym/envs/classic_control/rendering.py in render(self, return_rgb_array)
     85         self.transform.enable()
     86         for geom in self.geoms:
---> 87             geom.render()
     88         for geom in self.onetime_geoms:
     89             geom.render()

/Users/Max/gym/gym/envs/classic_control/rendering.py in render(self)
    152         for attr in reversed(self.attrs):
    153             attr.enable()
--> 154         self.render1()
    155         for attr in self.attrs:
    156             attr.disable()

/Users/Max/gym/gym/envs/classic_control/rendering.py in render1(self)
    225         else: glBegin(GL_TRIANGLES)
    226         for p in self.v:
--> 227             glVertex3f(p[0], p[1],0)  # draw each vertex
    228         glEnd()
    229 

/Users/Max/Coding/anaconda2/envs/torch/lib/python3.6/site-packages/pyglet/gl/lib.py in errcheck(result, func, arguments)
     82     pass
     83 
---> 84 def errcheck(result, func, arguments):
     85     if _debug_gl_trace:
     86         try:

KeyboardInterrupt: 

In [31]:
policy.saved_actions


Out[31]:
[Variable containing:
  1
 [torch.LongTensor of size 1x1], Variable containing:
  0
 [torch.LongTensor of size 1x1], Variable containing:
  0
 [torch.LongTensor of size 1x1], Variable containing:
  1
 [torch.LongTensor of size 1x1], Variable containing:
  0
 [torch.LongTensor of size 1x1]]

In [10]:
# Simple Application 
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.affine2 = nn.Linear(128, 2)

        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.affine2(x)
        return F.softmax(action_scores)

torch.manual_seed(12)
policy = Policy()
state = Variable(torch.randn(1,4),requires_grad=True)
probs = policy(state)
action = probs.multinomial(2, replace=True)
policy.saved_actions.append(action)
action.reinforce(10)
action.backward()
state.grad


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-2c9dcfd9671d> in <module>()
     23 state = Variable(torch.randn(1,4),requires_grad=True)
     24 probs = policy(state)
---> 25 action = probs.multinomial(2, replace=True)
     26 policy.saved_actions.append(action)
     27 action.reinforce(10)

TypeError: multinomial() got an unexpected keyword argument 'replace'

In [ ]:
def gen_matrix_from_cluster_ix(cluster_ix):
	if not isinstance(cluster_ix, np.ndarray):
		cluster_ix = cluster_ix.numpy()
	set_size = cluster_ix.shape[0]
	matrix = np.tile(cluster_ix, (set_size,1))
	matrix = matrix - matrix.T
	matrix = ~matrix.astype(bool)
	return matrix

def plot_matrix(matrix):
	if not isinstance(matrix, np.ndarray):
		matrix = matrix.numpy()

	plt.matshow(matrix, interpolation='nearest')
	plt.show()

def plot_embd(embd):
	"""
	Plots a colormap of the L-kernel given embd,
	such that L = embd * embd.T
	Arguments:
	- embd: numpy array or torch tensor
	"""
	if not isinstance(embd, np.ndarray):
		embd = embd.numpy()
	
	L = embd.dot(embd.T)
	plot_matrix(L)

In [ ]:


In [ ]:


In [4]:
import torch
from torch.autograd import Variable
A = Variable(torch.randn(2,2), requires_grad=True)
B = Variable(torch.randn(2,2), requires_grad=True)
C = Variable(torch.randn(2,2), requires_grad=True)
my_l = [A, B, C]
loss = 0
for var in my_l:
    lo = torch.sin(var)
    loss += lo
loss.sum().backward()
print(A.grad)
print(B.grad)


Variable containing:
-0.2218  0.8815
-0.0302  0.6539
[torch.FloatTensor of size 2x2]

Variable containing:
-0.0219  0.7302
 0.8764  0.9063
[torch.FloatTensor of size 2x2]


In [49]:
from dpp_nets.dpp import score_dpp
import numpy as np
# Gradient Exploration
set_size = 5
kernel_dim = 10
K = torch.randn(set_size, kernel_dim)
L = K.mm(K.t())
subset1 = torch.ByteTensor([1, 1, 1, 0, 0])
subset2 = torch.ByteTensor([1, 1, 0, 0, 0])
# Gradient
embd = K.numpy()
subset1 = subset1.numpy()
subset2 = subset2.numpy()
print('Grad1 :', np.sign(score_dpp(embd, subset1)))
print('Grad2 :',np.sign(score_dpp(embd, subset2)))
print('Grad Agreement', np.sign(score_dpp(embd, subset1)) == np.sign(score_dpp(embd, subset2)))


Grad1 : [[ 1. -1.  1.  1. -1.  1.  1. -1. -1.  1.]
 [ 1.  1. -1. -1.  1. -1. -1.  1. -1. -1.]
 [-1.  1. -1. -1. -1. -1.  1.  1. -1.  1.]
 [ 1.  1.  1.  1.  1. -1. -1. -1. -1. -1.]
 [-1.  1. -1. -1.  1. -1. -1.  1.  1. -1.]]
Grad2 : [[-1. -1.  1.  1. -1.  1. -1. -1.  1. -1.]
 [-1.  1.  1. -1.  1. -1. -1.  1. -1. -1.]
 [-1. -1. -1. -1. -1. -1. -1.  1.  1. -1.]
 [ 1.  1.  1.  1.  1. -1. -1. -1. -1. -1.]
 [-1.  1. -1. -1.  1. -1. -1.  1.  1. -1.]]
Grad Agreement [[False  True  True  True  True  True False  True False False]
 [False  True False  True  True  True  True  True  True  True]
 [ True False  True  True  True  True False  True False False]
 [ True  True  True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True]]

In [26]:
K


Out[26]:
-0.0038  0.2483 -1.4515 -1.5514  0.6025 -0.7342  0.0642 -0.4024  0.5271 -0.6581
 0.5920 -0.4561 -0.0731 -0.8933 -1.2156  1.9095  2.0881  1.5159 -0.1664  0.5997
-0.1518 -0.9442  0.0375  0.9015 -2.3705 -0.0967  1.5697  2.3204 -0.1346  0.0100
-0.7355 -1.9643  0.1382  0.6559  0.4406 -1.2344  1.2915  0.4911 -0.5186  1.2513
-1.4058 -1.1972 -0.9700 -1.5969 -0.0701 -0.0360  1.7563 -0.3487 -0.9293 -0.0764
[torch.FloatTensor of size 5x10]

In [9]:
score_dpp


Out[9]:
<function dpp_nets.dpp.score_dpp.score_dpp>

In [55]:
import numpy as np
import matplotlib.pyplot as plt

#pgf.rcfonts : False
    
# set up figure size
plt.figure(figsize=(2, 2))

# do some plotting here
x = np.linspace(-2, 2, 1e3)
plt.plot(x, x**2)

# save to file
plt.savefig('example.pdf')
plt.savefig('example.pgf')


/Users/Max/Coding/anaconda2/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:10: DeprecationWarning: object of type <class 'float'> cannot be safely interpreted as an integer.
  # Remove the CWD from sys.path while we load stuff.

In [54]:
plt.rcParams


Out[54]:
RcParams({'_internal.classic_mode': False,
          'agg.path.chunksize': 0,
          'animation.avconv_args': [],
          'animation.avconv_path': 'avconv',
          'animation.bitrate': -1,
          'animation.codec': 'h264',
          'animation.convert_args': [],
          'animation.convert_path': 'convert',
          'animation.ffmpeg_args': [],
          'animation.ffmpeg_path': 'ffmpeg',
          'animation.frame_format': 'png',
          'animation.html': 'none',
          'animation.mencoder_args': [],
          'animation.mencoder_path': 'mencoder',
          'animation.writer': 'ffmpeg',
          'axes.autolimit_mode': 'data',
          'axes.axisbelow': 'line',
          'axes.edgecolor': 'k',
          'axes.facecolor': 'w',
          'axes.formatter.limits': [-7, 7],
          'axes.formatter.offset_threshold': 4,
          'axes.formatter.use_locale': False,
          'axes.formatter.use_mathtext': False,
          'axes.formatter.useoffset': True,
          'axes.grid': False,
          'axes.grid.axis': 'both',
          'axes.grid.which': 'major',
          'axes.hold': None,
          'axes.labelcolor': 'k',
          'axes.labelpad': 4.0,
          'axes.labelsize': 'medium',
          'axes.labelweight': 'normal',
          'axes.linewidth': 0.8,
          'axes.prop_cycle': cycler('color', ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']),
          'axes.spines.bottom': True,
          'axes.spines.left': True,
          'axes.spines.right': True,
          'axes.spines.top': True,
          'axes.titlepad': 6.0,
          'axes.titlesize': 'large',
          'axes.titleweight': 'normal',
          'axes.unicode_minus': True,
          'axes.xmargin': 0.05,
          'axes.ymargin': 0.05,
          'axes3d.grid': True,
          'backend': 'module://ipykernel.pylab.backend_inline',
          'backend.qt4': 'PyQt4',
          'backend.qt5': 'PyQt5',
          'backend_fallback': True,
          'boxplot.bootstrap': None,
          'boxplot.boxprops.color': 'k',
          'boxplot.boxprops.linestyle': '-',
          'boxplot.boxprops.linewidth': 1.0,
          'boxplot.capprops.color': 'k',
          'boxplot.capprops.linestyle': '-',
          'boxplot.capprops.linewidth': 1.0,
          'boxplot.flierprops.color': 'k',
          'boxplot.flierprops.linestyle': 'none',
          'boxplot.flierprops.linewidth': 1.0,
          'boxplot.flierprops.marker': 'o',
          'boxplot.flierprops.markeredgecolor': 'k',
          'boxplot.flierprops.markerfacecolor': 'none',
          'boxplot.flierprops.markersize': 6.0,
          'boxplot.meanline': False,
          'boxplot.meanprops.color': 'C2',
          'boxplot.meanprops.linestyle': '--',
          'boxplot.meanprops.linewidth': 1.0,
          'boxplot.meanprops.marker': '^',
          'boxplot.meanprops.markeredgecolor': 'C2',
          'boxplot.meanprops.markerfacecolor': 'C2',
          'boxplot.meanprops.markersize': 6.0,
          'boxplot.medianprops.color': 'C1',
          'boxplot.medianprops.linestyle': '-',
          'boxplot.medianprops.linewidth': 1.0,
          'boxplot.notch': False,
          'boxplot.patchartist': False,
          'boxplot.showbox': True,
          'boxplot.showcaps': True,
          'boxplot.showfliers': True,
          'boxplot.showmeans': False,
          'boxplot.vertical': True,
          'boxplot.whiskerprops.color': 'k',
          'boxplot.whiskerprops.linestyle': '-',
          'boxplot.whiskerprops.linewidth': 1.0,
          'boxplot.whiskers': 1.5,
          'contour.corner_mask': True,
          'contour.negative_linestyle': 'dashed',
          'datapath': '/Users/Max/Coding/anaconda2/envs/torch/lib/python3.6/site-packages/matplotlib/mpl-data',
          'date.autoformatter.day': '%Y-%m-%d',
          'date.autoformatter.hour': '%m-%d %H',
          'date.autoformatter.microsecond': '%M:%S.%f',
          'date.autoformatter.minute': '%d %H:%M',
          'date.autoformatter.month': '%Y-%m',
          'date.autoformatter.second': '%H:%M:%S',
          'date.autoformatter.year': '%Y',
          'docstring.hardcopy': False,
          'errorbar.capsize': 0.0,
          'examples.directory': '',
          'figure.autolayout': False,
          'figure.dpi': 72.0,
          'figure.edgecolor': (1, 1, 1, 0),
          'figure.facecolor': (1, 1, 1, 0),
          'figure.figsize': [6.0, 4.0],
          'figure.frameon': True,
          'figure.max_open_warning': 20,
          'figure.subplot.bottom': 0.125,
          'figure.subplot.hspace': 0.2,
          'figure.subplot.left': 0.125,
          'figure.subplot.right': 0.9,
          'figure.subplot.top': 0.88,
          'figure.subplot.wspace': 0.2,
          'figure.titlesize': 'large',
          'figure.titleweight': 'normal',
          'font.cursive': ['Apple Chancery',
                           'Textile',
                           'Zapf Chancery',
                           'Sand',
                           'Script MT',
                           'Felipa',
                           'cursive'],
          'font.family': ['sans-serif'],
          'font.fantasy': ['Comic Sans MS',
                           'Chicago',
                           'Charcoal',
                           'ImpactWestern',
                           'Humor Sans',
                           'xkcd',
                           'fantasy'],
          'font.monospace': ['DejaVu Sans Mono',
                             'Bitstream Vera Sans Mono',
                             'Computer Modern Typewriter',
                             'Andale Mono',
                             'Nimbus Mono L',
                             'Courier New',
                             'Courier',
                             'Fixed',
                             'Terminal',
                             'monospace'],
          'font.sans-serif': ['DejaVu Sans',
                              'Bitstream Vera Sans',
                              'Computer Modern Sans Serif',
                              'Lucida Grande',
                              'Verdana',
                              'Geneva',
                              'Lucid',
                              'Arial',
                              'Helvetica',
                              'Avant Garde',
                              'sans-serif'],
          'font.serif': ['DejaVu Serif',
                         'Bitstream Vera Serif',
                         'Computer Modern Roman',
                         'New Century Schoolbook',
                         'Century Schoolbook L',
                         'Utopia',
                         'ITC Bookman',
                         'Bookman',
                         'Nimbus Roman No9 L',
                         'Times New Roman',
                         'Times',
                         'Palatino',
                         'Charter',
                         'serif'],
          'font.size': 10.0,
          'font.stretch': 'normal',
          'font.style': 'normal',
          'font.variant': 'normal',
          'font.weight': 'normal',
          'grid.alpha': 1.0,
          'grid.color': '#b0b0b0',
          'grid.linestyle': '-',
          'grid.linewidth': 0.8,
          'hatch.color': 'k',
          'hatch.linewidth': 1.0,
          'hist.bins': 10,
          'image.aspect': 'equal',
          'image.cmap': 'viridis',
          'image.composite_image': True,
          'image.interpolation': 'nearest',
          'image.lut': 256,
          'image.origin': 'upper',
          'image.resample': True,
          'interactive': False,
          'keymap.all_axes': ['a'],
          'keymap.back': ['left', 'c', 'backspace'],
          'keymap.forward': ['right', 'v'],
          'keymap.fullscreen': ['f', 'ctrl+f'],
          'keymap.grid': ['g'],
          'keymap.home': ['h', 'r', 'home'],
          'keymap.pan': ['p'],
          'keymap.quit': ['ctrl+w', 'cmd+w'],
          'keymap.save': ['s', 'ctrl+s'],
          'keymap.xscale': ['k', 'L'],
          'keymap.yscale': ['l'],
          'keymap.zoom': ['o'],
          'legend.borderaxespad': 0.5,
          'legend.borderpad': 0.4,
          'legend.columnspacing': 2.0,
          'legend.edgecolor': '0.8',
          'legend.facecolor': 'inherit',
          'legend.fancybox': True,
          'legend.fontsize': 'medium',
          'legend.framealpha': 0.8,
          'legend.frameon': True,
          'legend.handleheight': 0.7,
          'legend.handlelength': 2.0,
          'legend.handletextpad': 0.8,
          'legend.labelspacing': 0.5,
          'legend.loc': 'best',
          'legend.markerscale': 1.0,
          'legend.numpoints': 1,
          'legend.scatterpoints': 1,
          'legend.shadow': False,
          'lines.antialiased': True,
          'lines.color': 'C0',
          'lines.dash_capstyle': 'butt',
          'lines.dash_joinstyle': 'round',
          'lines.dashdot_pattern': [6.4, 1.6, 1.0, 1.6],
          'lines.dashed_pattern': [3.7, 1.6],
          'lines.dotted_pattern': [1.0, 1.65],
          'lines.linestyle': '-',
          'lines.linewidth': 1.5,
          'lines.marker': 'None',
          'lines.markeredgewidth': 1.0,
          'lines.markersize': 6.0,
          'lines.scale_dashes': True,
          'lines.solid_capstyle': 'projecting',
          'lines.solid_joinstyle': 'round',
          'markers.fillstyle': 'full',
          'mathtext.bf': 'sans:bold',
          'mathtext.cal': 'cursive',
          'mathtext.default': 'it',
          'mathtext.fallback_to_cm': True,
          'mathtext.fontset': 'dejavusans',
          'mathtext.it': 'sans:italic',
          'mathtext.rm': 'sans',
          'mathtext.sf': 'sans',
          'mathtext.tt': 'monospace',
          'nbagg.transparent': True,
          'patch.antialiased': True,
          'patch.edgecolor': 'k',
          'patch.facecolor': 'C0',
          'patch.force_edgecolor': False,
          'patch.linewidth': 1.0,
          'path.effects': [],
          'path.simplify': True,
          'path.simplify_threshold': 0.1111111111111111,
          'path.sketch': None,
          'path.snap': True,
          'pdf.compression': 6,
          'pdf.fonttype': 3,
          'pdf.inheritcolor': False,
          'pdf.use14corefonts': False,
          'pgf.debug': False,
          'pgf.preamble': [],
          'pgf.rcfonts': True,
          'pgf.texsystem': 'xelatex',
          'plugins.directory': '.matplotlib_plugins',
          'polaraxes.grid': True,
          'ps.distiller.res': 6000,
          'ps.fonttype': 3,
          'ps.papersize': 'letter',
          'ps.useafm': False,
          'ps.usedistiller': False,
          'savefig.bbox': None,
          'savefig.directory': '~',
          'savefig.dpi': 'figure',
          'savefig.edgecolor': 'w',
          'savefig.facecolor': 'w',
          'savefig.format': 'png',
          'savefig.frameon': True,
          'savefig.jpeg_quality': 95,
          'savefig.orientation': 'portrait',
          'savefig.pad_inches': 0.1,
          'savefig.transparent': False,
          'scatter.marker': 'o',
          'svg.fonttype': 'path',
          'svg.hashsalt': None,
          'svg.image_inline': True,
          'text.antialiased': True,
          'text.color': 'k',
          'text.dvipnghack': None,
          'text.hinting': 'auto',
          'text.hinting_factor': 8,
          'text.latex.preamble': [],
          'text.latex.preview': False,
          'text.latex.unicode': False,
          'text.usetex': False,
          'timezone': 'UTC',
          'tk.window_focus': False,
          'toolbar': 'toolbar2',
          'verbose.fileo': 'sys.stdout',
          'verbose.level': 'silent',
          'webagg.open_in_browser': True,
          'webagg.port': 8988,
          'webagg.port_retries': 50,
          'xtick.alignment': 'center',
          'xtick.bottom': True,
          'xtick.color': 'k',
          'xtick.direction': 'out',
          'xtick.labelsize': 'medium',
          'xtick.major.bottom': True,
          'xtick.major.pad': 3.5,
          'xtick.major.size': 3.5,
          'xtick.major.top': True,
          'xtick.major.width': 0.8,
          'xtick.minor.bottom': True,
          'xtick.minor.pad': 3.4,
          'xtick.minor.size': 2.0,
          'xtick.minor.top': True,
          'xtick.minor.visible': False,
          'xtick.minor.width': 0.6,
          'xtick.top': False,
          'ytick.alignment': 'center_baseline',
          'ytick.color': 'k',
          'ytick.direction': 'out',
          'ytick.labelsize': 'medium',
          'ytick.left': True,
          'ytick.major.left': True,
          'ytick.major.pad': 3.5,
          'ytick.major.right': True,
          'ytick.major.size': 3.5,
          'ytick.major.width': 0.8,
          'ytick.minor.left': True,
          'ytick.minor.pad': 3.4,
          'ytick.minor.right': True,
          'ytick.minor.size': 2.0,
          'ytick.minor.visible': False,
          'ytick.minor.width': 0.6,
          'ytick.right': False})

In [ ]: