Training Pytorch models with TVM computation

a tutorial by Thomas Viehmann tv@lernapparat.de

Acknowledgement & Disclosure: The creation of this tutorial was sponsored by AMD. Thank you!

Following our tutorial on running and tuning PyTorch models on TVM, we can look at using TVM to speed up training, too. Of course, this opens an entire new can of worms as we need to deal with autodifferentiation.

Our goal in this tutorial is to take a non-trivial module (we'll use BertLayer from HuggingFace transformer's BertModel) and divert the computation during training to TVM. So the user can take a (traceable) module and do

add_tvm_dispatch(module, sample_input)

and then if she calls module with inputs of the same shape as the sample_input, she'll get the outputs computed by TVM (as PyTorch tensors, of course) and if not, it'll just use the regular forward.

The bad new first: This tutorial shows how to do these things. We will not yet achieve a great speedup in this tutorial.

But enough talk, let us dive right in!

The first thing to do is import things and get the model we want.


In [1]:
import inspect
import types
import sys

# I sometimes need to choose PyTorch...
#sys.path.insert(0, '/home/tv/pytorch/pytorch/build/lib.linux-x86_64-3.8//')
import torch
import torch.utils.dlpack

# import TVM
import sys
import os


tvm_root = '/home/tv/rocm/tvm/tvm/'
tvm_paths = [os.path.join(tvm_root, p) for p in ['python', 'topi/python', 'nnvm/python']]
os.environ['PYTHONPATH'] = ':'.join([os.environ.get('PYTHONPATH', '')] + tvm_paths)
for p in tvm_paths:
    sys.path.insert(0, p)
    

import tvm
import tvm.relay

torch.cuda.get_device_name()


Out[1]:
'Device 66af'

Helpfully, transformers supports tracing their model with the PyTorch JIT. We use their tutorial on it, the following is copied straight from the tutorial


In [2]:
import transformers

from transformers import BertModel, BertTokenizer, BertConfig
import numpy

import torch

enc = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

model.eval()
for p in model.parameters():
    p.requires_grad_(False)

transformers.__version__


Out[2]:
'3.0.0'

Now we can trace our model. As we want to do inference, we impose evaluation mode and not requiring gradients for the parameters.


In [3]:
dtype = torch.float32
dtype_str = str(dtype).split('.')[-1]

In [4]:
# Creating the trace
model.to(dtype)
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
traced_model.eval()
for p in traced_model.parameters():
    p.requires_grad_(False)

Readers of the PyTorch Bert & TVM tutorial will recall the wrapper we had for getting inputs and outputs of a submodule of the model.


In [5]:
class DebugWrap(torch.nn.Module):
    def __init__(self, root, target_qn):
        super().__init__()
        self.root = (root,) # Hide from PyTorch
        parent, = self.root
        target_qn = target_qn.split('.')
        self.target_basename = target_qn[-1]
        for nc in target_qn[:-1]:
            parent = getattr(parent, nc)
        self.parent = (parent,)
        target = getattr(parent, self.target_basename)
        self.wrapped = target
        setattr(parent, self.target_basename, self)
    def remove(self):
        parent, = self.parent
        setattr(parent, self.target_basename, self.wrapped)
        self.root = None
    def forward(self, *inp, **kwinp):
        assert self.root is not None
        self.DEBUG_INP = inp
        self.DEBUG_KWINP = kwinp
        out = self.wrapped(*inp, **kwinp)
        self.DEBUG_OUT = out
        return out

We also had a fancy visualization. We now have a small addition, the dictionary to specify attributes for nodes. This will come in handy later.


In [6]:
import graphviz
def visualize(expr, collapse_small=True, node_attr_dict = {}):
    def collect_ops(node):
        ops = set()
        def visitor(e):
            if isinstance(e, tvm.ir.Op):
                ops.add(e.name)
        tvm.relay.analysis.post_order_visit(node, visitor)
        return ops

    # node_dict maps a Relay node to an index (node ID)
    def _traverse_expr(node, node_dict):
        if node in node_dict:
            return
        node_dict[node] = len(node_dict)

    node_dict = {}
    tvm.relay.analysis.post_order_visit(expr, lambda x: _traverse_expr(x, node_dict))

    relayviz_nodes = []

    dot = graphviz.Digraph(format='svg')
    dot.attr('node', shape = 'box')

    def to_str(node):
        if isinstance(node, tvm.relay.Constant):
            return repr(node).lstrip('Constant(')[:-1]
        else:
            raise NotImplementedError("to_str:" + repr(node))

    def is_small_const(c):
        if not (collapse_small and isinstance(c, tvm.relay.Constant)):
            return False
        if isinstance(c.data, tvm.runtime.ndarray.NDArray):
            return numpy.prod(c.data.shape) < 10
        return True
            
    # Sort by node ID
    for node, node_id in sorted(node_dict.items(), key=lambda x: x[1]):
        if isinstance(node, tvm.relay.Function):
            dot.node(str(node_id), 'Function', **node_attr_dict.get(node, {}))
            dot.edge(str(node_dict[node.body]), str(node_id))
        elif isinstance(node, tvm.relay.Var):
            if node.type_annotation is not None:
                if hasattr(node.type_annotation, 'shape'):
                    shape = tuple([int(x) for x in node.type_annotation.shape])
                    dtype = node.type_annotation.dtype
                    typstr = 'Tensor[{}, {}]'.format(shape, dtype)
                else:
                    typstr = str(node.type_annotation)
            else:
                typstr = '?'
            d = dict(shape = 'ellipse')
            d.update(node_attr_dict.get(node, {}))
            dot.node(str(node_id),
                     '{}: {}'.format(
                         node.name_hint, typstr
                     ), **d)
        elif isinstance(node, tvm.relay.Tuple):
            dot.node(str(node_id), 'Tuple[...])', **node_attr_dict.get(node, {}))
            for field in node.fields:
                dot.edge(str(node_dict[field]), str(node_id))
        elif isinstance(node, tvm.relay.Constant):
            
            if not is_small_const(node): # small consts are shown in ops
                dot.node(str(node_id), 'Constant({}, {})'.format(node.data.shape, node.data.dtype),
                        **node_attr_dict.get(node, {}))
        elif isinstance(node, tvm.relay.Call):
            args_with_edge = []
            arg_str_list = []
            for arg in node.args:
                if is_small_const(arg):
                    arg_str_list.append(to_str(arg))
                else:
                    arg_str_list.append('·')
                    args_with_edge.append(arg)
            arg_str = ', '.join(arg_str_list)
            if isinstance(node.op, tvm.ir.Op):
                name = node.op.name
                attrs = {k:getattr(node.attrs, k) for k in node.attrs.keys()} if hasattr(node.attrs, 'keys') else {}
                #attrs = inspect.getmembers(node.attrs)
                attr_str_list = [k+'='+(str(v) if len(str(v))<15 else "...") for k, v in attrs.items()]
                if attr_str_list:
                    attr_str = '| '+ ', '.join(attr_str_list)
                else:
                    attr_str = ''
            else:
                ops = collect_ops(node)
                if ops:
                    name = '_'.join(ops)
                else:
                    name = '...'
                attr_str = ''
            s = f'{name}({arg_str}{attr_str})'
            dot.node(str(node_id), s, **node_attr_dict.get(node, {}))
            for arg in args_with_edge:
                dot.edge(str(node_dict[arg]), str(node_id))
        elif isinstance(node, tvm.ir.Op):
            # dot.node(str(node_id), 'Op {}'.format(node.name))
            pass # covered in call
        elif isinstance(node, tvm.relay.TupleGetItem):
            dot.node(str(node_id), 'TupleGetItem(idx={})'.format(node.index), **node_attr_dict.get(node, {}))
            dot.edge(str(node_dict[node.tuple_value]), str(node_id))
        elif isinstance(node, tvm.relay.Let):
            dot.node(str(node_id), 'Let(XX)', **node_attr_dict.get(node, {}))
            dot.edge(str(node_dict[node.value]), str(node_id))
            dot.edge(str(node_id), str(node_dict[node.var]))
        else:
            raise RuntimeError(
                'Unknown node type. node_id: {}, node: {}'.format(node_id, type(node)))

    return dot

Let's wrap the first BertLayer in our model. You could also take smaller bits if you run my tutorials on your phone and want smaller graphs.


In [7]:
try:
    debug_wrap = DebugWrap(model, "encoder.layer.0") # encoder.layer.0.attention.self
    tt = tokens_tensor.cpu()
    st = segments_tensors.cpu()
    model(tt, st)
finally:
    debug_wrap.remove()

We trace the module.


In [8]:
model.train()
traced_module = torch.jit.trace(debug_wrap.wrapped, [i.to(dtype) for i in debug_wrap.DEBUG_INP[:2]])


/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py:954: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
With rtol=1e-05 and atol=1e-05, found 10750 element(s) (out of 10752) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 2.075359344482422 (-4.184629917144775 vs. -2.1092705726623535), which occurred at index (0, 6, 381).
  _check_trace(

Let's convert the traced model to TVM. This works just as before.


In [9]:
shape_list = [(i.debugName().split('.')[0], i.type().sizes()) for i in  list(traced_module.graph.inputs())[1:]]
mod, mod_params = tvm.relay.frontend.from_pytorch(traced_module, shape_list, default_dtype=dtype_str)


ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32

One thing we'll do in between is to move from a module interface - with named parameters - to a functional interface (which is what TVM can do for us). The first thing we want to do for that is arrange for the function arguments to be in an order that we can work with - i.e. first the direct inputs to the module and then the parameters in the same order that PyTorch uses them.


In [10]:
# the converter will output arguments in an arbitrary order (well, by position of use), we want that of the input
fn = mod['main']
# Careful traced module's vs. non-traced module's parameter ordering.
# Anecdotally, I have not seen orderings differ between the two, though.
arg_order = ([n for n, _ in shape_list]
             +[n for n, _ in traced_module.named_parameters()])
tmp_arg_idx = {p.name_hint: i for i, p in enumerate(fn.params)}

fn = tvm.relay.Function([fn.params[tmp_arg_idx[n]] for n in arg_order], fn.body)

Let's look at our function.


In [11]:
visualize(fn)


Out[11]:
%3 0 input: Tensor[(1, 14, 768), float32] 28 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->28 38 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->38 59 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->59 84 add(·, ·) 0->84 1 attention_mask: Tensor[(1, 1, 1, 14), float32] 54 add(·, ·) 1->54 2 attention.self.query.weight: Tensor[(768, 768), float32] 29 transpose(·| axes=[1, 0]) 2->29 3 attention.self.query.bias: Tensor[(768,), float32] 34 add(·, ·) 3->34 4 attention.self.key.weight: Tensor[(768, 768), float32] 39 transpose(·| axes=[1, 0]) 4->39 5 attention.self.key.bias: Tensor[(768,), float32] 44 add(·, ·) 5->44 6 attention.self.value.weight: Tensor[(768, 768), float32] 60 transpose(·| axes=[1, 0]) 6->60 7 attention.self.value.bias: Tensor[(768,), float32] 65 add(·, ·) 7->65 8 attention.output.dense.weight: Tensor[(768, 768), float32] 76 transpose(·| axes=[1, 0]) 8->76 9 attention.output.dense.bias: Tensor[(768,), float32] 81 add(·, ·) 9->81 10 attention.output.LayerNorm.weight: Tensor[(768,), float32] 85 nn.layer_norm(·, ·, ·| axis=-1, epsilon=1e-12, center=1, scale=1) 10->85 11 attention.output.LayerNorm.bias: Tensor[(768,), float32] 11->85 12 intermediate.dense.weight: Tensor[(3072, 768), float32] 87 transpose(·| axes=[1, 0]) 12->87 13 intermediate.dense.bias: Tensor[(3072,), float32] 92 add(·, ·) 13->92 14 output.dense.weight: Tensor[(768, 3072), float32] 103 transpose(·| axes=[1, 0]) 14->103 15 output.dense.bias: Tensor[(768,), float32] 108 add(·, ·) 15->108 16 output.LayerNorm.weight: Tensor[(768,), float32] 112 nn.layer_norm(·, ·, ·| axis=-1, epsilon=1e-12, center=1, scale=1) 16->112 17 output.LayerNorm.bias: Tensor[(768,), float32] 17->112 32 nn.batch_matmul(·, ·) 28->32 30 reshape(·| newshape=[-1, 768, 768], reverse=0) 29->30 31 transpose(·| axes=[0, 2, 1]) 30->31 31->32 33 reshape(·| newshape=[1, 14, 768], reverse=0) 32->33 33->34 35 reshape(·| newshape=..., reverse=0) 34->35 36 transpose(·| axes=[0, 2, 1, 3]) 35->36 37 reshape(·| newshape=[-1, 14, 64], reverse=0) 36->37 50 nn.batch_matmul(·, ·) 37->50 42 nn.batch_matmul(·, ·) 38->42 40 reshape(·| newshape=[-1, 768, 768], reverse=0) 39->40 41 transpose(·| axes=[0, 2, 1]) 40->41 41->42 43 reshape(·| newshape=[1, 14, 768], reverse=0) 42->43 43->44 45 reshape(·| newshape=..., reverse=0) 44->45 46 transpose(·| axes=[0, 2, 1, 3]) 45->46 47 transpose(·| axes=[0, 1, 3, 2]) 46->47 48 reshape(·| newshape=[-1, 64, 14], reverse=0) 47->48 49 transpose(·| axes=[0, 2, 1]) 48->49 49->50 51 reshape(·| newshape=..., reverse=0) 50->51 53 divide(·, 8.0) 51->53 53->54 55 nn.softmax(·| axis=-1) 54->55 56 nn.dropout(·| rate=0.1) 55->56 57 TupleGetItem(idx=0) 56->57 58 reshape(·| newshape=[-1, 14, 14], reverse=0) 57->58 70 nn.batch_matmul(·, ·) 58->70 63 nn.batch_matmul(·, ·) 59->63 61 reshape(·| newshape=[-1, 768, 768], reverse=0) 60->61 62 transpose(·| axes=[0, 2, 1]) 61->62 62->63 64 reshape(·| newshape=[1, 14, 768], reverse=0) 63->64 64->65 66 reshape(·| newshape=..., reverse=0) 65->66 67 transpose(·| axes=[0, 2, 1, 3]) 66->67 68 reshape(·| newshape=[-1, 14, 64], reverse=0) 67->68 69 transpose(·| axes=[0, 2, 1]) 68->69 69->70 71 reshape(·| newshape=..., reverse=0) 70->71 72 transpose(·| axes=[0, 2, 1, 3]) 71->72 73 copy(·) 72->73 74 reshape(·| newshape=[1, 14, 768], reverse=0) 73->74 75 reshape(·| newshape=[-1, 14, 768], reverse=0) 74->75 79 nn.batch_matmul(·, ·) 75->79 77 reshape(·| newshape=[-1, 768, 768], reverse=0) 76->77 78 transpose(·| axes=[0, 2, 1]) 77->78 78->79 80 reshape(·| newshape=[1, 14, 768], reverse=0) 79->80 80->81 82 nn.dropout(·| rate=0.1) 81->82 83 TupleGetItem(idx=0) 82->83 83->84 84->85 86 reshape(·| newshape=[-1, 14, 768], reverse=0) 85->86 111 add(·, ·) 85->111 90 nn.batch_matmul(·, ·) 86->90 88 reshape(·| newshape=..., reverse=0) 87->88 89 transpose(·| axes=[0, 2, 1]) 88->89 89->90 91 reshape(·| newshape=[1, 14, 3072], reverse=0) 90->91 91->92 96 multiply(·, 0.70710677) 92->96 101 multiply(·, ·) 92->101 97 erf(·) 96->97 99 multiply(·, 0.5) 97->99 100 add(0.5, ·) 99->100 100->101 102 reshape(·| newshape=[-1, 14, 3072], reverse=0) 101->102 106 nn.batch_matmul(·, ·) 102->106 104 reshape(·| newshape=..., reverse=0) 103->104 105 transpose(·| axes=[0, 2, 1]) 104->105 105->106 107 reshape(·| newshape=[1, 14, 768], reverse=0) 106->107 107->108 109 nn.dropout(·| rate=0.1) 108->109 110 TupleGetItem(idx=0) 109->110 110->111 111->112 113 Tuple[...]) 112->113 114 Function 113->114

As in the BERT inference, we want to run some optimization passes. It'll be convenient to do this on at a function level, so we're wrapping some standard TVM passes to work like this, too.

We already know the ShapeConstDedupMutator and the TransposeDedupMutator from the inference notebook, deduplicating some of the things that came with the PyTorch conversion.

But we also have a few new transformations:

  • One particularity of the Autodifferentiation is that it'll use a lot of ..._like operations to broadcast or "unbroadcast" (summation is the dual of broadcasting w.r.t. autodifferentiation) things. But this means that you now have two tensor arguments, even if the latter doesn't really need a gradient. ZappLike replaces those operations with the corresponding functions taking a shape parameter instead.
  • Another thing is the "rooting" of derivatives. TVM generates a tensors with all ones of the same shape as the return values of our function as the starting point for the chain rule. These are then multiplied to the derivatives of our operations. But multiplication with ones is not doing much, so we strike that. Similarly, TVM initializes the gradient of a variable (an input) to zeros of the same shape. If it isn't used, the gradient will be zero, but if it is, the "real gradient" will be added to that zero. But adding zero can be eliminated as well. These are taken care off by ZeroZapp and OneZapp.
  • TVM doesn't have a training variant for the LayerNorm (or BatchNorm or others). So we implement a pass to spell out the computation.
  • TVM also doesn't have training dropout. Here the problem is somewhat harder to fix, as TVM doesn't have random currently. We instead replace the dropout by a construct taking a random bernoulli draw (of 0/1 values) and mimicking dropout with that. The idea is that we'll use PyTorch to generate this mask for us. This has the added benefit that (if we generate dropout masks in the same order as PyTorch) we'll get the exact same result.

So here is this bit of infrastructure:


In [12]:
import numpy

def work_on_fn(pass_cls):
    def apply_pass(fn_or_mod):
        if isinstance(fn_or_mod, tvm.IRModule):
            return pass_cls()(fn_or_mod)
        if isinstance(fn_or_mod, tvm.relay.Function):
            return pass_cls()(
                       tvm.IRModule({'main': fn_or_mod}))['main']
        raise NotImplemented("unsupporded type {}".format(type(fn_or_mod)))
    return apply_pass

infer_type = work_on_fn(tvm.relay.transform.InferType)
to_graph_normal_form = work_on_fn(tvm.relay.transform.ToGraphNormalForm)
dead_code_elimination = work_on_fn(tvm.relay.transform.DeadCodeElimination)
eliminate_common_subexpr = work_on_fn(tvm.relay.transform.EliminateCommonSubexpr)

class ShapeConstDedupMutator(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.shape_consts = {}

    def visit_call(self, call):
        if (isinstance(call.op, tvm.ir.Op) 
            and call.op.name in {"reshape", "broadcast_to", "collapse_sum_to"}
            and isinstance(call.args[1], tvm.relay.Constant)):
            # assert list(call.attrs.newshape) == list(call.args[1].data.asnumpy())
            new_fn = self.visit(call.op)
            new_args = [self.visit(arg) for arg in call.args]
            const = new_args[1]
            assert const.data.dtype.startswith('int') and len(const.data.shape)==1
            key = tuple(const.data.asnumpy())
            if key in self.shape_consts:
                new_args[1] = self.shape_consts[key]
            else:
                self.shape_consts[key] = new_args[1]
            return tvm.relay.Call(new_fn, new_args, call.attrs)
        return super().visit_call(call)


class TransposeDedupMutator(tvm.relay.ExprMutator):
    def visit_call(self, call):
        if (isinstance(call.op, tvm.ir.Op) and call.op.name == "transpose"
            and isinstance(call.args[0], tvm.relay.Call) 
            and isinstance(call.args[0].op, tvm.ir.Op) and call.args[0].op.name == "transpose"):
            axes = [call.args[0].attrs.axes[int(i)] for i in call.attrs.axes]
            new_inp = self.visit(call.args[0].args[0])
            if axes == list(range(len(axes))): # neutral permutation, should really do this separately...
                return new_inp
            return tvm.relay.transpose(new_inp, axes)
        return super().visit_call(call)

#@tvm.relay.transform.function_pass(opt_level=1)
#def TransposeDedup(fn, mod, ctx):
#    return TransposeDedupMutator().visit(fn)

class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.zeros = tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard())
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.zeros + self.other_tensor) | (self.other_tensor + self.zeros)

    def callback(self, pre, post, node_map):
        rt = node_map[self.pattern][0]
        ot = node_map[self.other_tensor][0]
        if (ot._checked_type_ == rt._checked_type_):
            return ot
        else:
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))

class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.ones = tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard()) | tvm.relay.dataflow_pattern.is_constant()
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.ones + self.other_tensor) | (self.other_tensor + self.ones)

    def callback(self, pre, post, node_map):
        rt = node_map[self.pattern][0]
        ones = node_map[self.ones][0]
        ot = node_map[self.other_tensor][0]
        if isinstance(ones, tvm.relay.Constant):
            val = ones.data.asnumpy()
            if not ((val == 0) if numpy.isscalar(val) else (val == 0).all()):
                return rt
        # I don't know why I don't reliably get checked types here...
        if (((rt._checked_type_ is not None) and (ot._checked_type_ == rt._checked_type_))
            or (rt.type_args[0] == rt.type_args[1])):
            return ot
        elif (rt._checked_type_ is not None):
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
        return rt

class OneZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.ones = tvm.relay.dataflow_pattern.is_op("ones")(tvm.relay.dataflow_pattern.wildcard()) | tvm.relay.dataflow_pattern.is_constant()
        self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = (self.ones * self.other_tensor) | (self.other_tensor * self.ones)

    def callback(self, pre, post, node_map):
        global val
        rt = node_map[self.pattern][0]
        ones = node_map[self.ones][0]
        ot = node_map[self.other_tensor][0]
        if isinstance(ones, tvm.relay.Constant):
            val = ones.data.asnumpy()
            if not ((val == 1) if numpy.isscalar(val) else (val == 1).all()):
                return rt
        if (((rt._checked_type_ is not None) and (ot._checked_type_ == rt._checked_type_))
            or (rt.type_args[0] == rt.type_args[1])):
            return ot
        if (rt._checked_type_ is not None):
            return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
        return rt


class LikeZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
    def __init__(self):
        self.translations_with_dt = {'zeros_like': tvm.relay.zeros,
                                     'ones_like': tvm.relay.ones}
        self.data_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern_tensor = tvm.relay.dataflow_pattern.wildcard()
        self.pattern = ((tvm.relay.dataflow_pattern.is_op("zeros_like")
                        | tvm.relay.dataflow_pattern.is_op("ones_like")
                        )(self.data_tensor)
                        ) | ((
                        tvm.relay.dataflow_pattern.is_op("collapse_sum_like")
                        | tvm.relay.dataflow_pattern.is_op("reshape_like")
                        | tvm.relay.dataflow_pattern.is_op("broadcast_to_like")
                       )(self.data_tensor, self.pattern_tensor))

    def callback(self, pre, post, node_map):
        data = node_map[self.data_tensor][0]
        res = node_map[self.pattern][0]
        if res.op.name in self.translations_with_dt:
            ret = self.translations_with_dt[res.op.name](list(res.type_args[0].shape),
                                                              res.type_args[0].dtype) # which dtype?
            return ret
        if (res.type_args[0] is not None and res.type_args[0] == res.type_args[1]):
            return data
        if res.op.name == 'broadcast_to_like':
            return tvm.relay.broadcast_to(data, list(res.type_args[1].shape))
        if res.op.name == 'reshape_like':
            return tvm.relay.reshape(data, list(res.type_args[1].shape))
        if res.op.name == 'collapse_sum_like':
            return tvm.relay.collapse_sum_to(data, list(res.type_args[1].shape))
        return res


class DecomposeLayerNorm(tvm.relay.dataflow_pattern.DFPatternCallback):
    # TVM doesn't have a LayerNorm backward
    def __init__(self):
        self.pattern = tvm.relay.dataflow_pattern.is_op("nn.layer_norm")(
            tvm.relay.dataflow_pattern.wildcard(),
            tvm.relay.dataflow_pattern.wildcard(),
            tvm.relay.dataflow_pattern.wildcard())

    def callback(self, pre, post, node_map):
        # probably only 1d...
        res = node_map[self.pattern][0]
        inp, weight, bias = res.args
        mean = tvm.relay.mean(inp, axis=res.attrs.axis, keepdims=True)
        std = tvm.relay.std(inp, axis=res.attrs.axis, keepdims=True)
        res_new = ((inp - mean) / (std + tvm.relay.const(res.attrs.epsilon, dtype=res.type_args[0].dtype))) * weight + bias
        return res_new

class ExternalizeDropout(tvm.relay.dataflow_pattern.DFPatternCallback):
    # TVM doesn't have a Dropout defined (for inference it can be deleted)
    # but it also does not appear to have random, so we make the random draw
    # an input
    def __init__(self):
        self.dropout_info = {}
        self.counter = 0
        self.inp = tvm.relay.dataflow_pattern.wildcard()
        self.dropout = tvm.relay.dataflow_pattern.is_op("nn.dropout")(self.inp)
        self.pattern = tvm.relay.dataflow_pattern.is_tuple_get_item(self.dropout, 0)

    def callback(self, pre, post, node_map):
        res = node_map[self.pattern][0]
        dropout = node_map[self.dropout][0]
        inp = node_map[self.inp][0]
        typ = dropout.type_args[0]
        rate = dropout.attrs.rate
        name = f"dropout:{self.counter}"
        self.counter += 1
        do_var = tvm.relay.var(name, type_annotation=typ)
        self.dropout_info[name] = (rate, typ)
        return inp * (do_var * tvm.relay.const(1 / (1 - rate), dtype=typ.dtype))

def externalize_dropout(fn):
    edo = ExternalizeDropout()
    fn = tvm.relay.dataflow_pattern.rewrite(edo, fn)
    return fn, edo.dropout_info

In [ ]:

As hinted at above, TVM's gradient taking assumes that it is the last element in the computation (the ones-Tensors discussed above). This isn't a good fit with PyTorch's modular view which expects a grad_out for each output to be given. Happily, this is computationally equivalent to multiplying by grad out and summation, so we amend our function with that. We wish to be flexible, so we allow both functions returning a single tensor and those returning a tuple of tensors. Also we apply the passes handling layer norm and the dropout .


In [14]:
fn = TransposeDedupMutator().visit(fn)
fn = infer_type(fn)
output_type = fn.body.checked_type

if isinstance(output_type, tvm.relay.TensorType):
    gr_out = tvm.relay.var("gr:out", output_type)
    fn_for_gr = tvm.relay.Function(list(fn.params) + [gr_out], tvm.relay.sum(fn.body * gr_out))
else:
    # we can try to handle tuples of tensors, but our nesting patience ends there
    assert (isinstance(output_type, tvm.relay.TupleType) and
            all([isinstance(f, tvm.relay.TensorType) for f in output_type.fields]))
    gr_outs = [tvm.relay.var(f"gr:out:{i}", t) for i, t in enumerate(output_type.fields)]
    prods_with_gr_out = [tvm.relay.sum(tvm.relay.TupleGetItem(fn.body, i) * go_i)
                         for i, go_i in enumerate(gr_outs)]
    s = prods_with_gr_out[0]
    for p in prods_with_gr_out[1:]:
        s = s + p
    fn_for_gr = tvm.relay.Function(list(fn.params) + gr_outs, s)
fn_for_gr = infer_type(fn_for_gr)
fn_for_gr = tvm.relay.dataflow_pattern.rewrite(DecomposeLayerNorm(), fn_for_gr)
fn_for_gr = infer_type(fn_for_gr)
fn_for_gr, dropout_info = externalize_dropout(fn_for_gr)
fn_for_gr = infer_type(fn_for_gr)

visualize(fn_for_gr)


Out[14]:
%3 0 input: Tensor[(1, 14, 768), float32] 32 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->32 42 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->42 63 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->63 89 add(·, ·) 0->89 1 attention_mask: Tensor[(1, 1, 1, 14), float32] 57 add(·, ·) 1->57 2 attention.self.query.weight: Tensor[(768, 768), float32] 33 transpose(·| axes=[1, 0]) 2->33 3 attention.self.query.bias: Tensor[(768,), float32] 38 add(·, ·) 3->38 4 attention.self.key.weight: Tensor[(768, 768), float32] 43 transpose(·| axes=[1, 0]) 4->43 5 attention.self.key.bias: Tensor[(768,), float32] 48 add(·, ·) 5->48 6 attention.self.value.weight: Tensor[(768, 768), float32] 64 transpose(·| axes=[1, 0]) 6->64 7 attention.self.value.bias: Tensor[(768,), float32] 69 add(·, ·) 7->69 8 attention.output.dense.weight: Tensor[(768, 768), float32] 80 transpose(·| axes=[1, 0]) 8->80 9 attention.output.dense.bias: Tensor[(768,), float32] 85 add(·, ·) 9->85 10 attention.output.LayerNorm.weight: Tensor[(768,), float32] 101 multiply(·, ·) 10->101 11 attention.output.LayerNorm.bias: Tensor[(768,), float32] 102 add(·, ·) 11->102 12 intermediate.dense.weight: Tensor[(3072, 768), float32] 104 transpose(·| axes=[1, 0]) 12->104 13 intermediate.dense.bias: Tensor[(3072,), float32] 109 add(·, ·) 13->109 14 output.dense.weight: Tensor[(768, 3072), float32] 120 transpose(·| axes=[1, 0]) 14->120 15 output.dense.bias: Tensor[(768,), float32] 125 add(·, ·) 15->125 16 output.LayerNorm.weight: Tensor[(768,), float32] 138 multiply(·, ·) 16->138 17 output.LayerNorm.bias: Tensor[(768,), float32] 139 add(·, ·) 17->139 18 gr:out:0: Tensor[(1, 14, 768), float32] 142 multiply(·, ·) 18->142 19 dropout:0: Tensor[(1, 12, 14, 14), float32] 60 multiply(·, 1.1111112) 19->60 20 dropout:1: Tensor[(1, 14, 768), float32] 87 multiply(·, 1.1111112) 20->87 21 dropout:2: Tensor[(1, 14, 768), float32] 127 multiply(·, 1.1111112) 21->127 36 nn.batch_matmul(·, ·) 32->36 34 reshape(·| newshape=[-1, 768, 768], reverse=0) 33->34 35 transpose(·| axes=[0, 2, 1]) 34->35 35->36 37 reshape(·| newshape=[1, 14, 768], reverse=0) 36->37 37->38 39 reshape(·| newshape=..., reverse=0) 38->39 40 transpose(·| axes=[0, 2, 1, 3]) 39->40 41 reshape(·| newshape=[-1, 14, 64], reverse=0) 40->41 53 nn.batch_matmul(·, ·) 41->53 46 nn.batch_matmul(·, ·) 42->46 44 reshape(·| newshape=[-1, 768, 768], reverse=0) 43->44 45 transpose(·| axes=[0, 2, 1]) 44->45 45->46 47 reshape(·| newshape=[1, 14, 768], reverse=0) 46->47 47->48 49 reshape(·| newshape=..., reverse=0) 48->49 50 transpose(·| axes=[0, 2, 3, 1]) 49->50 51 reshape(·| newshape=[-1, 64, 14], reverse=0) 50->51 52 transpose(·| axes=[0, 2, 1]) 51->52 52->53 54 reshape(·| newshape=..., reverse=0) 53->54 56 divide(·, 8.0) 54->56 56->57 58 nn.softmax(·| axis=-1) 57->58 61 multiply(·, ·) 58->61 60->61 62 reshape(·| newshape=[-1, 14, 14], reverse=0) 61->62 74 nn.batch_matmul(·, ·) 62->74 67 nn.batch_matmul(·, ·) 63->67 65 reshape(·| newshape=[-1, 768, 768], reverse=0) 64->65 66 transpose(·| axes=[0, 2, 1]) 65->66 66->67 68 reshape(·| newshape=[1, 14, 768], reverse=0) 67->68 68->69 70 reshape(·| newshape=..., reverse=0) 69->70 71 transpose(·| axes=[0, 2, 1, 3]) 70->71 72 reshape(·| newshape=[-1, 14, 64], reverse=0) 71->72 73 transpose(·| axes=[0, 2, 1]) 72->73 73->74 75 reshape(·| newshape=..., reverse=0) 74->75 76 transpose(·| axes=[0, 2, 1, 3]) 75->76 77 copy(·) 76->77 78 reshape(·| newshape=[1, 14, 768], reverse=0) 77->78 79 reshape(·| newshape=[-1, 14, 768], reverse=0) 78->79 83 nn.batch_matmul(·, ·) 79->83 81 reshape(·| newshape=[-1, 768, 768], reverse=0) 80->81 82 transpose(·| axes=[0, 2, 1]) 81->82 82->83 84 reshape(·| newshape=[1, 14, 768], reverse=0) 83->84 84->85 88 multiply(·, ·) 85->88 87->88 88->89 91 mean(·| axis=[-1], keepdims=1, exclude=0) 89->91 92 subtract(·, ·) 89->92 95 mean(·| axis=[-1], keepdims=1, exclude=0) 89->95 96 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 89->96 91->92 100 divide(·, ·) 92->100 95->96 97 sqrt(·) 96->97 99 add(·, 1e-12) 97->99 99->100 100->101 101->102 103 reshape(·| newshape=[-1, 14, 768], reverse=0) 102->103 129 add(·, ·) 102->129 107 nn.batch_matmul(·, ·) 103->107 105 reshape(·| newshape=..., reverse=0) 104->105 106 transpose(·| axes=[0, 2, 1]) 105->106 106->107 108 reshape(·| newshape=[1, 14, 3072], reverse=0) 107->108 108->109 113 multiply(·, 0.70710677) 109->113 118 multiply(·, ·) 109->118 114 erf(·) 113->114 116 multiply(·, 0.5) 114->116 117 add(0.5, ·) 116->117 117->118 119 reshape(·| newshape=[-1, 14, 3072], reverse=0) 118->119 123 nn.batch_matmul(·, ·) 119->123 121 reshape(·| newshape=..., reverse=0) 120->121 122 transpose(·| axes=[0, 2, 1]) 121->122 122->123 124 reshape(·| newshape=[1, 14, 768], reverse=0) 123->124 124->125 128 multiply(·, ·) 125->128 127->128 128->129 130 mean(·| axis=[-1], keepdims=1, exclude=0) 129->130 131 subtract(·, ·) 129->131 132 mean(·| axis=[-1], keepdims=1, exclude=0) 129->132 133 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 129->133 130->131 137 divide(·, ·) 131->137 132->133 134 sqrt(·) 133->134 136 add(·, 1e-12) 134->136 136->137 137->138 138->139 140 Tuple[...]) 139->140 141 TupleGetItem(idx=0) 140->141 141->142 143 sum(·| axis=None, keepdims=0, exclude=0) 142->143 144 Function 143->144

Finally we can take the grad. As we get a lot of let nodes, we bring it to normal form.


In [15]:
grfn = tvm.relay.transform.gradient(fn_for_gr, mode='first_order')
grfn = to_graph_normal_form(grfn)

TVM's gradient-taking returns a function that has the same parameters as the original function (in our case amended with the grad_out and dropout) and then returns a tuple of the original return and a tuple containing gradients for all inputs. The first thing we do is to drop all the gradients for grad_out and dropout which we don't need. Then we run our simplification passes.


In [16]:
# Now we have (sum(orig_out * grad_out), (grad_inp_1, ..., grad_inp_n, grad_grad_out, gr_dropout ...))
# but we only want orig_out and grad_inp_1, ..., grad_inp_n
def is_aux_input(p):
    return p.name_hint.startswith('dropout:') or p.name_hint.startswith('gr:out:')

# the gr_out and dropout parameters will have gradients computed, but we do not want that
grads_to_keep = tvm.relay.Tuple([g for p, g in zip(grfn.params, grfn.body.fields[1].fields)
                                   if not is_aux_input(p)])

assert grfn.body.fields[0].op.name == 'sum'
assert grfn.body.fields[0].args[0].op.name == 'multiply'
if isinstance(output_type, tvm.relay.TensorType):
    orig_out = grfn.body.fields[0].args[0].args[0]
else:
    assert isinstance(output_type, tvm.relay.TupleType)
    orig_out = grfn.body.fields[0].args[0].args[0].tuple_value
out_and_grad = tvm.relay.Tuple([orig_out, grads_to_keep])
out_and_grad_fn = tvm.relay.Function(grfn.params, out_and_grad)
out_and_grad_fn = infer_type(out_and_grad_fn)
out_and_grad_fn = dead_code_elimination(out_and_grad_fn)
out_and_grad_fn = eliminate_common_subexpr(out_and_grad_fn)
out_and_grad_fn = infer_type(out_and_grad_fn)
out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(LikeZapp(), out_and_grad_fn)
out_and_grad_fn = infer_type(out_and_grad_fn)
out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(ZeroZapp(), out_and_grad_fn)
out_and_grad_fn = infer_type(out_and_grad_fn)
out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(OneZapp(), out_and_grad_fn)
out_and_grad_fn = infer_type(out_and_grad_fn)
out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(OneZapp(), out_and_grad_fn)
out_and_grad_fn = infer_type(out_and_grad_fn)
out_and_grad_fn = dead_code_elimination(out_and_grad_fn)
out_and_grad_fn = eliminate_common_subexpr(out_and_grad_fn)

Now is a good time to take a look at our graph:


In [17]:
visualize(out_and_grad_fn)


Out[17]:
%3 0 input: Tensor[(1, 14, 768), float32] 31 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->31 86 add(·, ·) 0->86 1 attention_mask: Tensor[(1, 1, 1, 14), float32] 55 add(·, ·) 1->55 2 attention.self.query.weight: Tensor[(768, 768), float32] 32 transpose(·| axes=[1, 0]) 2->32 3 attention.self.query.bias: Tensor[(768,), float32] 37 add(·, ·) 3->37 4 attention.self.key.weight: Tensor[(768, 768), float32] 41 transpose(·| axes=[1, 0]) 4->41 5 attention.self.key.bias: Tensor[(768,), float32] 46 add(·, ·) 5->46 6 attention.self.value.weight: Tensor[(768, 768), float32] 61 transpose(·| axes=[1, 0]) 6->61 7 attention.self.value.bias: Tensor[(768,), float32] 66 add(·, ·) 7->66 8 attention.output.dense.weight: Tensor[(768, 768), float32] 77 transpose(·| axes=[1, 0]) 8->77 9 attention.output.dense.bias: Tensor[(768,), float32] 82 add(·, ·) 9->82 10 attention.output.LayerNorm.weight: Tensor[(768,), float32] 97 multiply(·, ·) 10->97 212 multiply(·, ·) 10->212 11 attention.output.LayerNorm.bias: Tensor[(768,), float32] 98 add(·, ·) 11->98 12 intermediate.dense.weight: Tensor[(3072, 768), float32] 100 transpose(·| axes=[1, 0]) 12->100 13 intermediate.dense.bias: Tensor[(3072,), float32] 105 add(·, ·) 13->105 14 output.dense.weight: Tensor[(768, 3072), float32] 116 transpose(·| axes=[1, 0]) 14->116 15 output.dense.bias: Tensor[(768,), float32] 121 add(·, ·) 15->121 16 output.LayerNorm.weight: Tensor[(768,), float32] 133 multiply(·, ·) 16->133 157 multiply(·, ·) 16->157 17 output.LayerNorm.bias: Tensor[(768,), float32] 134 add(·, ·) 17->134 18 gr:out:0: Tensor[(1, 14, 768), float32] 153 multiply(·, ·) 18->153 19 dropout:0: Tensor[(1, 12, 14, 14), float32] 58 multiply(·, 1.1111112) 19->58 20 dropout:1: Tensor[(1, 14, 768), float32] 84 multiply(·, 1.1111112) 20->84 21 dropout:2: Tensor[(1, 14, 768), float32] 123 multiply(·, 1.1111112) 21->123 35 nn.batch_matmul(·, ·) 31->35 44 nn.batch_matmul(·, ·) 31->44 64 nn.batch_matmul(·, ·) 31->64 300 transpose(·| axes=[0, 2, 1]) 31->300 33 reshape(·| newshape=[-1, 768, 768], reverse=0) 32->33 34 transpose(·| axes=[0, 2, 1]) 33->34 34->35 293 transpose(·| axes=[0, 2, 1]) 34->293 36 reshape(·| newshape=[1, 14, 768], reverse=0) 35->36 36->37 38 reshape(·| newshape=..., reverse=0) 37->38 39 transpose(·| axes=[0, 2, 1, 3]) 38->39 40 reshape(·| newshape=[-1, 14, 64], reverse=0) 39->40 51 nn.batch_matmul(·, ·) 40->51 278 transpose(·| axes=[0, 2, 1]) 40->278 42 reshape(·| newshape=[-1, 768, 768], reverse=0) 41->42 43 transpose(·| axes=[0, 2, 1]) 42->43 43->44 284 transpose(·| axes=[0, 2, 1]) 43->284 45 reshape(·| newshape=[1, 14, 768], reverse=0) 44->45 45->46 47 reshape(·| newshape=..., reverse=0) 46->47 48 transpose(·| axes=[0, 2, 3, 1]) 47->48 49 reshape(·| newshape=[-1, 64, 14], reverse=0) 48->49 50 transpose(·| axes=[0, 2, 1]) 49->50 50->51 288 transpose(·| axes=[0, 2, 1]) 50->288 52 reshape(·| newshape=..., reverse=0) 51->52 54 divide(·, 8.0) 52->54 54->55 56 nn.softmax(·| axis=-1) 55->56 59 multiply(·, ·) 56->59 271 multiply(·, ·) 56->271 274 multiply(·, ·) 56->274 58->59 269 multiply(·, ·) 58->269 60 reshape(·| newshape=[-1, 14, 14], reverse=0) 59->60 71 nn.batch_matmul(·, ·) 60->71 256 transpose(·| axes=[0, 2, 1]) 60->256 62 reshape(·| newshape=[-1, 768, 768], reverse=0) 61->62 63 transpose(·| axes=[0, 2, 1]) 62->63 63->64 262 transpose(·| axes=[0, 2, 1]) 63->262 65 reshape(·| newshape=[1, 14, 768], reverse=0) 64->65 65->66 67 reshape(·| newshape=..., reverse=0) 66->67 68 transpose(·| axes=[0, 2, 1, 3]) 67->68 69 reshape(·| newshape=[-1, 14, 64], reverse=0) 68->69 70 transpose(·| axes=[0, 2, 1]) 69->70 70->71 266 transpose(·| axes=[0, 2, 1]) 70->266 72 reshape(·| newshape=..., reverse=0) 71->72 73 transpose(·| axes=[0, 2, 1, 3]) 72->73 74 copy(·) 73->74 75 reshape(·| newshape=[1, 14, 768], reverse=0) 74->75 76 reshape(·| newshape=[-1, 14, 768], reverse=0) 75->76 80 nn.batch_matmul(·, ·) 76->80 322 transpose(·| axes=[0, 2, 1]) 76->322 78 reshape(·| newshape=[-1, 768, 768], reverse=0) 77->78 79 transpose(·| axes=[0, 2, 1]) 78->79 79->80 250 transpose(·| axes=[0, 2, 1]) 79->250 81 reshape(·| newshape=[1, 14, 768], reverse=0) 80->81 81->82 85 multiply(·, ·) 82->85 84->85 249 multiply(·, ·) 84->249 85->86 88 mean(·| axis=[-1], keepdims=1, exclude=0) 86->88 89 subtract(·, ·) 86->89 92 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 86->92 224 multiply(·, ·) 86->224 88->89 88->92 229 multiply(·, ·) 88->229 96 divide(·, ·) 89->96 93 sqrt(·) 92->93 220 power(·, ·) 92->220 95 add(·, 1e-12) 93->95 95->96 214 divide(·, ·) 95->214 236 divide(·, ·) 95->236 96->97 213 multiply(·, ·) 96->213 329 multiply(·, ·) 96->329 97->98 99 reshape(·| newshape=[-1, 14, 768], reverse=0) 98->99 125 add(·, ·) 98->125 103 nn.batch_matmul(·, ·) 99->103 335 transpose(·| axes=[0, 2, 1]) 99->335 101 reshape(·| newshape=..., reverse=0) 100->101 102 transpose(·| axes=[0, 2, 1]) 101->102 102->103 209 transpose(·| axes=[0, 2, 1]) 102->209 104 reshape(·| newshape=[1, 14, 3072], reverse=0) 103->104 104->105 109 multiply(·, 0.70710677) 105->109 114 multiply(·, ·) 105->114 204 multiply(·, ·) 105->204 110 erf(·) 109->110 200 negative(·) 109->200 201 multiply(·, ·) 109->201 112 multiply(·, 0.5) 110->112 113 add(0.5, ·) 112->113 113->114 197 multiply(·, ·) 113->197 115 reshape(·| newshape=[-1, 14, 3072], reverse=0) 114->115 119 nn.batch_matmul(·, ·) 115->119 343 transpose(·| axes=[0, 2, 1]) 115->343 117 reshape(·| newshape=..., reverse=0) 116->117 118 transpose(·| axes=[0, 2, 1]) 117->118 118->119 195 transpose(·| axes=[0, 2, 1]) 118->195 120 reshape(·| newshape=[1, 14, 768], reverse=0) 119->120 120->121 124 multiply(·, ·) 121->124 123->124 194 multiply(·, ·) 123->194 124->125 126 mean(·| axis=[-1], keepdims=1, exclude=0) 125->126 127 subtract(·, ·) 125->127 128 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 125->128 171 multiply(·, ·) 125->171 126->127 126->128 176 multiply(·, ·) 126->176 132 divide(·, ·) 127->132 129 sqrt(·) 128->129 167 power(·, ·) 128->167 131 add(·, 1e-12) 129->131 131->132 159 divide(·, ·) 131->159 183 divide(·, ·) 131->183 132->133 158 multiply(·, ·) 132->158 350 multiply(·, ·) 132->350 133->134 135 Tuple[...]) 134->135 356 Tuple[...]) 135->356 140 zeros([  1  14 768]| shape=[1, 14, 768], dtype=float32) 141 Tuple[...]) 140->141 142 TupleGetItem(idx=0) 141->142 154 add(·, ·) 142->154 147 ones([]| shape=[], dtype=float32) 148 expand_dims(·| axis=0, num_newaxis=1) 147->148 149 expand_dims(·| axis=1, num_newaxis=1) 148->149 150 expand_dims(·| axis=2, num_newaxis=1) 149->150 152 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 150->152 152->153 153->154 155 Tuple[...]) 154->155 156 TupleGetItem(idx=0) 155->156 156->157 156->350 354 collapse_sum_to(·, [768]| shape=[768], dtype=) 156->354 157->158 157->183 158->159 160 negative(·) 159->160 162 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 160->162 164 multiply(·, 0.5) 162->164 168 multiply(·, ·) 164->168 166 negative(0.5) 166->167 166->220 167->168 170 multiply(·, 0.0026041667) 168->170 175 multiply(-2.0, ·) 168->175 170->171 182 add(·, ·) 171->182 173 zeros([ 1 14  1]| shape=[1, 14, 1], dtype=float32) 177 add(·, ·) 173->177 188 add(·, ·) 173->188 175->176 176->177 179 multiply(·, 0.0013020834) 177->179 181 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 179->181 181->182 184 add(·, ·) 182->184 183->184 185 negative(·) 183->185 193 add(·, ·) 184->193 187 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 185->187 187->188 190 multiply(·, 0.0013020834) 188->190 192 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 190->192 192->193 193->194 211 add(·, ·) 193->211 196 nn.batch_matmul(·, ·) 194->196 342 transpose(·| axes=[0, 2, 1]) 194->342 349 collapse_sum_to(·, [768]| shape=[768], dtype=) 194->349 195->196 196->197 196->204 208 add(·, ·) 197->208 200->201 202 exp(·) 201->202 203 multiply(1.1283792, ·) 202->203 206 multiply(·, ·) 203->206 205 multiply(·, 0.5) 204->205 205->206 207 multiply(·, 0.70710677) 206->207 207->208 210 nn.batch_matmul(·, ·) 208->210 334 transpose(·| axes=[0, 2, 1]) 208->334 341 collapse_sum_to(·, [3072]| shape=[3072], dtype=) 208->341 209->210 210->211 211->212 211->329 333 collapse_sum_to(·, [768]| shape=[768], dtype=) 211->333 212->213 212->236 213->214 215 negative(·) 214->215 217 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 215->217 219 multiply(·, 0.5) 217->219 221 multiply(·, ·) 219->221 220->221 223 multiply(·, 0.0026041667) 221->223 228 multiply(-2.0, ·) 221->228 223->224 235 add(·, ·) 224->235 226 zeros([ 1 14  1]| shape=[1, 14, 1], dtype=float32) 230 add(·, ·) 226->230 241 add(·, ·) 226->241 228->229 229->230 232 multiply(·, 0.0013020834) 230->232 234 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 232->234 234->235 237 add(·, ·) 235->237 236->237 238 negative(·) 236->238 246 add(·, ·) 237->246 240 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 238->240 240->241 243 multiply(·, 0.0013020834) 241->243 245 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 243->245 245->246 246->249 265 add(·, ·) 246->265 248 zeros([  1  14 768]| shape=[1, 14, 768], dtype=float32) 264 add(·, ·) 248->264 286 add(·, ·) 248->286 295 add(·, ·) 248->295 251 nn.batch_matmul(·, ·) 249->251 321 transpose(·| axes=[0, 2, 1]) 249->321 328 collapse_sum_to(·, [768]| shape=[768], dtype=) 249->328 250->251 252 reshape(·| newshape=..., reverse=0) 251->252 253 transpose(·| axes=[0, 2, 1, 3]) 252->253 254 reshape(·| newshape=[12, 14, 64], reverse=0) 253->254 255 transpose(·| axes=[0, 2, 1]) 254->255 267 nn.batch_matmul(·, ·) 254->267 257 nn.batch_matmul(·, ·) 255->257 256->257 258 transpose(·| axes=[0, 2, 1]) 257->258 259 reshape(·| newshape=..., reverse=0) 258->259 260 transpose(·| axes=[0, 2, 1, 3]) 259->260 261 reshape(·| newshape=[1, 14, 768], reverse=0) 260->261 263 nn.batch_matmul(·, ·) 261->263 314 transpose(·| axes=[0, 2, 1]) 261->314 320 collapse_sum_to(·, [768]| shape=[768], dtype=) 261->320 262->263 263->264 264->265 287 add(·, ·) 265->287 266->267 268 reshape(·| newshape=..., reverse=0) 267->268 268->269 269->271 273 subtract(·, ·) 269->273 272 sum(·| axis=[-1], keepdims=1, exclude=0) 271->272 272->273 273->274 275 divide(·, 8.0) 274->275 298 collapse_sum_to(·, [ 1  1  1 14]| shape=[1, 1, 1, 14], dtype=) 274->298 276 reshape(·| newshape=[12, 14, 14], reverse=0) 275->276 277 transpose(·| axes=[0, 2, 1]) 276->277 289 nn.batch_matmul(·, ·) 276->289 279 nn.batch_matmul(·, ·) 277->279 278->279 280 transpose(·| axes=[0, 2, 1]) 279->280 281 reshape(·| newshape=..., reverse=0) 280->281 282 transpose(·| axes=[0, 3, 1, 2]) 281->282 283 reshape(·| newshape=[1, 14, 768], reverse=0) 282->283 285 nn.batch_matmul(·, ·) 283->285 307 transpose(·| axes=[0, 2, 1]) 283->307 313 collapse_sum_to(·, [768]| shape=[768], dtype=) 283->313 284->285 285->286 286->287 296 add(·, ·) 287->296 288->289 290 reshape(·| newshape=..., reverse=0) 289->290 291 transpose(·| axes=[0, 2, 1, 3]) 290->291 292 reshape(·| newshape=[1, 14, 768], reverse=0) 291->292 294 nn.batch_matmul(·, ·) 292->294 299 transpose(·| axes=[0, 2, 1]) 292->299 306 collapse_sum_to(·, [768]| shape=[768], dtype=) 292->306 293->294 294->295 295->296 355 Tuple[...]) 296->355 298->355 301 nn.batch_matmul(·, ·) 299->301 300->301 308 nn.batch_matmul(·, ·) 300->308 315 nn.batch_matmul(·, ·) 300->315 302 transpose(·| axes=[0, 2, 1]) 301->302 303 reshape(·| newshape=[768, 768], reverse=0) 302->303 304 transpose(·| axes=[1, 0]) 303->304 304->355 306->355 307->308 309 transpose(·| axes=[0, 2, 1]) 308->309 310 reshape(·| newshape=[768, 768], reverse=0) 309->310 311 transpose(·| axes=[1, 0]) 310->311 311->355 313->355 314->315 316 transpose(·| axes=[0, 2, 1]) 315->316 317 reshape(·| newshape=[768, 768], reverse=0) 316->317 318 transpose(·| axes=[1, 0]) 317->318 318->355 320->355 323 nn.batch_matmul(·, ·) 321->323 322->323 324 transpose(·| axes=[0, 2, 1]) 323->324 325 reshape(·| newshape=[768, 768], reverse=0) 324->325 326 transpose(·| axes=[1, 0]) 325->326 326->355 328->355 331 collapse_sum_to(·, [768]| shape=[768], dtype=) 329->331 331->355 333->355 336 nn.batch_matmul(·, ·) 334->336 335->336 337 transpose(·| axes=[0, 2, 1]) 336->337 338 reshape(·| newshape=[768, 3072], reverse=0) 337->338 339 transpose(·| axes=[1, 0]) 338->339 339->355 341->355 344 nn.batch_matmul(·, ·) 342->344 343->344 345 transpose(·| axes=[0, 2, 1]) 344->345 346 reshape(·| newshape=[3072, 768], reverse=0) 345->346 347 transpose(·| axes=[1, 0]) 346->347 347->355 349->355 352 collapse_sum_to(·, [768]| shape=[768], dtype=) 350->352 352->355 354->355 355->356 357 Function 356->357

But in PyTorch, we first compute the forward and then the backwards, so we have to take out the saw and split our graph. One of the difficult problems is what to do with things computed for both forward and backward. It is a hard problem, related to the MinCut problem.

Our extremal options could be:

  • One could only keep the inputs and recompute everything as needed.
  • If we had a salar output, we could compute the gradient and multiply with the derivative of the later layers on backward. (Loss functions might do that.) This does not, however, work for non-scalar tensor outputs.

We'll do the following: We compute the forward normally, but we keep all things that will be used in the backward. This is too much, unfortunately, and it is very likely the reason we don't see an end to end speedup. We'll discuss some potential heuristics below.

We use a coloring here. First we color all nodes of the forward computation in red. Then we traverse the gradient calculation and then color the nodes it needs from the backward blue. This gives us a chance to show off the attribute support in our visualization.

A bit of (PyTorch) terminology: When we have a function $Layer : x \mapsto y$ followed by some $Loss : y \mapsto l \in \mathbb{R}$, the backward is $BackwardOfLayer : grad\_out \mapsto grad\_in$ with $grad\_out = dl/dy$ and $grad\_in = dl/dx$.


In [18]:
orig_out = out_and_grad_fn.body.fields[0]
grad_ins = out_and_grad_fn.body.fields[1]

color_dict = {}
def color(n, c):
    if n in color_dict:
        return
    color_dict[n] = c
    for a in getattr(n, 'args', []):
        color(a, c)
    for a in getattr(n, 'fields', []):
        color(a, c)
    for nam in ('body', 'tuple_value'):
        b = getattr(n, nam, None)
        if b is not None:
            color(b, c)

color(orig_out, {'color': 'red'})
seen = set()
def color_crossings(n, c):
    if n in seen:
        return
    seen.add(n)
    if n in color_dict:
        color_dict[n] = c
        return
    for a in getattr(n, 'args', []):
        color_crossings(a, c)
    for a in getattr(n, 'fields', []):
        color_crossings(a, c)
    for nam in ('body', 'tuple_value'):
        b = getattr(n, nam, None)
        if b is not None:
            color_crossings(b, c)

color_crossings(grad_ins, {'color': 'blue'})

In [19]:
visualize(out_and_grad_fn, node_attr_dict=color_dict)


Out[19]:
%3 0 input: Tensor[(1, 14, 768), float32] 31 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->31 86 add(·, ·) 0->86 1 attention_mask: Tensor[(1, 1, 1, 14), float32] 55 add(·, ·) 1->55 2 attention.self.query.weight: Tensor[(768, 768), float32] 32 transpose(·| axes=[1, 0]) 2->32 3 attention.self.query.bias: Tensor[(768,), float32] 37 add(·, ·) 3->37 4 attention.self.key.weight: Tensor[(768, 768), float32] 41 transpose(·| axes=[1, 0]) 4->41 5 attention.self.key.bias: Tensor[(768,), float32] 46 add(·, ·) 5->46 6 attention.self.value.weight: Tensor[(768, 768), float32] 61 transpose(·| axes=[1, 0]) 6->61 7 attention.self.value.bias: Tensor[(768,), float32] 66 add(·, ·) 7->66 8 attention.output.dense.weight: Tensor[(768, 768), float32] 77 transpose(·| axes=[1, 0]) 8->77 9 attention.output.dense.bias: Tensor[(768,), float32] 82 add(·, ·) 9->82 10 attention.output.LayerNorm.weight: Tensor[(768,), float32] 97 multiply(·, ·) 10->97 212 multiply(·, ·) 10->212 11 attention.output.LayerNorm.bias: Tensor[(768,), float32] 98 add(·, ·) 11->98 12 intermediate.dense.weight: Tensor[(3072, 768), float32] 100 transpose(·| axes=[1, 0]) 12->100 13 intermediate.dense.bias: Tensor[(3072,), float32] 105 add(·, ·) 13->105 14 output.dense.weight: Tensor[(768, 3072), float32] 116 transpose(·| axes=[1, 0]) 14->116 15 output.dense.bias: Tensor[(768,), float32] 121 add(·, ·) 15->121 16 output.LayerNorm.weight: Tensor[(768,), float32] 133 multiply(·, ·) 16->133 157 multiply(·, ·) 16->157 17 output.LayerNorm.bias: Tensor[(768,), float32] 134 add(·, ·) 17->134 18 gr:out:0: Tensor[(1, 14, 768), float32] 153 multiply(·, ·) 18->153 19 dropout:0: Tensor[(1, 12, 14, 14), float32] 58 multiply(·, 1.1111112) 19->58 20 dropout:1: Tensor[(1, 14, 768), float32] 84 multiply(·, 1.1111112) 20->84 21 dropout:2: Tensor[(1, 14, 768), float32] 123 multiply(·, 1.1111112) 21->123 35 nn.batch_matmul(·, ·) 31->35 44 nn.batch_matmul(·, ·) 31->44 64 nn.batch_matmul(·, ·) 31->64 300 transpose(·| axes=[0, 2, 1]) 31->300 33 reshape(·| newshape=[-1, 768, 768], reverse=0) 32->33 34 transpose(·| axes=[0, 2, 1]) 33->34 34->35 293 transpose(·| axes=[0, 2, 1]) 34->293 36 reshape(·| newshape=[1, 14, 768], reverse=0) 35->36 36->37 38 reshape(·| newshape=..., reverse=0) 37->38 39 transpose(·| axes=[0, 2, 1, 3]) 38->39 40 reshape(·| newshape=[-1, 14, 64], reverse=0) 39->40 51 nn.batch_matmul(·, ·) 40->51 278 transpose(·| axes=[0, 2, 1]) 40->278 42 reshape(·| newshape=[-1, 768, 768], reverse=0) 41->42 43 transpose(·| axes=[0, 2, 1]) 42->43 43->44 284 transpose(·| axes=[0, 2, 1]) 43->284 45 reshape(·| newshape=[1, 14, 768], reverse=0) 44->45 45->46 47 reshape(·| newshape=..., reverse=0) 46->47 48 transpose(·| axes=[0, 2, 3, 1]) 47->48 49 reshape(·| newshape=[-1, 64, 14], reverse=0) 48->49 50 transpose(·| axes=[0, 2, 1]) 49->50 50->51 288 transpose(·| axes=[0, 2, 1]) 50->288 52 reshape(·| newshape=..., reverse=0) 51->52 54 divide(·, 8.0) 52->54 54->55 56 nn.softmax(·| axis=-1) 55->56 59 multiply(·, ·) 56->59 271 multiply(·, ·) 56->271 274 multiply(·, ·) 56->274 58->59 269 multiply(·, ·) 58->269 60 reshape(·| newshape=[-1, 14, 14], reverse=0) 59->60 71 nn.batch_matmul(·, ·) 60->71 256 transpose(·| axes=[0, 2, 1]) 60->256 62 reshape(·| newshape=[-1, 768, 768], reverse=0) 61->62 63 transpose(·| axes=[0, 2, 1]) 62->63 63->64 262 transpose(·| axes=[0, 2, 1]) 63->262 65 reshape(·| newshape=[1, 14, 768], reverse=0) 64->65 65->66 67 reshape(·| newshape=..., reverse=0) 66->67 68 transpose(·| axes=[0, 2, 1, 3]) 67->68 69 reshape(·| newshape=[-1, 14, 64], reverse=0) 68->69 70 transpose(·| axes=[0, 2, 1]) 69->70 70->71 266 transpose(·| axes=[0, 2, 1]) 70->266 72 reshape(·| newshape=..., reverse=0) 71->72 73 transpose(·| axes=[0, 2, 1, 3]) 72->73 74 copy(·) 73->74 75 reshape(·| newshape=[1, 14, 768], reverse=0) 74->75 76 reshape(·| newshape=[-1, 14, 768], reverse=0) 75->76 80 nn.batch_matmul(·, ·) 76->80 322 transpose(·| axes=[0, 2, 1]) 76->322 78 reshape(·| newshape=[-1, 768, 768], reverse=0) 77->78 79 transpose(·| axes=[0, 2, 1]) 78->79 79->80 250 transpose(·| axes=[0, 2, 1]) 79->250 81 reshape(·| newshape=[1, 14, 768], reverse=0) 80->81 81->82 85 multiply(·, ·) 82->85 84->85 249 multiply(·, ·) 84->249 85->86 88 mean(·| axis=[-1], keepdims=1, exclude=0) 86->88 89 subtract(·, ·) 86->89 92 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 86->92 224 multiply(·, ·) 86->224 88->89 88->92 229 multiply(·, ·) 88->229 96 divide(·, ·) 89->96 93 sqrt(·) 92->93 220 power(·, ·) 92->220 95 add(·, 1e-12) 93->95 95->96 214 divide(·, ·) 95->214 236 divide(·, ·) 95->236 96->97 213 multiply(·, ·) 96->213 329 multiply(·, ·) 96->329 97->98 99 reshape(·| newshape=[-1, 14, 768], reverse=0) 98->99 125 add(·, ·) 98->125 103 nn.batch_matmul(·, ·) 99->103 335 transpose(·| axes=[0, 2, 1]) 99->335 101 reshape(·| newshape=..., reverse=0) 100->101 102 transpose(·| axes=[0, 2, 1]) 101->102 102->103 209 transpose(·| axes=[0, 2, 1]) 102->209 104 reshape(·| newshape=[1, 14, 3072], reverse=0) 103->104 104->105 109 multiply(·, 0.70710677) 105->109 114 multiply(·, ·) 105->114 204 multiply(·, ·) 105->204 110 erf(·) 109->110 200 negative(·) 109->200 201 multiply(·, ·) 109->201 112 multiply(·, 0.5) 110->112 113 add(0.5, ·) 112->113 113->114 197 multiply(·, ·) 113->197 115 reshape(·| newshape=[-1, 14, 3072], reverse=0) 114->115 119 nn.batch_matmul(·, ·) 115->119 343 transpose(·| axes=[0, 2, 1]) 115->343 117 reshape(·| newshape=..., reverse=0) 116->117 118 transpose(·| axes=[0, 2, 1]) 117->118 118->119 195 transpose(·| axes=[0, 2, 1]) 118->195 120 reshape(·| newshape=[1, 14, 768], reverse=0) 119->120 120->121 124 multiply(·, ·) 121->124 123->124 194 multiply(·, ·) 123->194 124->125 126 mean(·| axis=[-1], keepdims=1, exclude=0) 125->126 127 subtract(·, ·) 125->127 128 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 125->128 171 multiply(·, ·) 125->171 126->127 126->128 176 multiply(·, ·) 126->176 132 divide(·, ·) 127->132 129 sqrt(·) 128->129 167 power(·, ·) 128->167 131 add(·, 1e-12) 129->131 131->132 159 divide(·, ·) 131->159 183 divide(·, ·) 131->183 132->133 158 multiply(·, ·) 132->158 350 multiply(·, ·) 132->350 133->134 135 Tuple[...]) 134->135 356 Tuple[...]) 135->356 140 zeros([  1  14 768]| shape=[1, 14, 768], dtype=float32) 141 Tuple[...]) 140->141 142 TupleGetItem(idx=0) 141->142 154 add(·, ·) 142->154 147 ones([]| shape=[], dtype=float32) 148 expand_dims(·| axis=0, num_newaxis=1) 147->148 149 expand_dims(·| axis=1, num_newaxis=1) 148->149 150 expand_dims(·| axis=2, num_newaxis=1) 149->150 152 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 150->152 152->153 153->154 155 Tuple[...]) 154->155 156 TupleGetItem(idx=0) 155->156 156->157 156->350 354 collapse_sum_to(·, [768]| shape=[768], dtype=) 156->354 157->158 157->183 158->159 160 negative(·) 159->160 162 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 160->162 164 multiply(·, 0.5) 162->164 168 multiply(·, ·) 164->168 166 negative(0.5) 166->167 166->220 167->168 170 multiply(·, 0.0026041667) 168->170 175 multiply(-2.0, ·) 168->175 170->171 182 add(·, ·) 171->182 173 zeros([ 1 14  1]| shape=[1, 14, 1], dtype=float32) 177 add(·, ·) 173->177 188 add(·, ·) 173->188 175->176 176->177 179 multiply(·, 0.0013020834) 177->179 181 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 179->181 181->182 184 add(·, ·) 182->184 183->184 185 negative(·) 183->185 193 add(·, ·) 184->193 187 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 185->187 187->188 190 multiply(·, 0.0013020834) 188->190 192 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 190->192 192->193 193->194 211 add(·, ·) 193->211 196 nn.batch_matmul(·, ·) 194->196 342 transpose(·| axes=[0, 2, 1]) 194->342 349 collapse_sum_to(·, [768]| shape=[768], dtype=) 194->349 195->196 196->197 196->204 208 add(·, ·) 197->208 200->201 202 exp(·) 201->202 203 multiply(1.1283792, ·) 202->203 206 multiply(·, ·) 203->206 205 multiply(·, 0.5) 204->205 205->206 207 multiply(·, 0.70710677) 206->207 207->208 210 nn.batch_matmul(·, ·) 208->210 334 transpose(·| axes=[0, 2, 1]) 208->334 341 collapse_sum_to(·, [3072]| shape=[3072], dtype=) 208->341 209->210 210->211 211->212 211->329 333 collapse_sum_to(·, [768]| shape=[768], dtype=) 211->333 212->213 212->236 213->214 215 negative(·) 214->215 217 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 215->217 219 multiply(·, 0.5) 217->219 221 multiply(·, ·) 219->221 220->221 223 multiply(·, 0.0026041667) 221->223 228 multiply(-2.0, ·) 221->228 223->224 235 add(·, ·) 224->235 226 zeros([ 1 14  1]| shape=[1, 14, 1], dtype=float32) 230 add(·, ·) 226->230 241 add(·, ·) 226->241 228->229 229->230 232 multiply(·, 0.0013020834) 230->232 234 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 232->234 234->235 237 add(·, ·) 235->237 236->237 238 negative(·) 236->238 246 add(·, ·) 237->246 240 collapse_sum_to(·, [ 1 14  1]| shape=[1, 14, 1], dtype=) 238->240 240->241 243 multiply(·, 0.0013020834) 241->243 245 broadcast_to(·, [  1  14 768]| shape=[1, 14, 768], dtype=) 243->245 245->246 246->249 265 add(·, ·) 246->265 248 zeros([  1  14 768]| shape=[1, 14, 768], dtype=float32) 264 add(·, ·) 248->264 286 add(·, ·) 248->286 295 add(·, ·) 248->295 251 nn.batch_matmul(·, ·) 249->251 321 transpose(·| axes=[0, 2, 1]) 249->321 328 collapse_sum_to(·, [768]| shape=[768], dtype=) 249->328 250->251 252 reshape(·| newshape=..., reverse=0) 251->252 253 transpose(·| axes=[0, 2, 1, 3]) 252->253 254 reshape(·| newshape=[12, 14, 64], reverse=0) 253->254 255 transpose(·| axes=[0, 2, 1]) 254->255 267 nn.batch_matmul(·, ·) 254->267 257 nn.batch_matmul(·, ·) 255->257 256->257 258 transpose(·| axes=[0, 2, 1]) 257->258 259 reshape(·| newshape=..., reverse=0) 258->259 260 transpose(·| axes=[0, 2, 1, 3]) 259->260 261 reshape(·| newshape=[1, 14, 768], reverse=0) 260->261 263 nn.batch_matmul(·, ·) 261->263 314 transpose(·| axes=[0, 2, 1]) 261->314 320 collapse_sum_to(·, [768]| shape=[768], dtype=) 261->320 262->263 263->264 264->265 287 add(·, ·) 265->287 266->267 268 reshape(·| newshape=..., reverse=0) 267->268 268->269 269->271 273 subtract(·, ·) 269->273 272 sum(·| axis=[-1], keepdims=1, exclude=0) 271->272 272->273 273->274 275 divide(·, 8.0) 274->275 298 collapse_sum_to(·, [ 1  1  1 14]| shape=[1, 1, 1, 14], dtype=) 274->298 276 reshape(·| newshape=[12, 14, 14], reverse=0) 275->276 277 transpose(·| axes=[0, 2, 1]) 276->277 289 nn.batch_matmul(·, ·) 276->289 279 nn.batch_matmul(·, ·) 277->279 278->279 280 transpose(·| axes=[0, 2, 1]) 279->280 281 reshape(·| newshape=..., reverse=0) 280->281 282 transpose(·| axes=[0, 3, 1, 2]) 281->282 283 reshape(·| newshape=[1, 14, 768], reverse=0) 282->283 285 nn.batch_matmul(·, ·) 283->285 307 transpose(·| axes=[0, 2, 1]) 283->307 313 collapse_sum_to(·, [768]| shape=[768], dtype=) 283->313 284->285 285->286 286->287 296 add(·, ·) 287->296 288->289 290 reshape(·| newshape=..., reverse=0) 289->290 291 transpose(·| axes=[0, 2, 1, 3]) 290->291 292 reshape(·| newshape=[1, 14, 768], reverse=0) 291->292 294 nn.batch_matmul(·, ·) 292->294 299 transpose(·| axes=[0, 2, 1]) 292->299 306 collapse_sum_to(·, [768]| shape=[768], dtype=) 292->306 293->294 294->295 295->296 355 Tuple[...]) 296->355 298->355 301 nn.batch_matmul(·, ·) 299->301 300->301 308 nn.batch_matmul(·, ·) 300->308 315 nn.batch_matmul(·, ·) 300->315 302 transpose(·| axes=[0, 2, 1]) 301->302 303 reshape(·| newshape=[768, 768], reverse=0) 302->303 304 transpose(·| axes=[1, 0]) 303->304 304->355 306->355 307->308 309 transpose(·| axes=[0, 2, 1]) 308->309 310 reshape(·| newshape=[768, 768], reverse=0) 309->310 311 transpose(·| axes=[1, 0]) 310->311 311->355 313->355 314->315 316 transpose(·| axes=[0, 2, 1]) 315->316 317 reshape(·| newshape=[768, 768], reverse=0) 316->317 318 transpose(·| axes=[1, 0]) 317->318 318->355 320->355 323 nn.batch_matmul(·, ·) 321->323 322->323 324 transpose(·| axes=[0, 2, 1]) 323->324 325 reshape(·| newshape=[768, 768], reverse=0) 324->325 326 transpose(·| axes=[1, 0]) 325->326 326->355 328->355 331 collapse_sum_to(·, [768]| shape=[768], dtype=) 329->331 331->355 333->355 336 nn.batch_matmul(·, ·) 334->336 335->336 337 transpose(·| axes=[0, 2, 1]) 336->337 338 reshape(·| newshape=[768, 3072], reverse=0) 337->338 339 transpose(·| axes=[1, 0]) 338->339 339->355 341->355 344 nn.batch_matmul(·, ·) 342->344 343->344 345 transpose(·| axes=[0, 2, 1]) 344->345 346 reshape(·| newshape=[3072, 768], reverse=0) 345->346 347 transpose(·| axes=[1, 0]) 346->347 347->355 349->355 352 collapse_sum_to(·, [768]| shape=[768], dtype=) 350->352 352->355 354->355 355->356 357 Function 356->357

Now we can split the function as described above. We collect the blue nodes as to capture - but constants will just be duplicated and inputs (Var nodes) need to be treated separately.


In [20]:
nodes_to_capture = [n for n, v in color_dict.items() 
                    if v['color'] == 'blue' and not isinstance(n, (tvm.relay.Constant, tvm.relay.Var))]
capture_tup = tvm.relay.Tuple(nodes_to_capture)
nodes_to_capture_idx = {n:i for i, n in enumerate(nodes_to_capture)}
capture_vars = [tvm.relay.var(f"input:captures:{i}", type_annotation=nodes_to_capture[i].checked_type)
                for i, n in enumerate(nodes_to_capture)]

grads_in = out_and_grad_fn.body.fields[1]

Now we can split out the backward, replacing all the blue nodes with variables.


In [21]:
needed_vars = set()
class GradientOnlyMutator(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()

    def visit_var(self, var):
        needed_vars.add(var)
        return var

    def visit(self, expr):
        if expr in nodes_to_capture_idx:
            return capture_vars[nodes_to_capture_idx[expr]]
        return super().visit(expr)
    
grads_in_only = GradientOnlyMutator().visit(grads_in)
gr_only_fn = tvm.relay.Function(sorted(needed_vars) + capture_vars, grads_in_only)

# TODO: check against output of original
fn_for_gr_input_names = {p.name_hint for p in fn_for_gr.params}
needed_var_names = {v.name_hint for v in needed_vars}

assert needed_var_names <= fn_for_gr_input_names
inputs_to_keep = [n for n in needed_vars if not is_aux_input(n)]

Next we take the forward and amend it to also return the required intermediates.


In [22]:
capture_tup = tvm.relay.Tuple([n for n in nodes_to_capture])
fw_and_cap_params = [p for p in out_and_grad_fn.params if not p.name_hint.startswith('gr:out:')]

fw_and_cap_fn = tvm.relay.Function(fw_and_cap_params,
                                   tvm.relay.Tuple((out_and_grad_fn.body.fields[0],) + (capture_tup,)))
visualize(fw_and_cap_fn)


Out[22]:
%3 0 input: Tensor[(1, 14, 768), float32] 30 reshape(·| newshape=[-1, 14, 768], reverse=0) 0->30 85 add(·, ·) 0->85 1 attention_mask: Tensor[(1, 1, 1, 14), float32] 54 add(·, ·) 1->54 2 attention.self.query.weight: Tensor[(768, 768), float32] 31 transpose(·| axes=[1, 0]) 2->31 3 attention.self.query.bias: Tensor[(768,), float32] 36 add(·, ·) 3->36 4 attention.self.key.weight: Tensor[(768, 768), float32] 40 transpose(·| axes=[1, 0]) 4->40 5 attention.self.key.bias: Tensor[(768,), float32] 45 add(·, ·) 5->45 6 attention.self.value.weight: Tensor[(768, 768), float32] 60 transpose(·| axes=[1, 0]) 6->60 7 attention.self.value.bias: Tensor[(768,), float32] 65 add(·, ·) 7->65 8 attention.output.dense.weight: Tensor[(768, 768), float32] 76 transpose(·| axes=[1, 0]) 8->76 9 attention.output.dense.bias: Tensor[(768,), float32] 81 add(·, ·) 9->81 10 attention.output.LayerNorm.weight: Tensor[(768,), float32] 96 multiply(·, ·) 10->96 11 attention.output.LayerNorm.bias: Tensor[(768,), float32] 97 add(·, ·) 11->97 12 intermediate.dense.weight: Tensor[(3072, 768), float32] 99 transpose(·| axes=[1, 0]) 12->99 13 intermediate.dense.bias: Tensor[(3072,), float32] 104 add(·, ·) 13->104 14 output.dense.weight: Tensor[(768, 3072), float32] 115 transpose(·| axes=[1, 0]) 14->115 15 output.dense.bias: Tensor[(768,), float32] 120 add(·, ·) 15->120 16 output.LayerNorm.weight: Tensor[(768,), float32] 132 multiply(·, ·) 16->132 17 output.LayerNorm.bias: Tensor[(768,), float32] 133 add(·, ·) 17->133 18 dropout:0: Tensor[(1, 12, 14, 14), float32] 57 multiply(·, 1.1111112) 18->57 19 dropout:1: Tensor[(1, 14, 768), float32] 83 multiply(·, 1.1111112) 19->83 20 dropout:2: Tensor[(1, 14, 768), float32] 122 multiply(·, 1.1111112) 20->122 34 nn.batch_matmul(·, ·) 30->34 43 nn.batch_matmul(·, ·) 30->43 63 nn.batch_matmul(·, ·) 30->63 135 Tuple[...]) 30->135 32 reshape(·| newshape=[-1, 768, 768], reverse=0) 31->32 33 transpose(·| axes=[0, 2, 1]) 32->33 33->34 33->135 35 reshape(·| newshape=[1, 14, 768], reverse=0) 34->35 35->36 37 reshape(·| newshape=..., reverse=0) 36->37 38 transpose(·| axes=[0, 2, 1, 3]) 37->38 39 reshape(·| newshape=[-1, 14, 64], reverse=0) 38->39 50 nn.batch_matmul(·, ·) 39->50 39->135 41 reshape(·| newshape=[-1, 768, 768], reverse=0) 40->41 42 transpose(·| axes=[0, 2, 1]) 41->42 42->43 42->135 44 reshape(·| newshape=[1, 14, 768], reverse=0) 43->44 44->45 46 reshape(·| newshape=..., reverse=0) 45->46 47 transpose(·| axes=[0, 2, 3, 1]) 46->47 48 reshape(·| newshape=[-1, 64, 14], reverse=0) 47->48 49 transpose(·| axes=[0, 2, 1]) 48->49 49->50 49->135 51 reshape(·| newshape=..., reverse=0) 50->51 53 divide(·, 8.0) 51->53 53->54 55 nn.softmax(·| axis=-1) 54->55 58 multiply(·, ·) 55->58 55->135 57->58 57->135 59 reshape(·| newshape=[-1, 14, 14], reverse=0) 58->59 70 nn.batch_matmul(·, ·) 59->70 59->135 61 reshape(·| newshape=[-1, 768, 768], reverse=0) 60->61 62 transpose(·| axes=[0, 2, 1]) 61->62 62->63 62->135 64 reshape(·| newshape=[1, 14, 768], reverse=0) 63->64 64->65 66 reshape(·| newshape=..., reverse=0) 65->66 67 transpose(·| axes=[0, 2, 1, 3]) 66->67 68 reshape(·| newshape=[-1, 14, 64], reverse=0) 67->68 69 transpose(·| axes=[0, 2, 1]) 68->69 69->70 69->135 71 reshape(·| newshape=..., reverse=0) 70->71 72 transpose(·| axes=[0, 2, 1, 3]) 71->72 73 copy(·) 72->73 74 reshape(·| newshape=[1, 14, 768], reverse=0) 73->74 75 reshape(·| newshape=[-1, 14, 768], reverse=0) 74->75 79 nn.batch_matmul(·, ·) 75->79 75->135 77 reshape(·| newshape=[-1, 768, 768], reverse=0) 76->77 78 transpose(·| axes=[0, 2, 1]) 77->78 78->79 78->135 80 reshape(·| newshape=[1, 14, 768], reverse=0) 79->80 80->81 84 multiply(·, ·) 81->84 83->84 83->135 84->85 87 mean(·| axis=[-1], keepdims=1, exclude=0) 85->87 88 subtract(·, ·) 85->88 91 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 85->91 85->135 87->88 87->91 87->135 95 divide(·, ·) 88->95 92 sqrt(·) 91->92 91->135 94 add(·, 1e-12) 92->94 94->95 94->135 95->96 95->135 96->97 98 reshape(·| newshape=[-1, 14, 768], reverse=0) 97->98 124 add(·, ·) 97->124 102 nn.batch_matmul(·, ·) 98->102 98->135 100 reshape(·| newshape=..., reverse=0) 99->100 101 transpose(·| axes=[0, 2, 1]) 100->101 101->102 101->135 103 reshape(·| newshape=[1, 14, 3072], reverse=0) 102->103 103->104 108 multiply(·, 0.70710677) 104->108 113 multiply(·, ·) 104->113 104->135 109 erf(·) 108->109 108->135 111 multiply(·, 0.5) 109->111 112 add(0.5, ·) 111->112 112->113 112->135 114 reshape(·| newshape=[-1, 14, 3072], reverse=0) 113->114 118 nn.batch_matmul(·, ·) 114->118 114->135 116 reshape(·| newshape=..., reverse=0) 115->116 117 transpose(·| axes=[0, 2, 1]) 116->117 117->118 117->135 119 reshape(·| newshape=[1, 14, 768], reverse=0) 118->119 119->120 123 multiply(·, ·) 120->123 122->123 122->135 123->124 125 mean(·| axis=[-1], keepdims=1, exclude=0) 124->125 126 subtract(·, ·) 124->126 127 variance(·, ·| axis=[-1], keepdims=1, exclude=0) 124->127 124->135 125->126 125->127 125->135 131 divide(·, ·) 126->131 128 sqrt(·) 127->128 127->135 130 add(·, 1e-12) 128->130 130->131 130->135 131->132 131->135 132->133 134 Tuple[...]) 133->134 136 Tuple[...]) 134->136 135->136 137 Function 136->137

TVM cannot return nested tuples, so we flatten the output in the function. Again we differentiate between tensor-valued functions and tuple valued ones (i.e. those returning potentially multiple tensors).


In [23]:
if isinstance(fn.body, tvm.relay.Tuple):
    # tuple of tensors output
    fw_and_cap_fn_flattened = tvm.relay.Function(fw_and_cap_fn.params, tvm.relay.Tuple(list(fw_and_cap_fn.body.fields[0].fields) # or single tensor
                                                + list(fw_and_cap_fn.body.fields[1].fields)))
else:
    # single tensor output
    fw_and_cap_fn_flattened = tvm.relay.Function(fw_and_cap_fn.params, tvm.relay.Tuple([fw_and_cap_fn.body.fields[0]]
                                                + list(fw_and_cap_fn.body.fields[1].fields)))

And at last, we can let TVM do its magic and compile our functions.


In [24]:
target = 'rocm -model=gfx906'
target_host = 'llvm'
ctx = tvm.context(target)

fw_and_cap_mod = tvm.IRModule({"main": fw_and_cap_fn_flattened})
with tvm.transform.PassContext(opt_level=3):
    graph, lib, params = tvm.relay.build(fw_and_cap_mod,
                                         target=target,
                                         target_host=target_host,
                                         params={})
fw_and_cap_compiled_module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
fw_and_cap_compiled_module.set_input(**params)

gr_only_mod = tvm.IRModule({"main": gr_only_fn})
with tvm.transform.PassContext(opt_level=3):
    graph, lib, params = tvm.relay.build(gr_only_mod,
                                     target=target,
                                     target_host=target_host,
                                     params={})
gr_only_compiled_module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
gr_only_compiled_module.set_input(**params) # we do have funny const tensors from TVM :/


WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (1, 14, 3072), 'float32'), ('TENSOR', (1, 768, 3072), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (1, 14, 768), 'float32'), ('TENSOR', (1, 3072, 768), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (1, 14, 768), 'float32'), ('TENSOR', (1, 768, 768), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (12, 14, 14), 'float32'), ('TENSOR', (12, 64, 14), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (12, 14, 64), 'float32'), ('TENSOR', (12, 14, 64), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (12, 64, 14), 'float32'), ('TENSOR', (12, 14, 14), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (1, 768, 14), 'float32'), ('TENSOR', (1, 768, 14), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (1, 3072, 14), 'float32'), ('TENSOR', (1, 768, 14), 'float32')). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=rocm -keys=rocm,gpu -max_num_threads=256 -model=gfx906 -thread_warp_size=64, workload=('batch_matmul.cuda', ('TENSOR', (1, 768, 14), 'float32'), ('TENSOR', (1, 3072, 14), 'float32')). A fallback configuration is used, which may bring great performance regression.

Time to give it a spin. We define convenience functions to move tensors between PyTorch and TVM and get the model parameters as a TVM dictionary.


In [25]:
def tensor_to_tvm(t):
    return tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(t))
def tensor_from_tvm(a):
    return(torch.utils.dlpack.from_dlpack(a.to_dlpack()))

debug_wrap.wrapped.cuda()
traced_module.cuda()

model_params_tvm = {k: tensor_to_tvm(v) for k, v in debug_wrap.wrapped.state_dict().items()}

Similarly, we get the inputs on the GPU in PyTorch and TVM.


In [26]:
inp_c = [i.cuda() for i in debug_wrap.DEBUG_INP[:2]]
inp_tvm = [tensor_to_tvm(i) for i in inp_c]

We need to deal with the dropout. It will turn out that our record of the dropout random draws happens in the same order as the dropout in the model. We did a depth-first search on the computational graph to find them and if the values of the the dropout are connected in the graph rather than being on independent branches, this will be the order in which PyTorch draws the matrices, too. If not, good luck fiddeling with the order.


In [27]:
dropout_info


Out[27]:
{'dropout:0': (0.1, TensorType([1, 12, 14, 14], float32)),
 'dropout:1': (0.1, TensorType([1, 14, 768], float32)),
 'dropout:2': (0.1, TensorType([1, 14, 768], float32))}

In [28]:
torch.manual_seed(12345)
drop_c = {}
for k in dropout_info.keys(): # we don't know the order
    p, typ = dropout_info[k]
    drop_c[k] = torch.nn.functional.dropout(torch.ones([int(i) for i in typ.shape], 
                                              dtype=getattr(torch, typ.dtype), device="cuda"), p=p)*(1-p)

drop_tvm = {n: tensor_to_tvm(t) for n, t in drop_c.items()}

Now we can run the forward.


In [29]:
fw_and_cap_compiled_module.set_input('input', inp_tvm[0])
fw_and_cap_compiled_module.set_input('attention_mask', inp_tvm[1])
fw_and_cap_compiled_module.set_input(**model_params_tvm)
fw_and_cap_compiled_module.set_input(**drop_tvm)
fw_and_cap_compiled_module.run()

And we can compare the output to PyTorch's:


In [30]:
torch.manual_seed(12345)
debug_wrap.wrapped.train()
numpy.abs(fw_and_cap_compiled_module.get_output(0).asnumpy()-debug_wrap.wrapped(*inp_c)[0].detach().cpu().numpy()).max()


Out[30]:
2.026558e-06

Supergood. Let's also try the backward. We generate a grad_out, set all the variables and run the backward model and run the backward model


In [31]:
gr_out_c = torch.randn(debug_wrap.DEBUG_OUT[0].shape, device="cuda", dtype=debug_wrap.DEBUG_OUT[0].dtype)

In [32]:
num_captures = len(capture_vars)
num_regular_outputs = len(fw_and_cap_fn_flattened.body.fields) - num_captures
captured_values = {v.name_hint: fw_and_cap_compiled_module.get_output(num_regular_outputs + i) for i, v in enumerate(capture_vars)}

#gr_only_compiled_module.set_input('input', inp_tvm[0])
#gr_only_compiled_module.set_input('attention_mask', inp_tvm[1])
gr_only_compiled_module.set_input(**drop_tvm)
gr_only_compiled_module.set_input(**model_params_tvm)
gr_only_compiled_module.set_input(**captured_values)
gr_only_compiled_module.set_input('gr:out:0', tensor_to_tvm(gr_out_c))
gr_only_compiled_module.run()

On the PyTorch side, it is easiest to re-run the forward (remembering to reset the random seed) and get the grads.


In [33]:
torch.manual_seed(12345)
debug_wrap.wrapped.train()
inp_c_rq = [i.requires_grad_() for i in inp_c]
for p in debug_wrap.wrapped.parameters():
    p.requires_grad_()
res = debug_wrap.wrapped(*inp_c_rq)[0]
grads_pt = torch.autograd.grad(res, inp_c_rq + list(debug_wrap.wrapped.parameters()), gr_out_c, allow_unused=True)

Did it work? It seems so:


In [34]:
for i, g_pt in enumerate(grads_pt):
    print(numpy.abs(gr_only_compiled_module.get_output(i).asnumpy() - g_pt.cpu().numpy()).max())


5.2452087e-06
1.001358e-05
6.4373016e-06
2.6226044e-06
1.1444092e-05
4.917383e-07
2.861023e-05
6.4373016e-06
1.335144e-05
6.198883e-06
6.556511e-06
4.172325e-06
6.866455e-05
3.33786e-06
8.821487e-06
1.9073486e-06
7.6293945e-06
1.9073486e-06

But we wanted to get something running in PyTorch, right?

Keeping with how PyTorch works, we first define an autograd.Function that the things we just did manually:

In the forward:

  • Generate the dropout random values,
  • Run the forward,
  • Record the captures, inputs, and dropout values needed for backward.

In the backward, run the backward and return the result (as PyTorch tensors).


In [35]:
fw_input_names = [p.name_hint for p in fw_and_cap_fn_flattened.params if not is_aux_input(p)]
input_to_idx = {n:i for i, n in enumerate(fw_input_names)}
inputs_to_keep_idx = [input_to_idx[i.name_hint] for i in inputs_to_keep]

In [36]:
class TVMFunction(torch.autograd.Function):
    # nb. using the modules is not thread safe...
    @staticmethod
    def forward(ctx, *inputs):
        assert len(inputs) == len(fw_input_names)
        assert all([i.is_cuda for i in inputs])
        drop_c = {}
        for k in dropout_info.keys(): # we don't know the order
            p, typ = dropout_info[k]
            drop_c[k] = torch.nn.functional.dropout(torch.ones([int(i) for i in typ.shape], 
                                                      dtype=getattr(torch, typ.dtype), device="cuda"), p=p)*(1-p)

        # we don't need to worry about PyTorch changing these because they're not visible.
        # so we don't need save_for_backward here
        drop_tvm = {n: tensor_to_tvm(t) for n, t in drop_c.items()}
        ctx.drop_tvm = drop_tvm 

        fw_and_cap_compiled_module.set_input(**drop_tvm)

        inputs_tvm = [tensor_to_tvm(t) for t in inputs]
        for n, i in zip(fw_input_names, inputs_tvm):
            fw_and_cap_compiled_module.set_input(n, i)
        fw_and_cap_compiled_module.run()
        if isinstance(output_type, tvm.relay.TensorType):
            res = tensor_from_tvm(fw_and_cap_compiled_module.get_output(0))
            num_outputs = 1
        else:
            res = tuple(tensor_from_tvm(fw_and_cap_compiled_module.get_output(i))
                        for i in range(len(output_type.fields)))

            num_outputs = len(res)
        ctx.save_for_backward(*([inputs[i] for i in inputs_to_keep_idx]
                               +[tensor_from_tvm(fw_and_cap_compiled_module.get_output(i)) 
                                 for i in range(num_outputs, fw_and_cap_compiled_module.get_num_outputs())]))
        return res

    @staticmethod
    def backward(ctx, *grad_outs):
        saved = ctx.saved_tensors
        kept_inputs = {fw_input_names[i]: tensor_to_tvm(t)
                       for i, t in zip(inputs_to_keep_idx, saved[:len(inputs_to_keep_idx)])}
        gr_only_compiled_module.set_input(**kept_inputs)
        captures = {f'input:captures:{i}': tensor_to_tvm(t) for i, t in enumerate(saved[len(kept_inputs):])}
        gr_only_compiled_module.set_input(**captures)
        grad_outs_tvm = {f"gr:out:{i}": tensor_to_tvm(go) for i, go in enumerate(grad_outs)}
        gr_only_compiled_module.set_input(**grad_outs_tvm)
        gr_only_compiled_module.set_input(**ctx.drop_tvm)
        gr_only_compiled_module.run()
        grad_in = [tensor_from_tvm(gr_only_compiled_module.get_output(i)) for i in range(gr_only_compiled_module.get_num_outputs())]
        return tuple(grad_in)

Because calling TVMFunction.apply does not please the eye, we define a convenience function and because we always love to have proper signatures, we also give it the names of our inputs.


In [37]:
def tvm_fn(*inputs):
    return TVMFunction.apply(*inputs)

tvm_fn.__signature__ = inspect.signature(tvm_fn).replace(
    parameters=[inspect.Parameter(n.replace('.', '__'), inspect.Parameter.POSITIONAL_ONLY) 
                for n in fw_input_names])

Let's check everything still works.


In [38]:
inp_all = (inp_c_rq + list(traced_module.parameters()))

torch.manual_seed(12345)
res_tvm = tvm_fn(*inp_all)

grad_outs = tuple(torch.randn_like(r) for r in res_tvm)
grads_tvm = torch.autograd.grad(res_tvm, inp_all, grad_outs)

In [39]:
assert len(grads_tvm) == len(grads_pt)
list((g1-g2).abs().max().item() for g1, g2 in zip(grads_tvm, grads_pt))


Out[39]:
[5.245208740234375e-06,
 1.0013580322265625e-05,
 6.4373016357421875e-06,
 2.6226043701171875e-06,
 1.1444091796875e-05,
 4.917383193969727e-07,
 2.86102294921875e-05,
 6.4373016357421875e-06,
 1.33514404296875e-05,
 6.198883056640625e-06,
 6.556510925292969e-06,
 4.172325134277344e-06,
 6.866455078125e-05,
 3.337860107421875e-06,
 8.821487426757812e-06,
 1.9073486328125e-06,
 7.62939453125e-06,
 1.9073486328125e-06]

Yay!

Let us wrap everything we did into a function that goes from traced model to autograd-wrapping function.

End-to-end converter


In [43]:
def create_tvm_function_from_traced_module(traced_module):
    assert traced_model.training, "We only do training right now"
    dt = next(traced_module.parameters()).dtype.__str__().split('.')[-1]
    shape_list = [(i.debugName().split('.')[0], i.type().sizes()) for i in  list(traced_module.graph.inputs())[1:]]
    mod, mod_params = tvm.relay.frontend.pytorch.from_pytorch(traced_module, shape_list, default_dtype=dt)

    # the converter will output arguments in an arbitrary order (well, by position of use), we want that of the input
    fn = mod['main']
    # Careful traced module's vs. non-traced module's parameter ordering.
    # Anecdotally, I have not seen orderings differ between the two, though.
    arg_order = ([n for n, _ in shape_list]
                 +[n for n, _ in traced_module.named_parameters()])
    tmp_arg_idx = {p.name_hint: i for i, p in enumerate(fn.params)}

    fn = tvm.relay.Function([fn.params[tmp_arg_idx[n]] for n in arg_order], fn.body)

    fn = TransposeDedupMutator().visit(fn)

    # prepare function to also use grad_out
    fn = infer_type(fn)
    output_type = fn.body.checked_type # fn.ret_type :)

    if isinstance(output_type, tvm.relay.TensorType):
        gr_out = tvm.relay.var("gr:out", output_type)
        fn_for_gr = tvm.relay.Function(list(fn.params) + [gr_out], tvm.relay.sum(fn.body * gr_out))
    else:
        # we can try to handle tuples of tensors, but our nesting patience ends there
        assert (isinstance(output_type, tvm.relay.TupleType) and
                all([isinstance(f, tvm.relay.TensorType) for f in output_type.fields]))
        gr_outs = [tvm.relay.var(f"gr:out:{i}", t) for i, t in enumerate(output_type.fields)]
        prods_with_gr_out = [tvm.relay.sum(tvm.relay.TupleGetItem(fn.body, i) * go_i)
                             for i, go_i in enumerate(gr_outs)]
        s = prods_with_gr_out[0]
        for p in prods_with_gr_out[1:]:
            s = s + p
        fn_for_gr = tvm.relay.Function(list(fn.params) + gr_outs, s)
    fn_for_gr = infer_type(fn_for_gr)
    fn_for_gr = tvm.relay.dataflow_pattern.rewrite(DecomposeLayerNorm(), fn_for_gr)
    fn_for_gr = infer_type(fn_for_gr)
    fn_for_gr, dropout_info = externalize_dropout(fn_for_gr)
    fn_for_gr = infer_type(fn_for_gr)

    # take the gradient
    grfn = tvm.relay.transform.gradient(fn_for_gr, mode='first_order')
    grfn = to_graph_normal_form(grfn)

    # removing of unneeded outputs and simplifications of the gradient
    
    # Now we have (sum(orig_out * grad_out), (grad_inp_1, ..., grad_inp_n, grad_grad_out, gr_dropout ...))
    # but we only want orig_out and grad_inp_1, ..., grad_inp_n
    def is_aux_input(p):
        return p.name_hint.startswith('dropout:') or p.name_hint.startswith('gr:out:')

    # the gr_out and dropout parameters will have gradients computed, but we do not want that
    grads_to_keep = tvm.relay.Tuple([g for p, g in zip(grfn.params, grfn.body.fields[1].fields)
                                       if not is_aux_input(p)])

    assert grfn.body.fields[0].op.name == 'sum'
    assert grfn.body.fields[0].args[0].op.name == 'multiply'
    if isinstance(output_type, tvm.relay.TensorType):
        orig_out = grfn.body.fields[0].args[0].args[0]
    else:
        assert isinstance(output_type, tvm.relay.TupleType)
        orig_out = grfn.body.fields[0].args[0].args[0].tuple_value
    out_and_grad = tvm.relay.Tuple([orig_out, grads_to_keep])
    out_and_grad_fn = tvm.relay.Function(grfn.params, out_and_grad)
    out_and_grad_fn = infer_type(out_and_grad_fn)
    out_and_grad_fn = dead_code_elimination(out_and_grad_fn)
    out_and_grad_fn = eliminate_common_subexpr(out_and_grad_fn)
    out_and_grad_fn = infer_type(out_and_grad_fn)
    out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(LikeZapp(), out_and_grad_fn)
    out_and_grad_fn = infer_type(out_and_grad_fn)
    out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(ZeroZapp(), out_and_grad_fn)
    out_and_grad_fn = infer_type(out_and_grad_fn)
    out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(OneZapp(), out_and_grad_fn)
    out_and_grad_fn = infer_type(out_and_grad_fn)
    out_and_grad_fn = tvm.relay.dataflow_pattern.rewrite(OneZapp(), out_and_grad_fn)
    out_and_grad_fn = infer_type(out_and_grad_fn)
    out_and_grad_fn = dead_code_elimination(out_and_grad_fn)
    out_and_grad_fn = eliminate_common_subexpr(out_and_grad_fn)

    # split the graph into forward and backward

    orig_out = out_and_grad_fn.body.fields[0]
    grad_ins = out_and_grad_fn.body.fields[1]

    color_dict = {}
    def color(n, c):
        if n in color_dict:
            return
        color_dict[n] = c
        for a in getattr(n, 'args', []):
            color(a, c)
        for a in getattr(n, 'fields', []):
            color(a, c)
        for nam in ('body', 'tuple_value'):
            b = getattr(n, nam, None)
            if b is not None:
                color(b, c)

    color(orig_out, {'color': 'red'})
    seen = set()
    def color_crossings(n, c):
        if n in seen:
            return
        seen.add(n)
        if n in color_dict:
            color_dict[n] = c
            return
        for a in getattr(n, 'args', []):
            color_crossings(a, c)
        for a in getattr(n, 'fields', []):
            color_crossings(a, c)
        for nam in ('body', 'tuple_value'):
            b = getattr(n, nam, None)
            if b is not None:
                color_crossings(b, c)

    color_crossings(grad_ins, {'color': 'blue'})

    nodes_to_capture = [n for n, v in color_dict.items() 
                        if v['color'] == 'blue' and not isinstance(n, (tvm.relay.Constant, tvm.relay.Var))]
    capture_tup = tvm.relay.Tuple(nodes_to_capture)
    nodes_to_capture_idx = {n:i for i, n in enumerate(nodes_to_capture)}
    capture_vars = [tvm.relay.var(f"input:captures:{i}", type_annotation=nodes_to_capture[i].checked_type)
                    for i, n in enumerate(nodes_to_capture)]

    grads_in = out_and_grad_fn.body.fields[1]

    needed_vars = set()
    class GradientOnlyMutator(tvm.relay.ExprMutator):
        def __init__(self):
            super().__init__()

        def visit_var(self, var):
            needed_vars.add(var)
            return var

        def visit(self, expr):
            if expr in nodes_to_capture_idx:
                return capture_vars[nodes_to_capture_idx[expr]]
            return super().visit(expr)

    grads_in_only = GradientOnlyMutator().visit(grads_in)

    # TODO: check against output of original
    fn_for_gr_input_names = {p.name_hint for p in fn_for_gr.params}
    needed_var_names = {v.name_hint for v in needed_vars}
    gr_only_fn = tvm.relay.Function(sorted(needed_vars) + capture_vars, grads_in_only)
    assert needed_var_names <= fn_for_gr_input_names

    inputs_to_keep = [n for n in needed_vars if not is_aux_input(n)]

    # build the forward function that also returns the data for the backward
    capture_tup = tvm.relay.Tuple([n for n in nodes_to_capture])
    fw_and_cap_params = [p for p in out_and_grad_fn.params if not p.name_hint.startswith('gr:out:')]

    fw_and_cap_fn = tvm.relay.Function(fw_and_cap_params,
                                       tvm.relay.Tuple((out_and_grad_fn.body.fields[0],) + (capture_tup,)))

    if isinstance(fn.body, tvm.relay.Tuple):
        # tuple of tensors output
        fw_and_cap_fn_flattened = tvm.relay.Function(fw_and_cap_fn.params, tvm.relay.Tuple(list(fw_and_cap_fn.body.fields[0].fields) # or single tensor
                                                    + list(fw_and_cap_fn.body.fields[1].fields)))
    else:
        # single tensor output
        fw_and_cap_fn_flattened = tvm.relay.Function(fw_and_cap_fn.params, tvm.relay.Tuple([fw_and_cap_fn.body.fields[0]]
                                                    + list(fw_and_cap_fn.body.fields[1].fields)))

    target = 'rocm -model=gfx906'
    target_host = 'llvm'
    ctx = tvm.context(target)

    fw_and_cap_mod = tvm.IRModule({"main": fw_and_cap_fn_flattened})
    with tvm.transform.PassContext(opt_level=3):
        graph, lib, params = tvm.relay.build(fw_and_cap_mod,
                                             target=target,
                                             target_host=target_host,
                                             params={})
    fw_and_cap_compiled_module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
    fw_and_cap_compiled_module.set_input(**params)

    gr_only_mod = tvm.IRModule({"main": gr_only_fn})
    with tvm.transform.PassContext(opt_level=3):
        graph, lib, params = tvm.relay.build(gr_only_mod,
                                         target=target,
                                         target_host=target_host,
                                         params={})
    gr_only_compiled_module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
    gr_only_compiled_module.set_input(**params) # we may have funny const tensors from TVM

    fw_input_names = [p.name_hint for p in fw_and_cap_fn_flattened.params if not is_aux_input(p)]
    input_to_idx = {n:i for i, n in enumerate(fw_input_names)}
    inputs_to_keep_idx = [input_to_idx[i.name_hint] for i in inputs_to_keep]

    

    class TVMFunction(torch.autograd.Function):
        # nb. using the compiled_modules is not thread safe...
        @staticmethod
        def forward(ctx, *inputs):
            assert len(inputs) == len(fw_input_names)
            assert all([i.is_cuda for i in inputs])
            drop_c = {}
            for k in dropout_info.keys(): # we don't know the order
                p, typ = dropout_info[k]
                drop_c[k] = torch.nn.functional.dropout(torch.ones([int(i) for i in typ.shape], 
                                                          dtype=getattr(torch, typ.dtype), device="cuda"), p=p)*(1-p)

            # we don't need to worry about PyTorch changing these because they're not visible.
            # so we don't need save_for_backward here
            drop_tvm = {n: tensor_to_tvm(t) for n, t in drop_c.items()}
            ctx.drop_tvm = drop_tvm 

            fw_and_cap_compiled_module.set_input(**drop_tvm)

            inputs_tvm = [tensor_to_tvm(t) for t in inputs]
            for n, i in zip(fw_input_names, inputs_tvm):
                fw_and_cap_compiled_module.set_input(n, i)
            fw_and_cap_compiled_module.run()
            if isinstance(output_type, tvm.relay.TensorType):
                res = tensor_from_tvm(fw_and_cap_compiled_module.get_output(0))
                num_outputs = 1
            else:
                res = tuple(tensor_from_tvm(fw_and_cap_compiled_module.get_output(i))
                            for i in range(len(output_type.fields)))

                num_outputs = len(res)
            ctx.save_for_backward(*([inputs[i] for i in inputs_to_keep_idx]
                                   +[tensor_from_tvm(fw_and_cap_compiled_module.get_output(i)) 
                                     for i in range(num_outputs, fw_and_cap_compiled_module.get_num_outputs())]))
            return res

        @staticmethod
        def backward(ctx, *grad_outs):
            saved = ctx.saved_tensors
            kept_inputs = {fw_input_names[i]: tensor_to_tvm(t)
                           for i, t in zip(inputs_to_keep_idx, saved[:len(inputs_to_keep_idx)])}
            gr_only_compiled_module.set_input(**kept_inputs)
            captures = {f'input:captures:{i}': tensor_to_tvm(t) for i, t in enumerate(saved[len(kept_inputs):])}
            gr_only_compiled_module.set_input(**captures)
            grad_outs_tvm = {f"gr:out:{i}": tensor_to_tvm(go) for i, go in enumerate(grad_outs)}
            gr_only_compiled_module.set_input(**grad_outs_tvm)
            gr_only_compiled_module.set_input(**ctx.drop_tvm)
            gr_only_compiled_module.run()
            grad_in = [tensor_from_tvm(gr_only_compiled_module.get_output(i)) for i in range(gr_only_compiled_module.get_num_outputs())]
            return tuple(grad_in)

    def tvm_fn(*inputs):
        return TVMFunction.apply(*inputs)

    tvm_fn.__signature__ = inspect.signature(tvm_fn).replace(
        parameters=[inspect.Parameter(n.replace('.', '__'), inspect.Parameter.POSITIONAL_ONLY) 
                    for n in fw_input_names])
    return tvm_fn

Let's give it a spin and see that it hasn't stopped working.


In [44]:
tvm_fn = create_tvm_function_from_traced_module(traced_module)


ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32

In [45]:
inp_all = (inp_c_rq + list(traced_module.parameters()))
torch.manual_seed(12345)
res_tvm = tvm_fn(*inp_all)

grad_outs = tuple(torch.randn_like(r) for r in res_tvm)

grads_tvm = torch.autograd.grad(res_tvm, inp_all, grad_outs)

In [46]:
torch.manual_seed(12345)
res_pt = traced_module(*inp_c_rq)
grads_pt = torch.autograd.grad(res_pt, inp_all, grad_outs)

In [47]:
assert len(res_tvm) == len(res_pt) and len(grads_tvm) == len(grads_pt)
(list((r1-r2).abs().max().item() for r1, r2 in zip(res_tvm, res_pt)),
 list((g1-g2).abs().max().item() for g1, g2 in zip(grads_tvm, grads_pt)))


Out[47]:
([2.0265579223632812e-06],
 [5.7220458984375e-06,
  1.33514404296875e-05,
  7.152557373046875e-06,
  3.039836883544922e-06,
  1.0728836059570312e-05,
  6.854534149169922e-07,
  3.24249267578125e-05,
  7.152557373046875e-06,
  1.33514404296875e-05,
  5.9604644775390625e-06,
  7.271766662597656e-06,
  4.291534423828125e-06,
  6.866455078125e-05,
  3.337860107421875e-06,
  8.821487426757812e-06,
  1.9073486328125e-06,
  7.62939453125e-06,
  1.9073486328125e-06])

Even better: Auto-Dispatching to TVM

But we promised that we could have a function that takes a module and sample inputs and modifies the model to use TVM if applicable.

Well, here it is, just a bit of messing with python method magic. For cleanliness, we also include a removal method.


In [48]:
def add_tvm_dispatch(module, sample_inputs):
    traced_module = torch.jit.trace(module, sample_inputs, )
    tvm_fn = create_tvm_function_from_traced_module(traced_module)
    tvm_input_shapes = [(i.shape, i.dtype, i.device) for i in sample_inputs]
    old_forward = module.forward
    old_remove_tvm_dispatch = getattr(module, 'remove_tvm_dispatch', None)


    def forward(self, *inputs):
        input_shapes = [(i.shape, i.dtype, i.device) for i in inputs]
        if tvm_input_shapes != input_shapes:
            res = old_forward(*inputs)
        else:
            inp_all = inputs + tuple(self.parameters())
            res = tvm_fn(*inp_all)
        return res

    def remove_tvm_dispatch(self):
        self.forward = old_forward
        if old_remove_tvm_dispatch is not None:
            self.remove_tvm_dispatch = old_remove_tvm_dispatch

    module.remove_tvm_dispatch = types.MethodType(remove_tvm_dispatch, module)
    module.forward = types.MethodType(forward, module)

All done!

Now let us run it for both a compatible input and an incompatible one. Notice the grad_fn printed at the end of the tensor output.


In [49]:
module = debug_wrap.wrapped
inp_c2 = [torch.cat([i, i], dim=0) for i in inp_c] # batch size 2 will be new

In [50]:
type(module)


Out[50]:
transformers.modeling_bert.BertLayer

In [51]:
add_tvm_dispatch(module, inp_c)


/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py:954: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
With rtol=1e-05 and atol=1e-05, found 10752 element(s) (out of 10752) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 2.4348974227905273 (-2.1344871520996094 vs. -4.569384574890137), which occurred at index (0, 13, 381).
  _check_trace(
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32

In [52]:
module.forward(*inp_c)


Out[52]:
(tensor([[[ 0.2507, -0.2194, -0.1747,  ..., -0.0382,  0.0428,  0.1907],
          [-0.1018,  0.6621, -0.3244,  ...,  0.2131, -0.0591, -0.5416],
          [-0.3190, -0.5550,  0.0830,  ...,  0.0665,  0.2982,  0.1724],
          ...,
          [ 0.8956, -0.0658, -0.9987,  ...,  0.0883, -0.2493,  0.8897],
          [ 1.0403,  0.0970, -0.6477,  ...,  0.2595, -0.2993,  0.1683],
          [-0.2900,  0.1849,  0.1094,  ..., -0.3210,  0.4615,  0.0437]]],
        device='cuda:0', grad_fn=<TVMFunctionBackward>),)

In [53]:
module(*inp_c2)  # different shape


Out[53]:
(tensor([[[ 0.1872, -0.2335, -0.1570,  ..., -0.0749,  0.0080,  0.2251],
          [-0.1784,  0.6730, -0.2436,  ...,  0.2280, -0.0746, -0.7620],
          [-0.5389, -0.6264, -0.1439,  ...,  0.1707,  0.2541,  0.1657],
          ...,
          [ 0.8209, -0.4704, -0.6749,  ..., -0.1276, -0.3264,  0.8429],
          [ 1.0422,  0.2161, -0.3209,  ...,  0.2026, -0.4514,  0.1065],
          [-0.2874,  0.1732,  0.0920,  ..., -0.2110,  0.5125,  0.0438]],
 
         [[ 0.2182, -0.2297, -0.1577,  ..., -0.0670,  0.0161,  0.2142],
          [-0.1877,  0.6781, -0.3514,  ...,  0.2637, -0.1320, -0.7478],
          [-0.4626, -0.7372,  0.0140,  ...,  0.1907,  0.1301,  0.2509],
          ...,
          [ 0.7453,  0.1160, -0.4402,  ..., -0.0357, -0.2483,  1.0130],
          [ 1.0437,  0.3303, -0.4749,  ...,  0.2047, -0.2310, -0.0612],
          [-0.2895,  0.2159,  0.1210,  ..., -0.1664,  0.5055, -0.0207]]],
        device='cuda:0', grad_fn=<NativeLayerNormBackward>),)

In [54]:
module.remove_tvm_dispatch()  # cleaning up

Performance

As I said in the beginning, we aren't quite where we want to eventually be in terms of performance. But let us tune the tasks a bit to see.


In [55]:
tasks1 = tvm.autotvm.task.extract_from_program(fw_and_cap_fn_flattened, target=target, params=params)
tasks2 = tvm.autotvm.task.extract_from_program(gr_only_mod["main"], target=target, params=params)

In [56]:
log_filename = 'bert-train-0.log'
n_trial = 20  # for real tuning, make this 2000!

def do_tune(tasks, log_filename):
    tmp_log_file = log_filename + ".tmp"
    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " %(i+1, len(tasks))

        # we use threading and tornado here to work around TVM and Jupyter colliding over IOLoops
        # In a regular python command line, you should be able to just call the tuner...
        import threading 
        import tornado

        # create tuner
        tuner = tvm.autotvm.tuner.XGBTuner(tsk, loss_type='rank')
        if os.path.isfile(tmp_log_file):
            tuner.load_history(tvm.autotvm.record.load_from_file(tmp_log_file))

        # do tuning
        tsk_trial = min(n_trial, len(tsk.config_space))
        def tune_task_fn():
            iol = tornado.ioloop.IOLoop()  # we need an event loop
            tuner.tune(
                n_trial=n_trial,
                early_stopping=600,
                measure_option=tvm.autotvm.measure_option(
                    builder=tvm.autotvm.LocalBuilder(timeout=10),
                    runner=tvm.autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150)),
                callbacks=[
                    tvm.autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                    tvm.autotvm.callback.log_to_file(tmp_log_file)
                ])

        tuning_thread = threading.Thread(target=tune_task_fn)  # create a thread start it and wait on it
        tuning_thread.start()
        tuning_thread.join()
        # done tuning, on to the next task

    # pick best records to a cache file
    tvm.autotvm.record.pick_best(tmp_log_file, log_filename)

#do_tune(tasks1+tasks2, log_filename)

We build with our log.


In [57]:
with tvm.autotvm.apply_history_best(log_filename):
    tvm_fn = create_tvm_function_from_traced_module(traced_module)


ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
ANTLR runtime and generated code versions disagree: 4.8!=4.7.2
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32
WARNING:root:Untyped Tensor found, assume it is float32

In [ ]:


In [58]:
def x():
    for i in range(100):
        res_tvm = tvm_fn(*inp_all)
        grads_tvm = torch.autograd.grad(res_tvm, inp_all, grad_outs)
    ctx.sync()

x()
%timeit x()


621 ms ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [60]:
def x():
    for i in range(100):
        res_pt = traced_module(*inp_c_rq)
        grads_pt = torch.autograd.grad(res_pt, inp_all, grad_outs)
    torch.cuda.synchronize()
x()
%timeit x()


126 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So here it is. We ran our model through TVM all right. But it's not as fast as the usual method yet. Here is to opportunity!

More seriously, we have two things to improve performance:

  • Find a better set of captured nodes.
  • Find optimizations on the TVM graph.

In terms of heuristics for the former (remember that it quite likely NP hard, i.e. I believe it is, but I didn't work out a formal proof), one would want to re-do cheap computation, most prominently point-wise computation (or maybe anything but matmul?). But that is for another day.

I hope you enjoyed the tutorial, I look forward to your questions and comments at tv@lernapparat.de.


In [ ]: