Thomas Viehmann tv@lernapparat.de
Here I'm trying to explore a flexible JITed LSTM / RNN implementation with the hope to eventually merge back an improved LSTM into PyTorch.
As usual, it should follow the functional + modular interface convention of PyTorch.
What others do as inspiration for the Features/API:
Top desired features (all = LSTM / GRU / RNN:)
Larger wishlist with links / details: PyTorch issue #9572.
Errors and such are all my own.
If you have something to contribute, I'd be most happy. Note that we'll want to balance providing functionality to with providing a concise core PyTorch library. The good news is that you can copypaste the code here to get completely custom JITed fast RNNs!
In [21]:
import torch
In [26]:
@torch.jit.script
def activation_cell(cx):
return torch.tanh(cx)
In [27]:
# note takes non-parameter jit.script functions activation_cell from context at definition time! (Probably will want to do this in a factory style function even if it's not 100% Pythonic)
@torch.jit.script
def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
hx, cx = hidden
gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = activation_cell(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
In [28]:
print(lstm_cell.graph.pretty_print())
In [ ]:
In [ ]: