JITed RNNs for PyTorch

Thomas Viehmann tv@lernapparat.de

Here I'm trying to explore a flexible JITed LSTM / RNN implementation with the hope to eventually merge back an improved LSTM into PyTorch.

As usual, it should follow the functional + modular interface convention of PyTorch.

What others do as inspiration for the Features/API:

Top desired features (all = LSTM / GRU / RNN:)

  • Custom activation functions (all)
  • Layer norm (all)
  • Custom forget gate bias (LSTM)
  • Peephole connections (LSTM)
  • Cell clipping (LSTM)
  • Projection layer (LSTM) (lowest priority, but seems useful for large-scale applications) (edited)

Larger wishlist with links / details: PyTorch issue #9572.

Thank you

  • Kai Arulkumaran helped a lot collating the wishlist, links top desired features

Errors and such are all my own.

Help wanted!

If you have something to contribute, I'd be most happy. Note that we'll want to balance providing functionality to with providing a concise core PyTorch library. The good news is that you can copypaste the code here to get completely custom JITed fast RNNs!


In [21]:
import torch

Functional Interfacefor a Cell

We start with a functional interface for a cell.

One wishlist item is to have custom (recurrent(?)) activations. We define this as a pointwise function.


In [26]:
@torch.jit.script
def activation_cell(cx):
    return torch.tanh(cx)

In [27]:
# note takes non-parameter jit.script functions activation_cell from context at definition time! (Probably will want to do this in a factory style function even if it's not 100% Pythonic)
@torch.jit.script
def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
    # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
    hx, cx = hidden
    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = activation_cell(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy

In [28]:
print(lstm_cell.graph.pretty_print())


def graph(self,
    input: Tensor,
    hidden: Tuple[Tensor, Tensor],
    w_ih: Tensor,
    w_hh: Tensor,
    b_ih: Tensor,
    b_hh: Tensor) -> Tuple[Tensor, Tensor]:
  hx, cx, = hidden
  _0 = torch.add(torch.mm(input, torch.t(w_ih)), torch.mm(hx, torch.t(w_hh)), alpha=1)
  gates = torch.add(torch.add(_0, b_ih, alpha=1), b_hh, alpha=1)
  ingate, forgetgate, cellgate, outgate, = torch.chunk(gates, 4, 1)
  ingate0 = torch.sigmoid(ingate)
  forgetgate0 = torch.sigmoid(forgetgate)
  cellgate0 = torch.tanh(cellgate)
  outgate0 = torch.sigmoid(outgate)
  cy = torch.add(torch.mul(forgetgate0, cx), torch.mul(ingate0, cellgate0), alpha=1)
  hy = torch.mul(outgate0, torch.tanh(cy))
  return (hy, cy)


In [ ]:

Modules

Tests


In [ ]: