Automatic optimization with the PyTorch JIT

a worked example

by Thomas Viehmann tv@lernapparat.de

Today, I would like to discuss in detail some aspects of optimizing code in models, and in particular how you can let the PyTorch JIT optimize things for you.

We will use the Intersection over Union loss commonly used in training detection models as an example and explore various ways to implement it in PyTorch.

The intersection over union (or IoU) loss arises in training detection networks. Given two axis-parallel rectangles (blue and red), we wish to compute the quotient between the are in the intersection (which is a rectangle again) and the union. In colors:

As the intersection is always contained in the union, we know that $0 \leq IoU \leq 1$ (with the optimum being $1$, so strictly speaking $-IoU$ would be a loss).

Note that if we have the area of the intersection and of the two rectangles, we can also express the area of the union as the sum areas of the two rectangles minus the area of the intersection (which is contained twice in the sum).

Let $(x_1, y_1, w_1, h_1$) be the coordinates top left and the width and the height of the first rectangle and $(x_2, y_2, w_2, h_2$) those of the second.

The intersection is easily calculated: If we have the top left and bottom right coordinates (and our coordinate system has increasing $y$ from top to bottom), we can take the maximum of the top left coordinates and the minimum of the bottom right coordinates. So we have1 $$ x_I = \max(x_1, x_2), \qquad y_I = \max(y_1, y_2) $$ and - we need to calculate the bottom right corners, take the minimum and transform back to width and hight - $$ w_I = \min(x_1 + w_1, x_2 + w_2)-x_I. $$ But there is a slight complication when the rectangles don't intersect: then our formulae do not work but instead give us the rectangle "between" the two but with the corner points exchanged. This means that then $w_i$ calculated as above is actually negative, so we can fix this by enforcing a minimum of $0$ $$ w_I = \max \left( \min(x_1 + w_1, x_2 + w_2)-x_I,0\right), \qquad h_I = \max \left( \min(y_1 + h_1, y_2 + h_2)-y_I,0\right). $$ Note that these last maxmimizations with a constant would be performed in PyTorch using the torch.clamp function, while the (elementwise) maximum and minimum between two tensors is computed using torch.min and torch.max.

Speaking of PyTorch, enough of the theory, let's move to practical things!


  1. I use $I$ here to mean Intersection, it's not an index.


In [1]:
import torch
import torch.utils.cpp_extension

The formulas above readily translate into a PyTorch function. Just to be safe, we clamp the the denominator to be at least $10^{-5}$.


In [2]:
def ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2, eps=1e-5):
    xi = torch.max(x1, x2)                                 # Intersection
    yi = torch.max(y1, y2)
    wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0)
    hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0)
    area_i = wi * hi                                       # Area Intersection
    area_u = w1 * h1 + w2 * h2 - wi * hi                   # Area Union
    return area_i / torch.clamp(area_u, min=eps)

The function will is vector-ready just by passing in a multi-dimensional tensor. Let's try it out with some dummy data:


In [3]:
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 100, 1000, device='cuda').exp()
ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2)


Out[3]:
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0233],
        [0.0000, 0.0000, 0.0000,  ..., 0.0444, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0614, 0.0000, 0.0000,  ..., 0.1205, 0.0000, 0.0599],
        [0.0000, 0.0000, 0.0000,  ..., 0.2437, 0.0000, 0.0110],
        [0.0000, 0.5228, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

Without looking too much at the results, it seems to work.

Let us take a short digression here. As you may know, PyTorch provides functional interfaces in torch.nn.functional (often also known as F) as well as modules (in torch.nn, commonly imported as nn)1. It does so for typical neural network components as well as the loss functions. We might wonder which is preferable for our own modelling. It is, in the end, a question of style, but I would suggest the following as a good rule of thumb: If it has (significant) parameters or even state, use the module interface (so subclass nn.Module). If it has not, define a function as the above. I also do this when using PyTorch's functions - e.g. I usually spell out my forward and prefer to use the function F.relu over the module nn.Relu.

But enough of the digression. Can our ratio_iou calculation be made more efficient?

One common thought when trying to make Python things more efficient is moving to C++. Fortunately PyTorch makes it very straightforward to do so, by the way of C++ extensions or custom operators. Both work the same except for the actual bindings. The difference between them is that functions in PyTorch extensions can take any parameters (by using the library PyBind11) while custom operators are restricted to the types that PyTorch supports (e.g. Tensors, int64_t, double, std::string, IntList and TensorList). The advantage of custom operators is that they can be used with the JIT and in C++, too.

Happily, we can just type our C-Code into a cell and have PyTorch compile it for us. Let's do a custom operator that follows exactly the Python function above:


  1. I might say that I usually just type out the modules instead for importing them under short names.


In [4]:
csrc = """
#include <torch/script.h>

using namespace torch;

Tensor iou_native(const Tensor& x1, const Tensor& y1, const Tensor& w1, const Tensor& h1,
                  const Tensor& x2, const Tensor& y2, const Tensor& w2, const Tensor& h2) {

    auto xi = torch::max(x1, x2);
    auto yi = torch::max(y1, y2);
    auto wi = torch::clamp(torch::min(x1+w1, x2+w2) - xi, 0);
    auto hi = torch::clamp(torch::min(y1+h1, y2+h2) - yi, 0);
    auto area_i = wi * hi;
    auto area_u = w1 * h1 + w2 * h2 - wi * hi;
    return area_i / torch::clamp(area_u, 1e-5);
}


static auto registry =
  torch::jit::RegisterOperators("super_iou::iou_native", &iou_native);
"""

torch.utils.cpp_extension.load_inline("libsuperiou", csrc, is_python_module=False, verbose=True)


Using /tmp/torch_extensions as PyTorch extensions root...
Emitting ninja build file /tmp/torch_extensions/libsuperiou/build.ninja...
Building extension module libsuperiou...
Loading extension module libsuperiou...

That was easy enough! Now we have a custom op unter torch.ops, the name it is available under is determined by the string argument to RegisterOperators - here torch.ops.super_iou.iou_native. (Note: If you get an error about "multiple overloads", you'll have to reload your kernel and start again... While PyTorch extensions support re-building and re-loading, custom operators run into trouble with that.) Let's see if it gives the same result as the Python version:


In [5]:
(ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2)==torch.ops.super_iou.iou_native(x1, y1, w1, h1, x2, y2, w2, h2)).all().item()


Out[5]:
1

It works. Note that in general it is safer to use torch.almost_equal or print (a-b).abs().max() to deal with numerical precision. But here, == works well, too.

So how about timings? Note that we need to call torch.cuda.synchronize() to get valid timings on the GPU.


In [6]:
def taketime(fn):
    _ = fn(x1, y1, w1, h1, x2, y2, w2, h2)
    torch.cuda.synchronize()

torch.cuda.synchronize()
%timeit taketime(ratio_iou)
%timeit taketime(torch.ops.super_iou.iou_native)


1000 loops, best of 3: 1.07 ms per loop
1000 loops, best of 3: 987 µs per loop

We see that there is difference of about 5% n cuda, if we did this with CPU tensors, there would be no significant difference. Depending on the nature of the calculation, this is a typical result. For the lltm model in the PyTorch C++-Extension tutorial, you get a speedup of about 10% by moving to C++. But this involves a loop over the input sequence, so calls quite a few tensor operation. Here we only have a handful of operations, so moving to C++ offers little performance gain by itself.

What is relatively slow about our code is that each operation stores intermediate results in tensors and the next operation reads those to continue. If we write our own kernel, that can be helped. I consider this the "classic way" of optimizing models. The TensorAccessor (for CPU) / PackedTensorAccessor (for transferring sizes and strides to GPU) classes provide a convenient interface for element access. As you would in production, we multiplex the floating types through templates in scalar_t. For simplicity, we only deal with 1-d tensors (this is the second argument to anything accessor).


In [7]:
csrc = """
#include <torch/script.h>
#include <ATen/Parallel.h>

using namespace torch;

// The cuda kernel is easy enough
template<typename scalar_t>
__global__ void iou_kernel_gpu(PackedTensorAccessor<scalar_t, 1> result,
                          PackedTensorAccessor<scalar_t, 1> x1,
                          PackedTensorAccessor<scalar_t, 1> y1,
                          PackedTensorAccessor<scalar_t, 1> w1,
                          PackedTensorAccessor<scalar_t, 1> h1,
                          PackedTensorAccessor<scalar_t, 1> x2,
                          PackedTensorAccessor<scalar_t, 1> y2,
                          PackedTensorAccessor<scalar_t, 1> w2,
                          PackedTensorAccessor<scalar_t, 1> h2
                          ) {
    int i = threadIdx.x + blockDim.x * blockIdx.x;
    if (i >= x1.size(0)) // we might have more threads than work to do in the last block
      return;
    // This should look very familiar. We could try reading each element only once, but let's keep it simple.
    scalar_t xi = max(x1[i], x2[i]);
    scalar_t yi = max(y1[i], y2[i]);
    scalar_t wi = max(min(x1[i]+w1[i], x2[i]+w2[i]) - xi, static_cast<scalar_t>(0));
    scalar_t hi = max(min(y1[i]+h1[i], y2[i]+h2[i]) - yi, static_cast<scalar_t>(0));
    scalar_t area_i = wi * hi;
    scalar_t area_u = w1[i] * h1[i] + w2[i] * h2[i] - area_i;
    result[i] = area_i / max(area_u, static_cast<scalar_t>(0.00001f));
}

// The CPU kernel is looks similar, we could also just put it in the main function...
template<typename scalar_t>
void iou_kernel_cpu(TensorAccessor<scalar_t, 1> result,
                    TensorAccessor<scalar_t, 1> x1,
                    TensorAccessor<scalar_t, 1> y1,
                    TensorAccessor<scalar_t, 1> w1,
                    TensorAccessor<scalar_t, 1> h1,
                    TensorAccessor<scalar_t, 1> x2,
                    TensorAccessor<scalar_t, 1> y2,
                    TensorAccessor<scalar_t, 1> w2,
                    TensorAccessor<scalar_t, 1> h2) {

    // we use CPU parallelization
    constexpr int64_t GRAIN_SIZE = 8192; // minimum grain size for parallel execution
    at::parallel_for(0, x1.size(0), GRAIN_SIZE, [&](int64_t i_begin, int64_t i_end) {
        for (int64_t i = i_begin; i < i_end; ++i) {
            scalar_t xi = max(x1[i], x2[i]);
            scalar_t yi = max(y1[i], y2[i]);
            scalar_t wi = max(min(x1[i]+w1[i], x2[i]+w2[i]) - xi, static_cast<scalar_t>(0));
            scalar_t hi = max(min(y1[i]+h1[i], y2[i]+h2[i]) - yi, static_cast<scalar_t>(0));
            scalar_t area_i = wi * hi;
            scalar_t area_u = w1[i] * h1[i] + w2[i] * h2[i] - area_i;
            result[i] = area_i / max(area_u, static_cast<scalar_t>(0.00001f));
        }
    });
}


torch::Tensor iou_forward(const Tensor& x1, const Tensor& y1, const Tensor& w1, const Tensor& h1,
                          const Tensor& x2, const Tensor& y2, const Tensor& w2, const Tensor& h2) {
  auto res = torch::empty_like(x1);
  for (auto& t : {x1, y1, w1, h1, x2, y2, w2, h2}) {
     AT_ASSERTM(t.dim()==1 && t.size(0)==x1.size(0) && t.device()==x1.device() && t.dtype()==x1.dtype(),
                "tensors are not of same shape and kind");
  }
  if (x1.is_cuda()) {
    dim3 block(512);
    dim3 grid((x1.size(0)+511)/512);
    AT_DISPATCH_FLOATING_TYPES(x1.type(), "iou", [&] {
      iou_kernel_gpu<scalar_t><<<grid,block>>>(res.packed_accessor<scalar_t, 1>(),
                              x1.packed_accessor<scalar_t, 1>(),
                              y1.packed_accessor<scalar_t, 1>(),
                              w1.packed_accessor<scalar_t, 1>(),
                              h1.packed_accessor<scalar_t, 1>(),
                              x2.packed_accessor<scalar_t, 1>(),
                              y2.packed_accessor<scalar_t, 1>(),
                              w2.packed_accessor<scalar_t, 1>(),
                              h2.packed_accessor<scalar_t, 1>());
    });
  } else {
    AT_DISPATCH_FLOATING_TYPES(x1.type(), "iou", [&] {
      iou_kernel_cpu<scalar_t>(res.accessor<scalar_t, 1>(),
                              x1.accessor<scalar_t, 1>(),
                              y1.accessor<scalar_t, 1>(),
                              w1.accessor<scalar_t, 1>(),
                              h1.accessor<scalar_t, 1>(),
                              x2.accessor<scalar_t, 1>(),
                              y2.accessor<scalar_t, 1>(),
                              w2.accessor<scalar_t, 1>(),
                              h2.accessor<scalar_t, 1>());
    });  
  }
  return res;
}

torch::Tensor iou_native(const Tensor& x1, const Tensor& y1, const Tensor& w1, const Tensor& h1,
                         const Tensor& x2, const Tensor& y2, const Tensor& w2, const Tensor& h2) {

    auto xi = torch::max(x1, x2);
    auto yi = torch::max(y1, y2);
    auto wi = torch::clamp(torch::min(x1+w1, x2+w2) - xi, 0);
    auto hi = torch::clamp(torch::min(y1+h1, y2+h2) - yi, 0);
    auto area_i = wi * hi;
    auto area_u = w1 * h1 + w2 * h2 - wi * hi;
    return area_i / torch::clamp(area_u, 1e-5);
}


static auto registry =
  torch::jit::RegisterOperators("super_iou2::iou_forward", &iou_forward)
    .op("super_iou2::iou_native", &iou_native);
;
"""

torch.utils.cpp_extension.load_inline("iou_op", "", csrc, is_python_module=False, verbose=True)


Using /tmp/torch_extensions as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /tmp/torch_extensions/iou_op/build.ninja...
Building extension module iou_op...
Loading extension module iou_op...

Phew. That was a bit tedious, but let's see if it works!


In [8]:
x1, y1, w1, h1, x2, y2, w2, h2 = [t.view(-1) for t in [x1, y1, w1, h1, x2, y2, w2, h2]]

print ("check gpu", (ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2)==torch.ops.super_iou.iou_native(x1, y1, w1, h1, x2, y2, w2, h2)).all().item())
print ("check cpu", (torch.ops.super_iou2.iou_forward(x1.cpu(), y1.cpu(), w1.cpu(), h1.cpu(), x2.cpu(), y2.cpu(), w2.cpu(), h2.cpu())
       == torch.ops.super_iou2.iou_forward(x1.cpu(), y1.cpu(), w1.cpu(), h1.cpu(), x2.cpu(), y2.cpu(), w2.cpu(), h2.cpu())).all().item())


check gpu 1
check cpu 1

So it seems to work, let's time things again.


In [9]:
torch.cuda.synchronize()
%timeit taketime(torch.ops.super_iou2.iou_forward)
%timeit taketime(ratio_iou)


The slowest run took 8.34 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 81.6 µs per loop
1000 loops, best of 3: 1.05 ms per loop

Now that is a lot faster!

However, it is not usable as is: We do not have a backward. So we need two more kernels? Can we get something that is fast and doesn't need us to write all the infrastructure?

It turns out we can. The PyTorch JIT has two awesome components, the fuser and the autodiff that will automatically create kernels for us. (There is a limitation, here, we need to specify the max argument to clamp in order for this here to work.)


In [10]:
import math
@torch.jit.script
def ratio_iou_scripted(x1, y1, w1, h1, x2, y2, w2, h2):
    xi = torch.max(x1, x2)                                    # Intersection (yi similarly)
    yi = torch.max(y1, y2)                                    # Intersection (yi similarly)
    wi = torch.clamp(torch.min(x1+w1, x2+w2) - xi, min=0, max=math.inf)
    hi = torch.clamp(torch.min(y1+h1, y2+h2) - yi, min=0, max=math.inf)
    area_i = wi * hi                                      # Area Intersection
    area_u = w1 * h1 + w2 * h2 - wi * hi    # Area Union
    return area_i / torch.clamp(area_u, min=1e-5, max=math.inf)

In [11]:
print("check", (ratio_iou_scripted(x1, y1, w1, h1, x2, y2, w2, h2)-ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2)).abs().max().item())


check 1.7881393432617188e-07

Let's time it again:


In [12]:
torch.cuda.synchronize()
%timeit taketime(torch.ops.super_iou2.iou_forward)
%timeit taketime(ratio_iou_scripted)
%timeit taketime(ratio_iou)


10000 loops, best of 3: 81.5 µs per loop
10000 loops, best of 3: 157 µs per loop
1000 loops, best of 3: 1.05 ms per loop

Not bad! We got a more than 6x speedup just by putting @torch.jit.script above our function. While apparent factor of two off the hand-crafted kernel still isn't ideal, part of that is that the size of the tensors isn't that large. Going to 10 Million elements, we are only 25% slower than the handwritten kernel:


In [13]:
x1, y1, w1, h1, x2, y2, w2, h2 = torch.randn(8, 10_000_000, device='cuda').exp()
torch.cuda.synchronize()
%timeit taketime(torch.ops.super_iou2.iou_forward)
%timeit taketime(ratio_iou_scripted)


1000 loops, best of 3: 1.02 ms per loop
1000 loops, best of 3: 1.28 ms per loop

How did that work? We can look at the graph the JIT has built for our calculation: You see that the main graph defers to a FusionGroup. The fusion group represents the graph that will be compiled into our custom kernel. (Note: I assume here that you run this with parameters not requiring gradients, we'll repeat the same with gradients below.)


In [14]:
ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)


Out[14]:
graph(%x1 : Float(*)
      %y1 : Float(*)
      %w1 : Float(*)
      %h1 : Float(*)
      %x2 : Float(*)
      %y2 : Float(*)
      %w2 : Float(*)
      %h2 : Float(*)) {
  %32 : Float(*) = prim::FusionGroup_0(%w2, %h2, %w1, %h1, %y2, %y1, %x2, %x1)
  return (%32);
}
with prim::FusionGroup_0 = graph(%14 : Float(*)
      %15 : Float(*)
      %17 : Float(*)
      %18 : Float(*)
      %34 : Float(*)
      %37 : Float(*)
      %51 : Float(*)
      %54 : Float(*)) {
  %xi : Float(*) = aten::max(%54, %51)
  %yi : Float(*) = aten::max(%37, %34)
  %55 : int = prim::Constant[value=1]()
  %56 : Float(*) = aten::add(%54, %17, %55)
  %52 : int = prim::Constant[value=1]()
  %53 : Float(*) = aten::add(%51, %14, %52)
  %50 : Float(*) = aten::min(%56, %53)
  %46 : int = prim::Constant[value=1]()
  %47 : Float(*) = aten::sub(%50, %xi, %46)
  %41 : int = prim::Constant[value=0]()
  %42 : float = prim::Constant[value=inf]()
  %wi : Float(*) = aten::clamp(%47, %41, %42)
  %38 : int = prim::Constant[value=1]()
  %39 : Float(*) = aten::add(%37, %18, %38)
  %35 : int = prim::Constant[value=1]()
  %36 : Float(*) = aten::add(%34, %15, %35)
  %33 : Float(*) = aten::min(%39, %36)
  %29 : int = prim::Constant[value=1]()
  %30 : Float(*) = aten::sub(%33, %yi, %29)
  %24 : int = prim::Constant[value=0]()
  %25 : float = prim::Constant[value=inf]()
  %hi : Float(*) = aten::clamp(%30, %24, %25)
  %area_i : Float(*) = aten::mul(%wi, %hi)
  %19 : Float(*) = aten::mul(%17, %18)
  %16 : Float(*) = aten::mul(%14, %15)
  %12 : int = prim::Constant[value=1]()
  %13 : Float(*) = aten::add(%19, %16, %12)
  %8 : int = prim::Constant[value=1]()
  %area_u : Float(*) = aten::sub(%13, %area_i, %8)
  %4 : float = prim::Constant[value=1e-05]()
  %5 : float = prim::Constant[value=inf]()
  %6 : Float(*) = aten::clamp(%area_u, %4, %5)
  %2 : Float(*) = aten::div(%area_i, %6)
  return (%2);
}

Note that even if things are shown in a fusion group, it can sometimes happen that the fuser decides it cannot create a kernel. You can observe kernel creation by setting the environment variable PYTORCH_FUSION_DEBUG=1 (works best on the console, the source code is written to the terminal).

But we really wanted to get forward and backward, so let's do that.

Here is a bit of digression again, but I'll keep it very short: Note that I use requires_grad_() below instead of a requires_grad=True argument in the randn. This is because now x1 and friends are leaf variables to the autograd graph, otherwise the random tensor (not assigned a Python variable) would be the leaf variables and accumulate the grads! This is something that you can easily fool yourself with (I can't say it never happened to me before and it's a not-so-infrequent cause for people asking on the forums, too). I prefer .requires_grad_() over setting the attribute .requires_grad = True because the first is not only shorter, but also will fail if I misspell it for any reason.

But so here is timing this with backward (I always evaluate the scripted function to not have the one-off compilation time in the timing):


In [15]:
x1, y1, w1, h1, x2, y2, w2, h2 = [t.requires_grad_() for t in torch.randn(8, 100_000, device='cuda').exp()]
l1 = ratio_iou(x1, y1, w1, h1, x2, y2, w2, h2)
l2 = ratio_iou_scripted(x1, y1, w1, h1, x2, y2, w2, h2)
grad_out = torch.randn_like(l1)
grads1 = torch.autograd.grad(l1, [x1, y1, w1, h1, x2, y2, w2, h2], grad_out)
grads2 = torch.autograd.grad(l2, [x1, y1, w1, h1, x2, y2, w2, h2], grad_out)

print ("check:", (l1-l2).abs().max().item(), max([(g1-g2).abs().max().item() for g1, g2 in zip(grads1, grads2)]))

def time_loss_and_backward(fn):
    l = fn(x1, y1, w1, h1, x2, y2, w2, h2)
    grads = torch.autograd.grad(l, [x1, y1, w1, h1, x2, y2, w2, h2], grad_out)
    torch.cuda.synchronize()
torch.cuda.synchronize()
%timeit time_loss_and_backward(ratio_iou)
%timeit time_loss_and_backward(ratio_iou_scripted)


check: 1.1920928955078125e-07 9.5367431640625e-07
100 loops, best of 3: 5.3 ms per loop
1000 loops, best of 3: 1.17 ms per loop

I get a 4.5x speedup. Not bad for just adding @torch.jit.script!

My measurements have been done on my PR #14957 branch. The backward optimization has had a bit of a bumpy ride in PyTorch in November 2018, as a late fix for correct gradients of broadcasted tensors has inserted summations into the backward that cannot be fused. I hope that it will be fixed soon.

Let's look at the graph again. You see that it now is wrapped in a DifferentiableGraph. This means that the JIT autodiff has identified a block that it knows how to differentiate. Ìnside, you have the FusionGroup we already saw and a bit of broadcasting.


In [16]:
ratio_iou_scripted.graph_for(x1, y1, w1, h1, x2, y2, w2, h2)


Out[16]:
graph(%x1 : Float(*)
      %y1 : Float(*)
      %w1 : Float(*)
      %h1 : Float(*)
      %x2 : Float(*)
      %y2 : Float(*)
      %w2 : Float(*)
      %h2 : Float(*)) {
  %32 : Float(*) = prim::DifferentiableGraph_0(%w2, %h2, %w1, %h1, %y2, %y1, %x2, %x1)
  return (%32);
}
with prim::DifferentiableGraph_0 = graph(%14 : Float(*)
      %15 : Float(*)
      %17 : Float(*)
      %18 : Float(*)
      %34 : Float(*)
      %37 : Float(*)
      %51 : Float(*)
      %54 : Float(*)) {
  %334 : Float(*), %335 : Float(*), %area_u.1 : Float(*), %area_i.1 : Float(*), %hi.1 : Float(*), %342 : Float(*), %344 : Float(*), %345 : Float(*), %wi.1 : Float(*), %347 : Float(*), %349 : Float(*), %350 : Float(*) = prim::FusionGroup_0(%14, %15, %17, %18, %34, %37, %51, %54)
  %353 : int[] = aten::size(%14)
  %354 : int[] = aten::size(%15)
  %355 : int[] = aten::size(%17)
  %356 : int[] = aten::size(%18)
  %357 : int[] = aten::size(%34)
  %358 : int[] = aten::size(%37)
  %359 : int[] = aten::size(%51)
  %360 : int[] = aten::size(%54)
  %367 : int[] = aten::size(%344)
  %368 : int[] = aten::size(%345)
  %371 : int[] = aten::size(%349)
  %372 : int[] = aten::size(%350)
  %373 : int[] = prim::BroadcastSizes(%360, %359)
  %374 : int[] = prim::BroadcastSizes(%358, %357)
  %377 : int[] = prim::BroadcastSizes(%372, %371)
  %381 : int[] = prim::BroadcastSizes(%368, %367)
  %384 : int[] = prim::BroadcastSizes(%355, %356)
  %385 : int[] = prim::BroadcastSizes(%353, %354)
  %386 : int[] = prim::BroadcastSizes(%384, %385)
  return (%334, %350, %349, %377, %373, %347, %wi.1, %345, %344, %381, %374, %342, %hi.1, %area_i.1, %384, %385, %386, %area_u.1, %335);
}
with prim::FusionGroup_0 = graph(%14 : Float(*)
      %15 : Float(*)
      %17 : Float(*)
      %18 : Float(*)
      %34 : Float(*)
      %37 : Float(*)
      %51 : Float(*)
      %54 : Float(*)) {
  %xi : Float(*) = aten::max(%54, %51)
  %yi : Float(*) = aten::max(%37, %34)
  %55 : int = prim::Constant[value=1]()
  %56 : Float(*) = aten::add(%54, %17, %55)
  %52 : int = prim::Constant[value=1]()
  %53 : Float(*) = aten::add(%51, %14, %52)
  %50 : Float(*) = aten::min(%56, %53)
  %46 : int = prim::Constant[value=1]()
  %47 : Float(*) = aten::sub(%50, %xi, %46)
  %41 : int = prim::Constant[value=0]()
  %42 : float = prim::Constant[value=inf]()
  %wi.1 : Float(*) = aten::clamp(%47, %41, %42)
  %38 : int = prim::Constant[value=1]()
  %39 : Float(*) = aten::add(%37, %18, %38)
  %35 : int = prim::Constant[value=1]()
  %36 : Float(*) = aten::add(%34, %15, %35)
  %33 : Float(*) = aten::min(%39, %36)
  %29 : int = prim::Constant[value=1]()
  %30 : Float(*) = aten::sub(%33, %yi, %29)
  %24 : int = prim::Constant[value=0]()
  %25 : float = prim::Constant[value=inf]()
  %hi.1 : Float(*) = aten::clamp(%30, %24, %25)
  %area_i.1 : Float(*) = aten::mul(%wi.1, %hi.1)
  %19 : Float(*) = aten::mul(%17, %18)
  %16 : Float(*) = aten::mul(%14, %15)
  %12 : int = prim::Constant[value=1]()
  %13 : Float(*) = aten::add(%19, %16, %12)
  %8 : int = prim::Constant[value=1]()
  %area_u.1 : Float(*) = aten::sub(%13, %area_i.1, %8)
  %4 : float = prim::Constant[value=1e-05]()
  %5 : float = prim::Constant[value=inf]()
  %6 : Float(*) = aten::clamp(%area_u.1, %4, %5)
  %2 : Float(*) = aten::div(%area_i.1, %6)
  return (%2, %6, %area_u.1, %area_i.1, %hi.1, %30, %36, %39, %wi.1, %47, %53, %56);
}

Let's look at the backward graph, too. I extracted the code to get the backward graph from PyTorch's testsuite. I re-define the function in order for only a single backward being defined. It tries to extract the backward graph from the latest(?) run forward, so it might be a bit fragile (rerun the definition of the ratio_iou_script and the timing with backward before backward_graph) if you run into trouble. Note that in the output here, the bulk of the calculation (except a few GradSumToSize) is done in a large fusion group again. On 1.0 this would have been split into piecemeal fusiongroups with SumToSize in between.


In [17]:
def backward_graph(script_module):
    # magic debugging stuff I learned about in the PyTorch JIT test suite
    graph_executor_state = script_module.get_debug_state()
    fwd_plan = list(graph_executor_state.execution_plans.values())[-1]
    grad_executor = list(fwd_plan.code.grad_executors())[-1]
    bwd_plan = list(grad_executor.get_debug_state().execution_plans.values())[-1]
    return bwd_plan.graph.copy() # in order to own the graph, we need to make a copy

In [18]:
backward_graph(ratio_iou_scripted)


Out[18]:
graph(%0 : Float(*)
      %1 : UndefinedTensor
      %2 : UndefinedTensor
      %3 : UndefinedTensor
      %4 : UndefinedTensor
      %5 : UndefinedTensor
      %6 : UndefinedTensor
      %7 : UndefinedTensor
      %8 : UndefinedTensor
      %9 : UndefinedTensor
      %10 : UndefinedTensor
      %11 : UndefinedTensor
      %12 : Float(*)
      %13 : Float(*)
      %14 : Float(*)
      %15 : Float(*)
      %16 : Float(*)
      %17 : Float(*)
      %18 : Float(*)
      %19 : Float(*)
      %20 : Float(*)
      %21 : Float(*)
      %22 : int[]
      %23 : int[]
      %24 : Float(*)
      %wi : Float(*)
      %26 : Float(*)
      %27 : Float(*)
      %28 : int[]
      %29 : int[]
      %30 : Float(*)
      %hi : Float(*)
      %area_i : Float(*)
      %33 : int[]
      %34 : int[]
      %35 : int[]
      %area_u : Float(*)
      %37 : Float(*)) {
  %38 : int[] = aten::size(%13)
  %39 : int[] = aten::size(%15)
  %40 : int[] = aten::size(%12)
  %41 : int[] = aten::size(%14)
  %42 : int[] = aten::size(%16)
  %43 : int[] = aten::size(%17)
  %44 : int[] = aten::size(%18)
  %45 : int[] = aten::size(%19)
  %46 : Tensor, %47 : Tensor, %48 : Tensor, %49 : Tensor, %50 : Tensor, %51 : Tensor, %52 : Tensor, %53 : Tensor = prim::FusionGroup_0(%15, %14, %12, %wi, %13, %18, %19, %21, %20, %24, %hi, %area_u, %area_i, %0, %37, %27, %26, %30, %16, %17)
  %54 : Tensor = prim::AutodiffGradSumToSize(%46, %45)
  %55 : Tensor = prim::AutodiffGradSumToSize(%47, %41)
  %56 : Tensor = prim::AutodiffGradSumToSize(%48, %39)
  %57 : Tensor = prim::AutodiffGradSumToSize(%49, %43)
  %58 : Tensor = prim::AutodiffGradSumToSize(%50, %42)
  %59 : Tensor = prim::AutodiffGradSumToSize(%51, %38)
  %60 : Tensor = prim::AutodiffGradSumToSize(%52, %40)
  %61 : Tensor = prim::AutodiffGradSumToSize(%53, %44)
  return (%60, %59, %55, %56, %58, %57, %61, %54);
}
with prim::FusionGroup_0 = graph(%23 : Float(*)
      %42 : Float(*)
      %103 : Float(*)
      %122 : Float(*)
      %139 : Float(*)
      %161 : Float(*)
      %162 : Float(*)
      %169 : Float(*)
      %170 : Float(*)
      %198 : Float(*)
      %235 : Float(*)
      %271 : Float(*)
      %312 : Float(*)
      %314 : Float(*)
      %316 : Float(*)
      %321 : Float(*)
      %322 : Float(*)
      %336 : Float(*)
      %367 : Float(*)
      %368 : Float(*)) {
  %373 : Byte(*) = aten::gt(%162, %161)
  %372 : Byte(*) = aten::lt(%170, %169)
  %371 : Byte(*) = aten::lt(%322, %321)
  %370 : Byte(*) = aten::gt(%368, %367)
  %369 : Byte(*) = aten::gt(%367, %368)
  %366 : float = prim::Constant[value=inf]()
  %365 : float = prim::Constant[value=inf]()
  %364 : float = prim::Constant[value=inf]()
  %363 : float = prim::Constant[value=inf]()
  %362 : float = prim::Constant[value=inf]()
  %361 : float = prim::Constant[value=inf]()
  %359 : float = prim::Constant[value=inf]()
  %360 : Byte(*) = aten::ge(%336, %359)
  %358 : Float(*) = aten::type_as(%360, %336)
  %356 : Float(*) = aten::neg(%358)
  %354 : int = prim::Constant[value=1]()
  %353 : int = prim::Constant[value=1]()
  %352 : int = prim::Constant[value=1]()
  %351 : int = prim::Constant[value=1]()
  %350 : int = prim::Constant[value=1]()
  %349 : int = prim::Constant[value=1]()
  %347 : int = prim::Constant[value=1]()
  %348 : Float(*) = aten::add(%356, %347, %347)
  %345 : int = prim::Constant[value=0]()
  %344 : int = prim::Constant[value=0]()
  %343 : int = prim::Constant[value=0]()
  %342 : int = prim::Constant[value=0]()
  %341 : int = prim::Constant[value=0]()
  %340 : int = prim::Constant[value=0]()
  %338 : int = prim::Constant[value=0]()
  %339 : Byte(*) = aten::le(%336, %338)
  %337 : Float(*) = aten::type_as(%339, %336)
  %334 : Float(*) = aten::neg(%337)
  %332 : int = prim::Constant[value=1]()
  %331 : int = prim::Constant[value=1]()
  %330 : int = prim::Constant[value=1]()
  %329 : int = prim::Constant[value=1]()
  %328 : int = prim::Constant[value=1]()
  %327 : int = prim::Constant[value=1]()
  %325 : int = prim::Constant[value=1]()
  %326 : Float(*) = aten::add(%334, %325, %325)
  %323 : Byte(*) = aten::lt(%321, %322)
  %320 : Float(*) = aten::div(%314, %316)
  %317 : Float(*) = aten::mul(%316, %316)
  %315 : Float(*) = aten::neg(%314)
  %313 : Float(*) = aten::mul(%315, %312)
  %310 : Float(*) = aten::div(%313, %317)
  %304 : float = prim::Constant[value=inf]()
  %303 : float = prim::Constant[value=inf]()
  %302 : float = prim::Constant[value=inf]()
  %301 : float = prim::Constant[value=inf]()
  %300 : float = prim::Constant[value=inf]()
  %299 : float = prim::Constant[value=inf]()
  %298 : float = prim::Constant[value=inf]()
  %296 : float = prim::Constant[value=inf]()
  %297 : Byte(*) = aten::ge(%271, %296)
  %295 : Float(*) = aten::type_as(%297, %271)
  %293 : Float(*) = aten::neg(%295)
  %291 : int = prim::Constant[value=1]()
  %290 : int = prim::Constant[value=1]()
  %289 : int = prim::Constant[value=1]()
  %288 : int = prim::Constant[value=1]()
  %287 : int = prim::Constant[value=1]()
  %286 : int = prim::Constant[value=1]()
  %285 : int = prim::Constant[value=1]()
  %283 : int = prim::Constant[value=1]()
  %284 : Float(*) = aten::add(%293, %283, %283)
  %281 : float = prim::Constant[value=1e-05]()
  %280 : float = prim::Constant[value=1e-05]()
  %279 : float = prim::Constant[value=1e-05]()
  %278 : float = prim::Constant[value=1e-05]()
  %277 : float = prim::Constant[value=1e-05]()
  %276 : float = prim::Constant[value=1e-05]()
  %275 : float = prim::Constant[value=1e-05]()
  %273 : float = prim::Constant[value=1e-05]()
  %274 : Byte(*) = aten::le(%271, %273)
  %272 : Float(*) = aten::type_as(%274, %271)
  %269 : Float(*) = aten::neg(%272)
  %267 : int = prim::Constant[value=1]()
  %266 : int = prim::Constant[value=1]()
  %265 : int = prim::Constant[value=1]()
  %264 : int = prim::Constant[value=1]()
  %263 : int = prim::Constant[value=1]()
  %262 : int = prim::Constant[value=1]()
  %261 : int = prim::Constant[value=1]()
  %259 : int = prim::Constant[value=1]()
  %260 : Float(*) = aten::add(%269, %259, %259)
  %257 : Tensor = aten::mul(%310, %260)
  %254 : Tensor = aten::mul(%257, %284)
  %251 : Tensor = aten::neg(%254)
  %247 : int = prim::Constant[value=1]()
  %246 : int = prim::Constant[value=1]()
  %245 : int = prim::Constant[value=1]()
  %244 : int = prim::Constant[value=1]()
  %243 : int = prim::Constant[value=1]()
  %242 : int = prim::Constant[value=1]()
  %241 : int = prim::Constant[value=1]()
  %239 : int = prim::Constant[value=1]()
  %240 : Tensor = aten::add(%320, %251, %239)
  %236 : Tensor = aten::mul(%240, %235)
  %231 : float = prim::Constant[value=inf]()
  %230 : float = prim::Constant[value=inf]()
  %229 : float = prim::Constant[value=inf]()
  %228 : float = prim::Constant[value=inf]()
  %227 : float = prim::Constant[value=inf]()
  %226 : float = prim::Constant[value=inf]()
  %225 : float = prim::Constant[value=inf]()
  %223 : float = prim::Constant[value=inf]()
  %224 : Byte(*) = aten::ge(%198, %223)
  %222 : Float(*) = aten::type_as(%224, %198)
  %220 : Float(*) = aten::neg(%222)
  %218 : int = prim::Constant[value=1]()
  %217 : int = prim::Constant[value=1]()
  %216 : int = prim::Constant[value=1]()
  %215 : int = prim::Constant[value=1]()
  %214 : int = prim::Constant[value=1]()
  %213 : int = prim::Constant[value=1]()
  %212 : int = prim::Constant[value=1]()
  %210 : int = prim::Constant[value=1]()
  %211 : Float(*) = aten::add(%220, %210, %210)
  %208 : int = prim::Constant[value=0]()
  %207 : int = prim::Constant[value=0]()
  %206 : int = prim::Constant[value=0]()
  %205 : int = prim::Constant[value=0]()
  %204 : int = prim::Constant[value=0]()
  %203 : int = prim::Constant[value=0]()
  %202 : int = prim::Constant[value=0]()
  %200 : int = prim::Constant[value=0]()
  %201 : Byte(*) = aten::le(%198, %200)
  %199 : Float(*) = aten::type_as(%201, %198)
  %196 : Float(*) = aten::neg(%199)
  %194 : int = prim::Constant[value=1]()
  %193 : int = prim::Constant[value=1]()
  %192 : int = prim::Constant[value=1]()
  %191 : int = prim::Constant[value=1]()
  %190 : int = prim::Constant[value=1]()
  %189 : int = prim::Constant[value=1]()
  %188 : int = prim::Constant[value=1]()
  %186 : int = prim::Constant[value=1]()
  %187 : Float(*) = aten::add(%196, %186, %186)
  %184 : Tensor = aten::mul(%236, %187)
  %181 : Tensor = aten::mul(%184, %211)
  %176 : Tensor = aten::neg(%181)
  %171 : Byte(*) = aten::lt(%169, %170)
  %168 : Tensor = aten::type_as(%171, %181)
  %166 : Tensor = aten::mul(%181, %168)
  %163 : Byte(*) = aten::gt(%161, %162)
  %160 : Tensor = aten::type_as(%163, %176)
  %158 : Tensor = aten::mul(%176, %160)
  %153 : int = prim::Constant[value=1]()
  %152 : int = prim::Constant[value=1]()
  %151 : int = prim::Constant[value=1]()
  %150 : int = prim::Constant[value=1]()
  %149 : int = prim::Constant[value=1]()
  %148 : int = prim::Constant[value=1]()
  %147 : int = prim::Constant[value=1]()
  %145 : int = prim::Constant[value=1]()
  %146 : Tensor = aten::add(%166, %158, %145)
  %140 : Tensor = aten::mul(%254, %139)
  %133 : int = prim::Constant[value=1]()
  %132 : int = prim::Constant[value=1]()
  %131 : int = prim::Constant[value=1]()
  %130 : int = prim::Constant[value=1]()
  %129 : int = prim::Constant[value=1]()
  %128 : int = prim::Constant[value=1]()
  %126 : int = prim::Constant[value=1]()
  %127 : Tensor = aten::add(%140, %166, %126)
  %123 : Tensor = aten::mul(%240, %122)
  %117 : Tensor = aten::mul(%123, %326)
  %114 : Tensor = aten::mul(%117, %348)
  %109 : Tensor = aten::type_as(%323, %114)
  %107 : Tensor = aten::mul(%114, %109)
  %104 : Tensor = aten::mul(%254, %103)
  %98 : int = prim::Constant[value=1]()
  %97 : int = prim::Constant[value=1]()
  %96 : int = prim::Constant[value=1]()
  %95 : int = prim::Constant[value=1]()
  %94 : int = prim::Constant[value=1]()
  %92 : int = prim::Constant[value=1]()
  %93 : Tensor = aten::add(%104, %107, %92)
  %89 : Tensor = aten::neg(%114)
  %84 : Tensor = aten::type_as(%369, %89)
  %82 : Tensor = aten::mul(%89, %84)
  %75 : int = prim::Constant[value=1]()
  %74 : int = prim::Constant[value=1]()
  %73 : int = prim::Constant[value=1]()
  %72 : int = prim::Constant[value=1]()
  %70 : int = prim::Constant[value=1]()
  %71 : Tensor = aten::add(%107, %82, %70)
  %67 : Tensor = aten::type_as(%371, %114)
  %65 : Tensor = aten::mul(%114, %67)
  %61 : Tensor = aten::type_as(%370, %89)
  %59 : Tensor = aten::mul(%89, %61)
  %53 : int = prim::Constant[value=1]()
  %52 : int = prim::Constant[value=1]()
  %51 : int = prim::Constant[value=1]()
  %49 : int = prim::Constant[value=1]()
  %50 : Tensor = aten::add(%65, %59, %49)
  %43 : Tensor = aten::mul(%254, %42)
  %36 : int = prim::Constant[value=1]()
  %35 : int = prim::Constant[value=1]()
  %33 : int = prim::Constant[value=1]()
  %34 : Tensor = aten::add(%43, %65, %33)
  %30 : Tensor = aten::type_as(%372, %181)
  %28 : Tensor = aten::mul(%181, %30)
  %24 : Tensor = aten::mul(%254, %23)
  %18 : int = prim::Constant[value=1]()
  %16 : int = prim::Constant[value=1]()
  %17 : Tensor = aten::add(%24, %28, %16)
  %13 : Tensor = aten::type_as(%373, %176)
  %11 : Tensor = aten::mul(%176, %13)
  %2 : int = prim::Constant[value=1]()
  %3 : Tensor = aten::add(%28, %11, %2)
  return (%3, %17, %34, %50, %71, %93, %127, %146);
}

That's all for now. I hope you enjoyed this little demo. I hope you enjoyed it and appreciate your feedback and comments at tv@lernapparat.de.

On my blog https://lernapparat.de/ you'll find the slides from the talk that this demonstration accompanies.


In [ ]: