NNGraph

A graph based container for creating deep learning models.

One can use nngraph to create simple networks too. Its syntax provides a more natural interface. It provides a function-like calling mechanism to connect nodes of computational units together. Every node has variable number of inputs and variable number of outputs.

The graph has to be fully defined, all connections have to be connected to a node and all inputs to a graph should be provided at runtime.


In [ ]:
require 'nngraph';

Start Simple

We can create a simple feedforward neural network easily.


In [ ]:
-- it is common style to mark inputs with identity nodes for clarity.
input = nn.Identity()()

-- each hidden layer is achieved by connecting the previous one
-- here we define a single hidden layer network
h1 = nn.Tanh()(nn.Linear(20, 10)(input))
output = nn.Linear(10, 1)(h1)
mlp = nn.gModule({input}, {output})

x = torch.rand(20)
dx = torch.rand(1)
mlp:updateOutput(x)
mlp:updateGradInput(x, dx)
mlp:accGradParameters(x, dx)

-- draw graph (the forward graph, '.fg')
-- this will produce an SVG in the runtime directory
graph.dot(mlp.fg, 'MLP', 'MLP')
itorch.image('MLP.svg')

Node Names

The name of the nodes are important. When we draw complicated graphs, we want to be able to match the graph to our code.


In [ ]:
local function get_network()
    -- it is common style to mark inputs with identity nodes for clarity.
    local input = nn.Identity()()

    -- each hidden layer is achieved by connecting the previous one
    -- here we define a single hidden layer network
    local h1 = nn.Linear(20, 10)(input)
    local h2 = nn.Sigmoid()(h1)
    local output = nn.Linear(10, 1)(h2)
    
    -- the following function call inspects the local variables in this
    -- function and finds the nodes corresponding to local variables.
    nngraph.annotateNodes()
    return nn.gModule({input}, {output}) 
end
mlp = get_network()
x = torch.rand(20)
dx = torch.rand(1)
mlp:updateOutput(x)
mlp:updateGradInput(x, dx)
mlp:accGradParameters(x, dx)

-- draw graph (the forward graph, '.fg')
-- this will produce an SVG in the runtime directory
graph.dot(mlp.fg, 'MLP', 'MLP_Annotated')
itorch.image('MLP_Annotated.svg')

Identifying Errors at Runtime


In [ ]:
-- We need to set debug flag to true
nngraph.setDebug(true)

local function get_network()
    -- it is common style to mark inputs with identity nodes for clarity.
    local input = nn.Identity()()

    -- each hidden layer is achieved by connecting the previous one
    -- here we define a single hidden layer network
    local h1 = nn.Linear(20, 10)(input)
    local h2 = nn.Sigmoid()(h1)
    local output = nn.Linear(10, 1)(h2)
    
    -- the following function call inspects the local variables in this
    -- function and finds the nodes corresponding to local variables.
    nngraph.annotateNodes()
    return nn.gModule({input}, {output}) 
end
mlp = get_network()
mlp.name = 'MyMLPError'
x = torch.rand(15) -- note that this input will cause runtime error

-- We do protected call to avoid real error interrupting the notebook
local o, err = pcall(function() mlp:updateOutput(x) end)
itorch.image('MyMLPError.svg')

But at the same time, an svg file with the name 'MyMLPError.svg' is produced where the node where the error occured is marked in red. One can easily see that the calculation of 'h1' was the problem.

A More Complete Example

Now we will create the core of an RNN module.


In [ ]:
function get_rnn(input_size, rnn_size)
  
    -- there are n+1 inputs (hiddens on each layer and x)
    local input = nn.Identity()()
    local prev_h = nn.Identity()()

    -- RNN tick
    local i2h = nn.Linear(input_size, rnn_size)(input)
    local h2h = nn.Linear(rnn_size, rnn_size)(prev_h)
    local added_h = nn.CAddTable()({i2h, h2h})
    local next_h = nn.Tanh()(added_h)
    
    nngraph.annotateNodes()
    return nn.gModule({input, prev_h}, {next_h})
end

local rnn_net = get_rnn(128, 128)
graph.dot(rnn_net.fg, 'rnn_net', 'rnn_net')
itorch.image('rnn_net.svg')

Connect in time

Now, let's connect these RNN cores in time


In [ ]:
local function get_rnn2(input_size, rnn_size)
    local input1 = nn.Identity()()
    local input2 = nn.Identity()()
    local prev_h = nn.Identity()()
    local rnn_net1 = get_rnn(128, 128)({input1, prev_h})
    local rnn_net2 = get_rnn(128, 128)({input2, rnn_net1})
    nngraph.annotateNodes()
    return nn.gModule({input1, input2, prev_h}, {rnn_net2})
end
local rnn_net2 = get_rnn2(128, 128)
graph.dot(rnn_net2.fg, 'rnn_net2', 'rnn_net2')
itorch.image('rnn_net2.svg')

More Debug

Even with variable naming, the graph becomes very complicated very quickly. One can then use custom annotations to mark certain paths.


In [ ]:
local function get_rnn2(input_size, rnn_size)
    local input1 = nn.Identity()():annotate{graphAttributes = {style='filled', fillcolor='blue'}}
    local input2 = nn.Identity()():annotate{graphAttributes = {style='filled', fillcolor='blue'}}
    local prev_h = nn.Identity()():annotate{graphAttributes = {style='filled', fillcolor='blue'}}
    local rnn_net1 = get_rnn(128, 128)({input1, prev_h}):annotate{graphAttributes = {style='filled', fillcolor='yellow'}}
    local rnn_net2 = get_rnn(128, 128)({input2, rnn_net1}):annotate{graphAttributes = {style='filled', fillcolor='green'}}
    nngraph.annotateNodes()
    return nn.gModule({input1, input2, prev_h}, {rnn_net2})
end
local rnn_net3 = get_rnn2(128, 128)
graph.dot(rnn_net3.fg, 'rnn_net3', 'rnn_net3')
itorch.image('rnn_net3.svg')


In [ ]:


In [ ]: