From Pixels to Torques: Policy Learning using

Deep Dynamical Convolutional Neural Networks (DDCNN)

by John-Alexander M. Assael, Marc P. Deisenroth

Parameters


In [ ]:
if itorch then
    arg = {}
end

cmd = torch.CmdLine()
cmd:text()
cmd:text('From Pixels to Torques:')
cmd:text('Policy Learning using Deep Dynamical Convolutional Neural Networks (DDCNN)')
cmd:text('by John-Alexander M. Assael, Marc P. Deisenroth')
cmd:text()
cmd:text('Options')

-- general options:
cmd:option('-seed', 1, 'initial random seed')
cmd:option('-threads', 4, 'number of threads')

-- gpu
cmd:option('-cuda', false, 'cuda')

-- model
cmd:option('-lambda', 1, 'lambda')
cmd:option('-action_size', 1, 'action size')

-- training
cmd:option('-batch_size', 20, 'batch size')
cmd:option('-hist_len', 2, 'history length')
cmd:option('-learningRate', 3e-4, 'learning rate')

-- get current path
require 'sys'
dname, fname = sys.fpath()
cmd:option('-save', dname, 'save path')
cmd:option('-load', false, 'load pretrained model')

cmd:option('-v', false, 'be verbose')
cmd:text()

opt = cmd:parse(arg)

Import Packages


In [ ]:
require 'base'
require 'hdf5'
require 'image'
require 'nngraph'
require 'optim'
require 'nn'
require 'KLDistCriterion'
require 'KLDCriterion'
require 'LinearO'
require 'AddCons'
require 'Reparametrize'
require 'unsup'
-- require 'pprint'
Plot = require 'itorch.Plot'

-- Cuda initialisation
if opt.cuda then
    require 'cutorch'
    require 'cunn'
    cutorch.setDevice(1)
    print(cutorch.getDeviceProperties(1))
end

torch.manualSeed(opt.seed)
torch.setnumthreads(opt.threads)
-- Set float as default type
torch.setdefaulttensortype('torch.FloatTensor')

Load Data


In [ ]:
function disp_img(img)
    if itorch then
        if opt.y_mean ~= nil then
            img = g_destandarize(img:float(), opt.y_mean, opt.y_std)
        end
        itorch.image(image.scale(img:float():reshape(opt.img_w, opt.img_h), 256))
    end
end

In [ ]:
local myFile = hdf5.open('data/single_gravity_40.h5', 'r')

local y_all = myFile:read('train_y'):all():float()
local u_all = myFile:read('train_u'):all():float():reshape(y_all:size(1), opt.action_size)

myFile:close()

-- Scale images
-- local new_size = 10
-- local prev_size = torch.sqrt(y_all:size(2))
-- y_all = image.scale(y_all:reshape(y_all:size(1), prev_size,prev_size), new_size, new_size):reshape(y_all:size(1), new_size^2)

-- Train Test
local y = y_all[{{1,15001}}]
local u = u_all[{{1,15001}}]

local ys = y_all[{{15001,16001}}]
local us = u_all[{{15001,16001}}]


-- Update parameters
opt.img_w = torch.sqrt(y:size(2))
opt.img_h = torch.sqrt(y:size(2))
opt.max_seq_length = y:size(1) - 1

-- Store data
state_train = {
  x = transfer_data(y),
  u = transfer_data(u)
}

state_test = {
  x = transfer_data(ys),
  u = transfer_data(us)
}

print('Train=' .. state_train.x:size(1) .. ' Test=' .. state_test.x:size(1) .. ' (' .. opt.img_w .. 'x' .. opt.img_h .. ')')

In [ ]:
idx=1
disp_img(state_train.x[idx])

Define Model Architecture

Network


In [ ]:
function create_network()
    
    opt.latent_dims = 2
    local enc_dims = 100
    local trans_dims = 100
    
    -- Model Specific parameters
    local f_maps_1 = 32
    local f_size_1 = 5
    local f_maps_2 = 32
    local f_size_2 = 5
    local f_maps_3 = 32
    local f_size_3 = 3
    
    -- Encoder
    encoder = nn.Sequential()
    encoder:add(nn.Reshape(1, opt.img_w, opt.img_h))
    encoder:add(nn.SpatialConvolutionMM(1, f_maps_1, f_size_1, f_size_1))
    encoder:add(nn.ReLU())
    encoder:add(nn.SpatialMaxPooling(2,2,2,2))
    
    --layer 2
    encoder:add(nn.SpatialConvolutionMM(f_maps_1, f_maps_2, f_size_2, f_size_2))
    encoder:add(nn.ReLU())
    encoder:add(nn.SpatialMaxPooling(2,2,2,2))
    
    --layer 3
    encoder:add(nn.SpatialConvolutionMM(f_maps_2, f_maps_3, f_size_3, f_size_3))
    encoder:add(nn.ReLU())
    -- encoder:add(nn.SpatialMaxPooling(2,2,2,2))
    
    encoder:add(nn.Reshape(f_maps_3*5*5))
    encoder:add(nn.LinearO(f_maps_3*5*5, enc_dims))
    encoder:add(nn.ReLU())
        
    encoder:add(nn.LinearO(enc_dims, enc_dims))
    encoder:add(nn.ReLU())
    
    encoder:add(nn.LinearO(enc_dims, opt.latent_dims))
       
    -- Decoder
    decoder = nn.Sequential()
    decoder:add(nn.LinearO(opt.latent_dims, enc_dims))
    decoder:add(nn.ReLU())

    decoder:add(nn.LinearO(enc_dims, enc_dims))
    decoder:add(nn.ReLU())    
    
    decoder:add(nn.LinearO(enc_dims, f_maps_3*6*6))
    decoder:add(nn.ReLU())
    
    decoder:add(nn.Reshape(f_maps_3, 6, 6))
    
    -- layer 3
    decoder:add(nn.SpatialUpSamplingNearest(2))
    decoder:add(nn.SpatialConvolutionMM(f_maps_3, f_maps_3, f_size_3+1, f_size_3+1))
    decoder:add(nn.ReLU())
        
    -- layer 2
    decoder:add(nn.SpatialUpSamplingNearest(2))
    decoder:add(nn.SpatialConvolutionMM(f_maps_3, f_maps_2, f_size_2, f_size_2))
    decoder:add(nn.ReLU())
    
    -- layer 1
    decoder:add(nn.SpatialUpSamplingNearest(2))
    decoder:add(nn.SpatialConvolutionMM(f_maps_2, f_maps_1, f_size_2+1, f_size_2+1))
    decoder:add(nn.ReLU())
    
    decoder:add(nn.SpatialUpSamplingNearest(2))
    decoder:add(nn.SpatialConvolutionMM(f_maps_1, 1, f_size_2+2, f_size_2+2))
    
    decoder:add(nn.Sigmoid())
    decoder:add(nn.View(opt.img_w^2))
    
    
    -- Clone enc-dec
    local encoder2 = encoder:clone("weight", "bias", "gradWeight", "gradBias")
    local decoder2 = decoder:clone("weight", "bias", "gradWeight", "gradBias")
    
    -- Define model
    local x_t_prev = nn.Identity()():annotate{name = 'x_t_prev'}
    local x_t = nn.Identity()():annotate{name = 'x_t'}
    local u_t = nn.Identity()():annotate{name = 'u_t'}
    
    -- Define Encoder Module
    local z_t_prev = encoder2(x_t_prev):annotate{name = 'z_t_prev'}
    local z_t = encoder(x_t):annotate{name = 'z_t'}
    
        
    -- transition
    trans = nn.Sequential()
    trans:add(nn.LinearO(opt.action_size+opt.latent_dims*2, trans_dims))
    trans:add(nn.ReLU())
    trans:add(nn.LinearO(trans_dims, trans_dims))
    trans:add(nn.ReLU())
    trans:add(nn.LinearO(trans_dims, opt.latent_dims))
    
    
    local dynamics_all = trans(nn.JoinTable(2)({z_t_prev, z_t, nn.Reshape(opt.action_size)(u_t)})):annotate{name = 'dynamics'}

    -- Define Output
    local decoder_x_t_next = decoder(dynamics_all):annotate{name = 'decoder_x_t_next'}
    local decoder_x_t_cur = decoder2(z_t):annotate{name = 'decoder_x_t_cur'}
    
    -- Create model
    
    model = nn.gModule({x_t_prev, x_t, u_t}, {z_t_prev, z_t, dynamics_all, decoder_x_t_cur, decoder_x_t_next})
    
    -- create_links(model)
    
    return model
end

function create_links(model)
    encoder = model.forwardnodes[5].data.module
    trans = model.forwardnodes[13].data.module
    decoder = model.forwardnodes[14].data.module
end

Setup Network function


In [ ]:
function setup()
    print("Creating Neural Net.")
    model = create_network()
    params, gradParams = model:getParameters()
    
    criterion = nn.BCECriterion()
    criterion.sizeAverage = false
    
    criterion_mse = nn.MSECriterion()
    criterion_mse.sizeAverage = false

end

function setup_load()
    
    print("Loading Neural Net.")
    
    load_model()
    
    create_links(model)
    
    params, gradParams = model:getParameters()
    
    opt.load = true
    
    dname, fname = sys.fpath()
    opt.save = dname
    
    criterion = nn.BCECriterion()
    criterion.sizeAverage = false
    
    criterion_mse = nn.MSECriterion()
    criterion_mse.sizeAverage = false

end

Save model


In [ ]:
function save_model()
    -- save/log current net
    local filename = paths.concat(opt.save, 'model/relu_single_gravity_ddcnn.t7')
    os.execute('mkdir -p ' .. paths.dirname(filename))
    if paths.filep(filename) then
        os.execute('mv ' .. filename .. ' ' .. filename .. '.old')
    end
    -- print('<trainer> saving network to '..filename)
    torch.save(filename, {model, opt, optim_config, train_err, test_err})
end

function load_model()
    model, opt, optim_config, train_err, test_err = unpack(torch.load('model/relu_single_gravity_ddcnn.t7'))
end

Initialize Network


In [ ]:
print("Network parameters:")
print(opt)

if opt.load then
    setup_load()
else
    setup()
    optim_config = { learningRate = opt.learningRate,
                     beta2 = 0.9
                    }
    train_err = {}
    test_err = {}
end

epoch = #train_err

Train Function


In [ ]:
function g_create_batch(dataset)
    local batches = {}
    
    -- shuffle at each epoch
    local shuffle = torch.randperm(dataset.x:size(1)):long()
    
    for t = 1,dataset.x:size(1),opt.batch_size do

        -- Count size
        local mini_batch_size = 0
        for i = t,math.min(t+opt.batch_size-1,dataset.x:size(1)) do
            local idx = shuffle[i]
            if idx - 1 >= 1 and idx+1 <= dataset.x:size(1) then
                mini_batch_size = mini_batch_size + 1
            end
        end

        -- create mini batch
        local batch_x_prev = torch.Tensor(mini_batch_size, opt.img_w^2)
        local batch_x_cur = torch.Tensor(mini_batch_size, opt.img_w^2)
        local batch_u = torch.Tensor(mini_batch_size, opt.action_size)
        local batch_y = torch.Tensor(mini_batch_size, opt.img_w^2)

        local cur_idx = 1

        for i = t,math.min(t+opt.batch_size-1,dataset.x:size(1)) do
            
            local idx = shuffle[i]

            -- Filter batches
            if idx - 1 >= 1 and idx+1 <= dataset.x:size(1) then
                
                -- load new sample
                batch_x_prev[cur_idx] = dataset.x[idx-1]
                batch_x_cur[cur_idx] = dataset.x[idx]
                batch_y[cur_idx] = dataset.x[idx+1]
                batch_u[cur_idx] = dataset.u[idx]

                cur_idx = cur_idx + 1
            end
        end

        table.insert(batches, {batch_x_prev, batch_x_cur, batch_u, batch_y})
    end
    
    dataset.batch = batches
end

In [ ]:
function train(dataset)

    g_create_batch(state_train)

    -- epoch tracker
    epoch = epoch or 0

    -- local vars
    local err = {all=0, bce=0, bce_1=0, mse=0}

    -- shuffle at each epoch
    local shuffle = torch.randperm(#dataset.batch):long()

    for t = 1,#dataset.batch do
        
        -- xlua.progress(t, #dataset.batch)

        -- create mini batch
        local batch_x_prev = dataset.batch[shuffle[t]][1]
        local batch_x_cur = dataset.batch[shuffle[t]][2]
        local batch_u = dataset.batch[shuffle[t]][3]
        local batch_y = dataset.batch[shuffle[t]][4]

        local batch_size = batch_y:size(1)

        -- create closure to evaluate f(X) and df/dX
        local feval = function(x)
            
            -- get new parameters
            if x ~= params then
                params:copy(x)
            end

            -- reset gradients
            gradParams:zero()
            
            -- reset errors
            local mse_err, bce_err, bce_1_err = 0, 0, 0
            
            local z_t_next_true = encoder:forward(batch_y)

            -- evaluate function for complete mini batch                                                
            local z_t_prev, z_t_cur, z_t_next, x_t, x_t_next = unpack(model:forward({batch_x_prev, batch_x_cur, batch_u}))  

            -- BCE x_t
            bce_err = bce_err + criterion:forward(x_t, batch_x_cur)
            local d_x_t = criterion:backward(x_t, batch_x_cur):clone()
            
            -- BCE x_t+1
            bce_1_err = bce_1_err + criterion:forward(x_t_next, batch_y)
            local d_x_t1 = criterion:backward(x_t_next, batch_y):clone()  
            
            -- MSE z_t+1
            mse_err = mse_err + criterion_mse:forward(z_t_next, z_t_next_true) * opt.lambda
            local d_z_t_next = criterion_mse:backward(z_t_next, z_t_next_true):clone():mul(opt.lambda)
            
            -- Backpropagate
            model:backward({batch_x_prev, batch_x_cur, batch_u}, {
                    torch.zeros(batch_size, opt.latent_dims),
                    torch.zeros(batch_size, opt.latent_dims),
                    torch.zeros(batch_size, opt.latent_dims),
                    d_x_t,
                    d_x_t1
                })
            
            local trans_in = torch.cat(torch.cat(z_t_prev, z_t_cur), batch_u)
            trans:forward(trans_in)
            trans:backward(trans_in, d_z_t_next)
            
            -- Accumulate errors
            err.mse = err.mse + mse_err
            err.bce = err.bce + bce_err
            err.bce_1 = err.bce_1 + bce_1_err
            err.all = err.all + bce_err + bce_1_err + mse_err
                        
            -- normalize gradients and f(X)
            local batcherr = (bce_err + bce_1_err + mse_err) / batch_size
            gradParams:div(batch_size)
                
            -- print(bce_err/batch_size, bce_1_err/batch_size, mse_err/batch_size)
                
            -- return f and df/dX
            return batcherr, gradParams
        end
        
        if batch_size > 0 then
            optim.adam(feval, params, optim_config)
            -- optim.adagrad(feval, params, optim_config)
            -- optim.rmsprop(feval, params, optim_config)
        end
        
    end
    
    -- Normalise errors
    err.all = err.all / (dataset.x:size(1) - 2)
    err.mse = err.mse / (dataset.x:size(1) - 2)
    err.bce = err.bce / (dataset.x:size(1) - 2)
    err.bce_1 = err.bce_1 / (dataset.x:size(1) - 2)
    
    epoch = epoch + 1

    return err
end

Train network


In [ ]:
-- epochs to run
opt.max_epoch = 50

-- start time
local beginning_time = torch.tic()

-- iterate through epochs
for e = 1, opt.max_epoch do
    
    -- local vars
    local time = sys.clock()
    
    -- train for 1 epoch
    local err = train(state_train)
        
    train_err[#train_err+1] = err
    
    -- time taken
    time = sys.clock() - time
    
    -- display stats
    if (epoch) % 1 == 0 then
        
        local since_beginning = g_d(torch.toc(beginning_time) / 60)
        print('epoch=' .. (epoch) ..
          ', Train err=' .. g_f3(train_err[#train_err].all) ..
          ', mse=' .. g_f3(train_err[#train_err].mse) ..
          ', bce=' .. g_f3(train_err[#train_err].bce) ..
          ', bce_1=' .. g_f3(train_err[#train_err].bce_1) ..
          ', t/epoch = ' .. g_f3(time) .. ' sec' ..
          ', t = ' .. since_beginning .. ' mins.')


        if (epoch) % 5 == 0 then
            save_model()
        end
    end
end

Plot Performance


In [ ]:
function get_error(err, criterion)
    local criterion = criterion or 'all'
    local arr = torch.zeros(#err)
    for i=1,#err do arr[i] = err[i][criterion] end    
    return arr
end

In [ ]:
colors = {'blue', 'green', 'red', 'purple', 'orange', 'magenta', 'cyan'}
plot = Plot()
plot:title(string.format('Single pendulum gravity - %d epochs', #train_err))
plot = plot:line(torch.range(2,#train_err), get_error(train_err,'all'), colors[1], 'L(D)')
plot = plot:line(torch.range(2,#train_err), get_error(train_err,'mse'), colors[2], '|| z-z_goal ||^2')
plot = plot:line(torch.range(2,#train_err), get_error(train_err,'bce'), colors[3], 'log p(x_t|z_t)')
plot = plot:line(torch.range(2,#train_err), get_error(train_err,'bce_1'), colors[4], 'log p(x_t+1|z_t+1)')
plot:legend(true):redraw()

Generate Predictions


In [ ]:
local idx = 600
local dataset = state_test

-- create mini batch
local dataset = state_train
local batch_x_prev = dataset.x:narrow(1, idx-1, 1)
local batch_x_cur = dataset.x:narrow(1, idx, 1)
local batch_x_next = dataset.x:narrow(1, idx+1, 1)
local batch_u = dataset.u:narrow(1, idx, 1)

local z_t_prev, z_t_cur, z_t_next, x_t, x_t_next = unpack(model:forward({batch_x_prev, batch_x_cur, batch_u}))  

disp_img(x_t)
disp_img(dataset.x[idx])
disp_img(x_t_next)
disp_img(dataset.x[idx+1])

In [ ]:
local dataset = state_test
local steps = 5
local idx = 610

-- create mini batch
local dataset = state_train
local x_prev = dataset.x:narrow(1, idx-1, 1)
local x_cur = dataset.x:narrow(1, idx, 1)

image.savePNG(string.format("preds/true-%05d.png", 0), dataset.x[idx]:view(opt.img_w,opt.img_w))

for i=0,steps-1 do
    local batch_u = dataset.u:narrow(1, idx+i, 1)

    local z_t_prev, z_t_cur, z_t_next, x_t, x_t_next = unpack(model:forward({x_prev, x_cur, batch_u}))  
    
    x_prev = x_cur:clone()
    x_cur = x_t_next:clone():view(1, opt.img_w^2)
    
    image.savePNG(string.format("preds/true-%05d.png", (i+1)), dataset.x[idx+1+i]:view(opt.img_w,opt.img_w))
    image.savePNG(string.format("preds/pred-%05d.png", (i+1)), x_t_next:view(opt.img_w,opt.img_w))
    
end

License

Copyright (C) 2015 John-Alexander M. Assael, Marc P. Deisenroth

The MIT License (MIT)

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.