Tutorial: From physics to tuned GPU kernels

This tutorial is designed to show you the whole process starting from modeling a physical process to a Python implementation to creating optimized and auto-tuned GPU application using Kernel Tuner.

In this tutorial, we will use diffusion as an example application.

We start with modeling the physical process of diffusion, for which we create a simple numerical implementation in Python. Then we create a CUDA kernel that performs the same computation, but on the GPU. Once we have a CUDA kernel, we start using the Kernel Tuner for auto-tuning our GPU application. And finally, we'll introduce a few code optimizations to our CUDA kernel that will improve performance, but also add more parameters to tune on using the Kernel Tuner.

**Note:** If you are reading this tutorial on the Kernel Tuner's documentation pages, note that you can actually run this tutorial as a Jupyter Notebook. Just clone the Kernel Tuner's [GitHub repository](http://github.com/benvanwerkhoven/kernel_tuner). Install the Kernel Tuner and Jupyter Notebooks and you're ready to go! You can start the tutorial by typing "jupyter notebook" in the "kernel_tuner/tutorial" directory.

Diffusion

Put simply, diffusion is the redistribution of something from a region of high concentration to a region of low concentration without bulk motion. The concept of diffusion is widely used in many fields, including physics, chemistry, biology, and many more.

Suppose that we take a metal sheet, in which the temperature is exactly equal to one degree everywhere in the sheet. Now if we were to heat a number of points on the sheet to a very high temperature, say a thousand degrees, in an instant by some method. We could see the heat diffuse from these hotspots to the cooler areas. We are assuming that the metal does not melt. In addition, we will ignore any heat loss from radiation or other causes in this example.

We can use the diffusion equation to model how the heat diffuses through our metal sheet:

\begin{equation*} \frac{\partial u}{\partial t}= D \left( \frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2} \right) \end{equation*}

Where $x$ and $y$ represent the spatial descretization of our 2D domain, $u$ is the quantity that is being diffused, $t$ is the descretization in time, and the constant $D$ determines how fast the diffusion takes place.

In this example, we will assume a very simple descretization of our problem. We assume that our 2D domain has $nx$ equi-distant grid points in the x-direction and $ny$ equi-distant grid points in the y-direction. Be sure to execute every cell as you read through this document, by selecting it and pressing shift+enter.


In [1]:
nx = 1024
ny = 1024

This results in a constant distance of $\delta x$ between all grid points in the $x$ dimension. Using central differences, we can numerically approximate the derivative for a given point $x_i$:

\begin{equation*} \left. \frac{\partial^2 u}{\partial x^2} \right|_{x_{i}} \approx \frac{u_{x_{i+1}}-2u_{{x_i}}+u_{x_{i-1}}}{(\delta x)^2} \end{equation*}

We do the same for the partial derivative in $y$:

\begin{equation*} \left. \frac{\partial^2 u}{\partial y^2} \right|_{y_{i}} \approx \frac{u_{y_{i+1}}-2u_{y_{i}}+u_{y_{i-1}}}{(\delta y)^2} \end{equation*}

If we combine the above equations, we can obtain a numerical estimation for the temperature field of our metal sheet in the next time step, using $\delta t$ as the time between time steps. But before we do, we also simplify the expression a little bit, because we'll assume that $\delta x$ and $\delta y$ are always equal to 1.

\begin{equation*} u'_{x,y} = u_{x,y} + \delta t \times \left( \left( u_{x_{i+1},y}-2u_{{x_i},y}+u_{x_{i-1},y} \right) + \left( u_{x,y_{i+1}}-2u_{x,y_{i}}+u_{x,y_{i-1}} \right) \right) \end{equation*}

In this formula $u'_{x,y}$ refers to the temperature field at the time $t + \delta t$. As a final step, we further simplify this equation to:

\begin{equation*} u'_{x,y} = u_{x,y} + \delta t \times \left( u_{x,y_{i+1}}+u_{x_{i+1},y}-4u_{{x_i},y}+u_{x_{i-1},y}+u_{x,y_{i-1}} \right) \end{equation*}

Python implementation

We can create a Python function that implements the numerical approximation defined in the above equation. For simplicity we'll use the assumption of a free boundary condition.


In [2]:
def diffuse(field, dt=0.225):
    field[1:nx-1,1:ny-1] = field[1:nx-1,1:ny-1] + dt * (
        field[1:nx-1,2:ny]+field[2:nx,1:ny-1]-4*field[1:nx-1,1:ny-1]+
        field[0:nx-2,1:ny-1]+field[1:nx-1,0:ny-2] ) 
    return field

To give our Python function a test run, we will now do some imports and generate the input data for the initial conditions of our metal sheet with a few very hot points. We'll also make two plots, one after a thousand time steps, and a second plot after another two thousand time steps. Do note that the plots are using different ranges for the colors. Also, executing the following cell may take a little while.


In [3]:
import numpy

#setup initial conditions
def get_initial_conditions(nx, ny):
    field = numpy.ones((ny, nx)).astype(numpy.float32)
    field[numpy.random.randint(0,nx,size=10), numpy.random.randint(0,ny,size=10)] = 1e3
    return field
field = get_initial_conditions(nx, ny)

We can now use this initial condition to solve the diffusion problem and plot the results.


In [4]:
from matplotlib import pyplot
%matplotlib inline

#run the diffuse function a 1000 times and another 2000 times and make plots
fig, (ax1, ax2) = pyplot.subplots(1,2)
cpu=numpy.copy(field)
for i in range(1000):
    cpu = diffuse(cpu)
ax1.imshow(cpu)
for i in range(2000):
    cpu = diffuse(cpu)
ax2.imshow(cpu)


Out[4]:
<matplotlib.image.AxesImage at 0x7f888f8cd7b8>

Now let's take a quick look at the execution time of our diffuse function. Before we do, we also copy the current state of the metal sheet to be able to restart the computation from this state.


In [5]:
#run another 1000 steps of the diffuse function and measure the time
from time import time
start = time()
cpu=numpy.copy(field)
for i in range(1000):
    cpu = diffuse(cpu)
end = time()
print("1000 steps of diffuse on a %d x %d grid took" %(nx,ny), (end-start)*1000.0, "ms")
pyplot.imshow(cpu)


1000 steps of diffuse on a 1024 x 1024 grid took 4152.086019515991 ms
Out[5]:
<matplotlib.image.AxesImage at 0x7f8865b51f28>

Computing on the GPU

The next step in this tutorial is to implement a GPU kernel that will allow us to run our problem on the GPU. We store the kernel code in a Python string, because we can directly compile and run the kernel from Python. In this tutorial, we'll use the CUDA programming model to implement our kernels.

If you prefer OpenCL over CUDA, don't worry. Everything in this tutorial applies as much to OpenCL as it does to CUDA. But we will use CUDA for our examples, and CUDA terminology in the text.


In [6]:
def get_kernel_string(nx, ny):
    return """
    #define nx %d
    #define ny %d
    #define dt 0.225f
    __global__ void diffuse_kernel(float *u_new, float *u) {
        int x = blockIdx.x * block_size_x + threadIdx.x;
        int y = blockIdx.y * block_size_y + threadIdx.y;

        if (x>0 && x<nx-1 && y>0 && y<ny-1) {
            u_new[y*nx+x] = u[y*nx+x] + dt * ( 
                u[(y+1)*nx+x]+u[y*nx+x+1]-4.0f*u[y*nx+x]+u[y*nx+x-1]+u[(y-1)*nx+x]);
        }
    }
    """ % (nx, ny)
kernel_string = get_kernel_string(nx, ny)

The above CUDA kernel parallelizes the work such that every grid point will be processed by a different CUDA thread. Therefore, the kernel is executed by a 2D grid of threads, which are grouped together into 2D thread blocks. The specific thread block dimensions we choose are not important for the result of the computation in this kernel. But as we will see will later, they will have an impact on performance.

In this kernel we are using two, currently undefined, compile-time constants for block_size_x and block_size_y, because we will auto tune these parameters later. It is often needed for performance to fix the thread block dimensions at compile time, because the compiler can unroll loops that iterate using the block size, or because you need to allocate shared memory using the thread block dimensions.

The next bit of Python code initializes PyCuda, and makes preparations so that we can call the CUDA kernel to do the computation on the GPU as we did earlier in Python.


In [7]:
from pycuda import driver, compiler, gpuarray, tools
import pycuda.autoinit
from time import time

#allocate GPU memory
u_old = gpuarray.to_gpu(field)
u_new = gpuarray.to_gpu(field)

#setup thread block dimensions and compile the kernel
threads = (16,16,1)
grid = (int(nx/16), int(ny/16), 1)
block_size_string = "#define block_size_x 16\n#define block_size_y 16\n"
mod = compiler.SourceModule(block_size_string+kernel_string)
diffuse_kernel = mod.get_function("diffuse_kernel")

The above code is a bit of boilerplate we need to compile a kernel using PyCuda. We've also, for the moment, fixed the thread block dimensions at 16 by 16. These dimensions serve as our initial guess for what a good performing pair of thread block dimensions could look like.

Now that we've setup everything, let's see how long the computation would take using the GPU.


In [8]:
#call the GPU kernel a 1000 times and measure performance
t0 = time()
for i in range(500):
    diffuse_kernel(u_new, u_old, block=threads, grid=grid)
    diffuse_kernel(u_old, u_new, block=threads, grid=grid)
driver.Context.synchronize()
print("1000 steps of diffuse ona %d x %d grid took" %(nx,ny), (time()-t0)*1000, "ms.")

#copy the result from the GPU to Python for plotting
gpu_result = u_old.get()
fig, (ax1, ax2) = pyplot.subplots(1,2)
ax1.imshow(gpu_result)
ax1.set_title("GPU Result")
ax2.imshow(cpu)
ax2.set_title("Python Result")


1000 steps of diffuse ona 1024 x 1024 grid took 33.46109390258789 ms.
Out[8]:
<matplotlib.text.Text at 0x7f8858b873c8>

That should already be a lot faster than our previous Python implementation, but we can do much better if we optimize our GPU kernel. And that is exactly what the rest of this tutorial is about!

Also, if you think the Python boilerplate code to call a GPU kernel was a bit messy, we've got good news for you! From now on, we'll only use the Kernel Tuner to compile and benchmark GPU kernels, which we can do with much cleaner Python code.

Auto-Tuning with the Kernel Tuner

Remember that previously we've set the thread block dimensions to 16 by 16. But how do we actually know if that is the best performing setting? That is where auto-tuning comes into play. Basically, it is very difficult to provide an answer through performance modeling and as such, we'd rather use the Kernel Tuner to compile and benchmark all possible kernel configurations.

But before we continue, we'll increase the problem size, because the GPU is very likely underutilized.


In [9]:
nx = 4096
ny = 4096
field = get_initial_conditions(nx, ny)
kernel_string = get_kernel_string(nx, ny)

The above code block has generated new initial conditions and a new string that contains our CUDA kernel using our new domain size.

To call the Kernel Tuner, we have to specify the tunable parameters, in our case block_size_x and block_size_y. For this purpose, we'll create an ordered dictionary to store the tunable parameters. The keys will be the name of the tunable parameter, and the corresponding value is the list of possible values for the parameter. For the purpose of this tutorial, we'll use a small number of commonly used values for the thread block dimensions, but feel free to try more!


In [10]:
from collections import OrderedDict
tune_params = OrderedDict()
tune_params["block_size_x"] = [16, 32, 48, 64, 128]
tune_params["block_size_y"] = [2, 4, 8, 16, 32]

We also have to tell the Kernel Tuner about the argument list of our CUDA kernel. Because the Kernel Tuner will be calling the CUDA kernel and measure its execution time. For this purpose we create a list in Python, that corresponds with the argument list of the diffuse_kernel CUDA function. This list will only be used as input to the kernel during tuning. The objects in the list should be Numpy arrays or scalars.

Because you can specify the arguments as Numpy arrays, the Kernel Tuner will take care of allocating GPU memory and copying the data to the GPU.


In [11]:
args = [field, field]

We're almost ready to call the Kernel Tuner, we just need to set how large the problem is we are currently working on by setting a problem_size. The Kernel Tuner knows about thread block dimensions, which it expects to be called block_size_x, block_size_y, and/or block_size_z. From these and the problem_size, the Kernel Tuner will compute the appropiate grid dimensions on the fly.


In [12]:
problem_size = (nx, ny)

And that's everything the Kernel Tuner needs to know to be able to start tuning our kernel. Let's give it a try by executing the next code block!


In [13]:
from kernel_tuner import tune_kernel
result = tune_kernel("diffuse_kernel", kernel_string, problem_size, args, tune_params)


Using: GeForce GTX 1080 Ti
block_size_x=16, block_size_y=2, time=0.916985595226
block_size_x=16, block_size_y=4, time=0.489004802704
block_size_x=16, block_size_y=8, time=0.500524806976
block_size_x=16, block_size_y=16, time=0.513356792927
block_size_x=16, block_size_y=32, time=0.545715200901
block_size_x=32, block_size_y=2, time=0.486515200138
block_size_x=32, block_size_y=4, time=0.449055999517
block_size_x=32, block_size_y=8, time=0.44974719882
block_size_x=32, block_size_y=16, time=0.457427197695
block_size_x=32, block_size_y=32, time=0.492915201187
block_size_x=48, block_size_y=2, time=0.464863997698
block_size_x=48, block_size_y=4, time=0.466118401289
block_size_x=48, block_size_y=8, time=0.475264000893
block_size_x=48, block_size_y=16, time=0.513632011414
block_size_x=64, block_size_y=2, time=0.458412796259
block_size_x=64, block_size_y=4, time=0.457715201378
block_size_x=64, block_size_y=8, time=0.461017608643
block_size_x=64, block_size_y=16, time=0.475987195969
block_size_x=128, block_size_y=2, time=0.460032004118
block_size_x=128, block_size_y=4, time=0.457779198885
block_size_x=128, block_size_y=8, time=0.462649595737
best performing configuration: block_size_x=32, block_size_y=4, time=0.449055999517

Note that the Kernel Tuner prints a lot of useful information. To ensure you'll be able to tell what was measured in this run the Kernel Tuner always prints the GPU or OpenCL Device name that is being used, as well as the name of the kernel. After that every line contains the combination of parameters and the time that was measured during benchmarking. The time that is being printed is in milliseconds and is obtained by averaging the execution time of 7 runs of the kernel. Finally, as a matter of convenience, the Kernel Tuner also prints the best performing combination of tunable parameters. However, later on in this tutorial we'll explain how to analyze and store the tuning results using Python.

Looking at the results printed above, the difference in performance between the different kernel configurations may seem very little. However, on our hardware, the performance of this kernel already varies in the order of 10%. Which of course can build up to large differences in the execution time if the kernel is to be executed thousands of times. We can also see that the performance of the best configuration in this set is 5% better than our initially guessed thread block dimensions of 16 by 16.

In addtion, you may notice that not all possible combinations of values for block_size_x and block_size_y are among the results. For example, 128x32 is not among the results. This is because some configuration require more threads per thread block than allowed on our GPU. The Kernel Tuner checks the limitations of your GPU at runtime and automatically skips over configurations that use too many threads per block. It will also do this for kernels that cannot be compiled because they use too much shared memory. And likewise for kernels that use too many registers to be launched at runtime. If you'd like to know about which configurations were skipped automatically you can pass the optional parameter verbose=True to tune_kernel.

However, knowing the best performing combination of tunable parameters becomes even more important when we start to further optimize our CUDA kernel. In the next section, we'll add a simple code optimization and show how this affects performance.

Using shared memory

Shared memory, is a special type of the memory available in CUDA. Shared memory can be used by threads within the same thread block to exchange and share values. It is in fact, one of the very few ways for threads to communicate on the GPU.

The idea is that we'll try improve the performance of our kernel by using shared memory as a software controlled cache. There are already caches on the GPU, but most GPUs only cache accesses to global memory in L2. Shared memory is closer to the multiprocessors where the thread blocks are executed, comparable to an L1 cache.

However, because there are also hardware caches, the performance improvement from this step is expected to not be that great. The more fine-grained control that we get by using a software managed cache, rather than a hardware implemented cache, comes at the cost of some instruction overhead. In fact, performance is quite likely to degrade a little. However, this intermediate step is necessary for the next optimization step we have in mind.


In [14]:
kernel_string_shared = """ 
#define nx %d
#define ny %d
#define dt 0.225f
__global__ void diffuse_kernel(float *u_new, float *u) {

    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int bx = blockIdx.x * block_size_x;
    int by = blockIdx.y * block_size_y;

    __shared__ float sh_u[block_size_y+2][block_size_x+2];

    #pragma unroll
    for (int i = ty; i<block_size_y+2; i+=block_size_y) {
        #pragma unroll
        for (int j = tx; j<block_size_x+2; j+=block_size_x) {
            int y = by+i-1;
            int x = bx+j-1;
            if (x>=0 && x<nx && y>=0 && y<ny) {
                sh_u[i][j] = u[y*nx+x];
            }
        }
    }
    __syncthreads();
    
    int x = bx+tx;
    int y = by+ty;
    if (x>0 && x<nx-1 && y>0 && y<ny-1) {
        int i = ty+1;
        int j = tx+1;
        u_new[y*nx+x] = sh_u[i][j] + dt * ( 
            sh_u[i+1][j] + sh_u[i][j+1] -4.0f * sh_u[i][j] +
            sh_u[i][j-1] + sh_u[i-1][j] );
    }    

}
""" % (nx, ny)

We can now tune this new kernel using the kernel tuner


In [15]:
result = tune_kernel("diffuse_kernel", kernel_string_shared, problem_size, args, tune_params)


Using: GeForce GTX 1080 Ti
block_size_x=16, block_size_y=2, time=1.22673916817
block_size_x=16, block_size_y=4, time=0.826361596584
block_size_x=16, block_size_y=8, time=0.793516802788
block_size_x=16, block_size_y=16, time=0.782112002373
block_size_x=16, block_size_y=32, time=0.776639997959
block_size_x=32, block_size_y=2, time=0.795135998726
block_size_x=32, block_size_y=4, time=0.722777605057
block_size_x=32, block_size_y=8, time=0.762777590752
block_size_x=32, block_size_y=16, time=0.75422719717
block_size_x=32, block_size_y=32, time=0.804876792431
block_size_x=48, block_size_y=2, time=0.778656005859
block_size_x=48, block_size_y=4, time=0.769734406471
block_size_x=48, block_size_y=8, time=0.782495999336
block_size_x=48, block_size_y=16, time=0.932281601429
block_size_x=64, block_size_y=2, time=0.734028804302
block_size_x=64, block_size_y=4, time=0.721625590324
block_size_x=64, block_size_y=8, time=0.736511993408
block_size_x=64, block_size_y=16, time=0.800019192696
block_size_x=128, block_size_y=2, time=0.724966406822
block_size_x=128, block_size_y=4, time=0.722969603539
block_size_x=128, block_size_y=8, time=0.759430396557
best performing configuration: block_size_x=64, block_size_y=4, time=0.721625590324

Tiling GPU Code

One very useful code optimization is called tiling, sometimes also called thread-block-merge. You can look at it in this way, currently we have many thread blocks that together work on the entire domain. If we were to use only half of the number of thread blocks, every thread block would need to double the amount of work it performs to cover the entire domain. However, the threads may be able to reuse part of the data and computation that is required to process a single output element for every element beyond the first.

This is a code optimization because effectively we are reducing the total number of instructions executed by all threads in all thread blocks. So in a way, were are condensing the total instruction stream while keeping the all the really necessary compute instructions. More importantly, we are increasing data reuse, where previously these values would have been reused from the cache or in the worst-case from GPU memory.

We can apply tiling in both the x and y-dimensions. This also introduces two new tunable parameters, namely the tiling factor in x and y, which we will call tile_size_x and tile_size_y. This is what the new kernel looks like:


In [16]:
kernel_string_tiled = """ 
#define nx %d
#define ny %d
#define dt 0.225f
__global__ void diffuse_kernel(float *u_new, float *u) {

    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int bx = blockIdx.x * block_size_x * tile_size_x;
    int by = blockIdx.y * block_size_y * tile_size_y;

    __shared__ float sh_u[block_size_y*tile_size_y+2][block_size_x*tile_size_x+2];

    #pragma unroll
    for (int i = ty; i<block_size_y*tile_size_y+2; i+=block_size_y) {
        #pragma unroll
        for (int j = tx; j<block_size_x*tile_size_x+2; j+=block_size_x) {
            int y = by+i-1;
            int x = bx+j-1;
            if (x>=0 && x<nx && y>=0 && y<ny) {
                sh_u[i][j] = u[y*nx+x];
            }
        }
    }
    __syncthreads();

    int y = by+ty;
    int x = bx+tx;
    
    #pragma unroll
    for (int tj=0; tj<tile_size_y; tj++) {
        int i = ty+tj*block_size_y+1;
        #pragma unroll
        for (int ti=0; ti<tile_size_x; ti++) {
            int j = tx+ti*block_size_x+1;
            if (x>0 && x+ti*block_size_x<nx-1 && y>0 && y+tj*block_size_y<ny-1) {
                u_new[y*nx+x+ti*block_size_x] = sh_u[i][j] + dt * ( 
                    sh_u[i+1][j] + sh_u[i][j+1] -4.0f * sh_u[i][j] +
                    sh_u[i][j-1] + sh_u[i-1][j] );
            }
        }
    
    }

}
""" % (nx, ny)

We can tune our tiled kernel by adding the two new tunable parameters to our dictionary tune_params.

We also need to somehow tell the Kernel Tuner to use fewer thread blocks to launch kernels with tile_size_x or tile_size_y larger than one. For this purpose the Kernel Tuner's tune_kernel function supports two optional arguments, called grid_div_x and grid_div_y. These are the grid divisor lists, which are lists of strings containing all the tunable parameters that divide a certain grid dimension. So far, we have been using the default settings for these, in which case the Kernel Tuner only uses the block_size_x and block_size_y tunable parameters to divide the problem_size.

Note that the Kernel Tuner will replace the values of the tunable parameters inside the strings and use the product of the parameters in the grid divisor list to compute the grid dimension rounded up. You can even use arithmetic operations, inside these strings as they will be evaluated. As such, we could have used ["block_size_x*tile_size_x"] to get the same result.

We are now ready to call the Kernel Tuner again and tune our tiled kernel. Let's execute the following code block, note that it may take a while as the number of kernel configurations that the Kernel Tuner will try has just been increased with a factor of 9!


In [17]:
tune_params["tile_size_x"] = [1,2,4]            #add tile_size_x to the tune_params
tune_params["tile_size_y"] = [1,2,4]            #add tile_size_y to the tune_params
grid_div_x = ["block_size_x", "tile_size_x"]    #tile_size_x impacts grid dimensions
grid_div_y = ["block_size_y", "tile_size_y"]    #tile_size_y impacts grid dimensions
result = tune_kernel("diffuse_kernel", kernel_string_tiled, problem_size, args,
                     tune_params, grid_div_x=grid_div_x, grid_div_y=grid_div_y)


Using: GeForce GTX 1080 Ti
block_size_x=16, block_size_y=2, tile_size_x=1, tile_size_y=1, time=1.22200961113
block_size_x=16, block_size_y=2, tile_size_x=1, tile_size_y=2, time=0.91601279974
block_size_x=16, block_size_y=2, tile_size_x=1, tile_size_y=4, time=0.752838408947
block_size_x=16, block_size_y=2, tile_size_x=2, tile_size_y=1, time=0.873651194572
block_size_x=16, block_size_y=2, tile_size_x=2, tile_size_y=2, time=0.69833599329
block_size_x=16, block_size_y=2, tile_size_x=2, tile_size_y=4, time=0.586931192875
block_size_x=16, block_size_y=2, tile_size_x=4, tile_size_y=1, time=0.516473591328
block_size_x=16, block_size_y=2, tile_size_x=4, tile_size_y=2, time=0.411392003298
block_size_x=16, block_size_y=2, tile_size_x=4, tile_size_y=4, time=0.384262400866
block_size_x=16, block_size_y=4, tile_size_x=1, tile_size_y=1, time=0.82159358263
block_size_x=16, block_size_y=4, tile_size_x=1, tile_size_y=2, time=0.632607996464
block_size_x=16, block_size_y=4, tile_size_x=1, tile_size_y=4, time=0.506457602978
block_size_x=16, block_size_y=4, tile_size_x=2, tile_size_y=1, time=0.618758392334
block_size_x=16, block_size_y=4, tile_size_x=2, tile_size_y=2, time=0.500288009644
block_size_x=16, block_size_y=4, tile_size_x=2, tile_size_y=4, time=0.429862397909
block_size_x=16, block_size_y=4, tile_size_x=4, tile_size_y=1, time=0.44995200038
block_size_x=16, block_size_y=4, tile_size_x=4, tile_size_y=2, time=0.366150397062
block_size_x=16, block_size_y=4, tile_size_x=4, tile_size_y=4, time=0.342201602459
block_size_x=16, block_size_y=8, tile_size_x=1, tile_size_y=1, time=0.793542397022
block_size_x=16, block_size_y=8, tile_size_x=1, tile_size_y=2, time=0.58026239872
block_size_x=16, block_size_y=8, tile_size_x=1, tile_size_y=4, time=0.494163197279
block_size_x=16, block_size_y=8, tile_size_x=2, tile_size_y=1, time=0.546316814423
block_size_x=16, block_size_y=8, tile_size_x=2, tile_size_y=2, time=0.467059195042
block_size_x=16, block_size_y=8, tile_size_x=2, tile_size_y=4, time=0.404249596596
block_size_x=16, block_size_y=8, tile_size_x=4, tile_size_y=1, time=0.440895992517
block_size_x=16, block_size_y=8, tile_size_x=4, tile_size_y=2, time=0.341376006603
block_size_x=16, block_size_y=8, tile_size_x=4, tile_size_y=4, time=0.339692795277
block_size_x=16, block_size_y=16, tile_size_x=1, tile_size_y=1, time=0.783923208714
block_size_x=16, block_size_y=16, tile_size_x=1, tile_size_y=2, time=0.597920000553
block_size_x=16, block_size_y=16, tile_size_x=1, tile_size_y=4, time=0.50277120471
block_size_x=16, block_size_y=16, tile_size_x=2, tile_size_y=1, time=0.615475213528
block_size_x=16, block_size_y=16, tile_size_x=2, tile_size_y=2, time=0.470937597752
block_size_x=16, block_size_y=16, tile_size_x=2, tile_size_y=4, time=0.418393599987
block_size_x=16, block_size_y=16, tile_size_x=4, tile_size_y=1, time=0.443519997597
block_size_x=16, block_size_y=16, tile_size_x=4, tile_size_y=2, time=0.343961596489
block_size_x=16, block_size_y=16, tile_size_x=4, tile_size_y=4, time=0.342540800571
block_size_x=16, block_size_y=32, tile_size_x=1, tile_size_y=1, time=0.780352008343
block_size_x=16, block_size_y=32, tile_size_x=1, tile_size_y=2, time=0.611705589294
block_size_x=16, block_size_y=32, tile_size_x=1, tile_size_y=4, time=0.515667212009
block_size_x=16, block_size_y=32, tile_size_x=2, tile_size_y=1, time=0.622534394264
block_size_x=16, block_size_y=32, tile_size_x=2, tile_size_y=2, time=0.502195191383
block_size_x=16, block_size_y=32, tile_size_x=2, tile_size_y=4, time=0.437388807535
block_size_x=16, block_size_y=32, tile_size_x=4, tile_size_y=1, time=0.45568639636
block_size_x=16, block_size_y=32, tile_size_x=4, tile_size_y=2, time=0.359289598465
block_size_x=16, block_size_y=32, tile_size_x=4, tile_size_y=4, time=0.426995199919
block_size_x=32, block_size_y=2, tile_size_x=1, tile_size_y=1, time=0.788947200775
block_size_x=32, block_size_y=2, tile_size_x=1, tile_size_y=2, time=0.616556799412
block_size_x=32, block_size_y=2, tile_size_x=1, tile_size_y=4, time=0.496121603251
block_size_x=32, block_size_y=2, tile_size_x=2, tile_size_y=1, time=0.629164803028
block_size_x=32, block_size_y=2, tile_size_x=2, tile_size_y=2, time=0.474841600657
block_size_x=32, block_size_y=2, tile_size_x=2, tile_size_y=4, time=0.407667201757
block_size_x=32, block_size_y=2, tile_size_x=4, tile_size_y=1, time=0.47406719923
block_size_x=32, block_size_y=2, tile_size_x=4, tile_size_y=2, time=0.371507203579
block_size_x=32, block_size_y=2, tile_size_x=4, tile_size_y=4, time=0.352531200647
block_size_x=32, block_size_y=4, tile_size_x=1, tile_size_y=1, time=0.72023679018
block_size_x=32, block_size_y=4, tile_size_x=1, tile_size_y=2, time=0.574816000462
block_size_x=32, block_size_y=4, tile_size_x=1, tile_size_y=4, time=0.481817597151
block_size_x=32, block_size_y=4, tile_size_x=2, tile_size_y=1, time=0.580928003788
block_size_x=32, block_size_y=4, tile_size_x=2, tile_size_y=2, time=0.455724793673
block_size_x=32, block_size_y=4, tile_size_x=2, tile_size_y=4, time=0.394975996017
block_size_x=32, block_size_y=4, tile_size_x=4, tile_size_y=1, time=0.464659202099
block_size_x=32, block_size_y=4, tile_size_x=4, tile_size_y=2, time=0.357107198238
block_size_x=32, block_size_y=4, tile_size_x=4, tile_size_y=4, time=0.324083191156
block_size_x=32, block_size_y=8, tile_size_x=1, tile_size_y=1, time=0.759910392761
block_size_x=32, block_size_y=8, tile_size_x=1, tile_size_y=2, time=0.569177603722
block_size_x=32, block_size_y=8, tile_size_x=1, tile_size_y=4, time=0.481279999018
block_size_x=32, block_size_y=8, tile_size_x=2, tile_size_y=1, time=0.528115200996
block_size_x=32, block_size_y=8, tile_size_x=2, tile_size_y=2, time=0.441734397411
block_size_x=32, block_size_y=8, tile_size_x=2, tile_size_y=4, time=0.393126398325
block_size_x=32, block_size_y=8, tile_size_x=4, tile_size_y=1, time=0.455404800177
block_size_x=32, block_size_y=8, tile_size_x=4, tile_size_y=2, time=0.350457596779
block_size_x=32, block_size_y=8, tile_size_x=4, tile_size_y=4, time=0.322547197342
block_size_x=32, block_size_y=16, tile_size_x=1, tile_size_y=1, time=0.754201591015
block_size_x=32, block_size_y=16, tile_size_x=1, tile_size_y=2, time=0.579827189445
block_size_x=32, block_size_y=16, tile_size_x=1, tile_size_y=4, time=0.491852802038
block_size_x=32, block_size_y=16, tile_size_x=2, tile_size_y=1, time=0.582751989365
block_size_x=32, block_size_y=16, tile_size_x=2, tile_size_y=2, time=0.451283198595
block_size_x=32, block_size_y=16, tile_size_x=2, tile_size_y=4, time=0.391807991266
block_size_x=32, block_size_y=16, tile_size_x=4, tile_size_y=1, time=0.456275194883
block_size_x=32, block_size_y=16, tile_size_x=4, tile_size_y=2, time=0.356716805696
block_size_x=32, block_size_y=16, tile_size_x=4, tile_size_y=4, time=0.362937599421
block_size_x=32, block_size_y=32, tile_size_x=1, tile_size_y=1, time=0.809894394875
block_size_x=32, block_size_y=32, tile_size_x=1, tile_size_y=2, time=0.60433280468
block_size_x=32, block_size_y=32, tile_size_x=1, tile_size_y=4, time=0.507142400742
block_size_x=32, block_size_y=32, tile_size_x=2, tile_size_y=1, time=0.655827200413
block_size_x=32, block_size_y=32, tile_size_x=2, tile_size_y=2, time=0.474092799425
block_size_x=32, block_size_y=32, tile_size_x=2, tile_size_y=4, time=0.408166396618
block_size_x=32, block_size_y=32, tile_size_x=4, tile_size_y=1, time=0.480531209707
block_size_x=32, block_size_y=32, tile_size_x=4, tile_size_y=2, time=0.346707201004
block_size_x=48, block_size_y=2, tile_size_x=1, tile_size_y=1, time=0.780134403706
block_size_x=48, block_size_y=2, tile_size_x=1, tile_size_y=2, time=0.601049602032
block_size_x=48, block_size_y=2, tile_size_x=1, tile_size_y=4, time=0.493900799751
block_size_x=48, block_size_y=2, tile_size_x=2, tile_size_y=1, time=0.620384001732
block_size_x=48, block_size_y=2, tile_size_x=2, tile_size_y=2, time=0.494553589821
block_size_x=48, block_size_y=2, tile_size_x=2, tile_size_y=4, time=0.425414395332
block_size_x=48, block_size_y=2, tile_size_x=4, tile_size_y=1, time=0.467033600807
block_size_x=48, block_size_y=2, tile_size_x=4, tile_size_y=2, time=0.375468802452
block_size_x=48, block_size_y=2, tile_size_x=4, tile_size_y=4, time=0.346079999208
block_size_x=48, block_size_y=4, tile_size_x=1, tile_size_y=1, time=0.771052801609
block_size_x=48, block_size_y=4, tile_size_x=1, tile_size_y=2, time=0.593977594376
block_size_x=48, block_size_y=4, tile_size_x=1, tile_size_y=4, time=0.49723520875
block_size_x=48, block_size_y=4, tile_size_x=2, tile_size_y=1, time=0.583270406723
block_size_x=48, block_size_y=4, tile_size_x=2, tile_size_y=2, time=0.478079998493
block_size_x=48, block_size_y=4, tile_size_x=2, tile_size_y=4, time=0.416320002079
block_size_x=48, block_size_y=4, tile_size_x=4, tile_size_y=1, time=0.443942397833
block_size_x=48, block_size_y=4, tile_size_x=4, tile_size_y=2, time=0.359744000435
block_size_x=48, block_size_y=4, tile_size_x=4, tile_size_y=4, time=0.343545603752
block_size_x=48, block_size_y=8, tile_size_x=1, tile_size_y=1, time=0.780960011482
block_size_x=48, block_size_y=8, tile_size_x=1, tile_size_y=2, time=0.598758399487
block_size_x=48, block_size_y=8, tile_size_x=1, tile_size_y=4, time=0.498617601395
block_size_x=48, block_size_y=8, tile_size_x=2, tile_size_y=1, time=0.57678719759
block_size_x=48, block_size_y=8, tile_size_x=2, tile_size_y=2, time=0.46561280489
block_size_x=48, block_size_y=8, tile_size_x=2, tile_size_y=4, time=0.41324160099
block_size_x=48, block_size_y=8, tile_size_x=4, tile_size_y=1, time=0.431225597858
block_size_x=48, block_size_y=8, tile_size_x=4, tile_size_y=2, time=0.351263999939
block_size_x=48, block_size_y=8, tile_size_x=4, tile_size_y=4, time=0.34440960288
block_size_x=48, block_size_y=16, tile_size_x=1, tile_size_y=1, time=0.933260798454
block_size_x=48, block_size_y=16, tile_size_x=1, tile_size_y=2, time=0.715257608891
block_size_x=48, block_size_y=16, tile_size_x=1, tile_size_y=4, time=0.586604809761
block_size_x=48, block_size_y=16, tile_size_x=2, tile_size_y=1, time=0.711615991592
block_size_x=48, block_size_y=16, tile_size_x=2, tile_size_y=2, time=0.558771193027
block_size_x=48, block_size_y=16, tile_size_x=2, tile_size_y=4, time=0.466284793615
block_size_x=48, block_size_y=16, tile_size_x=4, tile_size_y=1, time=0.44043520093
block_size_x=48, block_size_y=16, tile_size_x=4, tile_size_y=2, time=0.361823999882
block_size_x=64, block_size_y=2, tile_size_x=1, tile_size_y=1, time=0.731839990616
block_size_x=64, block_size_y=2, tile_size_x=1, tile_size_y=2, time=0.57044479847
block_size_x=64, block_size_y=2, tile_size_x=1, tile_size_y=4, time=0.470220798254
block_size_x=64, block_size_y=2, tile_size_x=2, tile_size_y=1, time=0.608800005913
block_size_x=64, block_size_y=2, tile_size_x=2, tile_size_y=2, time=0.472665601969
block_size_x=64, block_size_y=2, tile_size_x=2, tile_size_y=4, time=0.416352003813
block_size_x=64, block_size_y=2, tile_size_x=4, tile_size_y=1, time=0.481376004219
block_size_x=64, block_size_y=2, tile_size_x=4, tile_size_y=2, time=0.380812799931
block_size_x=64, block_size_y=2, tile_size_x=4, tile_size_y=4, time=0.351923197508
block_size_x=64, block_size_y=4, tile_size_x=1, tile_size_y=1, time=0.719257593155
block_size_x=64, block_size_y=4, tile_size_x=1, tile_size_y=2, time=0.55171200037
block_size_x=64, block_size_y=4, tile_size_x=1, tile_size_y=4, time=0.466758400202
block_size_x=64, block_size_y=4, tile_size_x=2, tile_size_y=1, time=0.568435204029
block_size_x=64, block_size_y=4, tile_size_x=2, tile_size_y=2, time=0.459654402733
block_size_x=64, block_size_y=4, tile_size_x=2, tile_size_y=4, time=0.394380801916
block_size_x=64, block_size_y=4, tile_size_x=4, tile_size_y=1, time=0.463052803278
block_size_x=64, block_size_y=4, tile_size_x=4, tile_size_y=2, time=0.36409599781
block_size_x=64, block_size_y=4, tile_size_x=4, tile_size_y=4, time=0.328998398781
block_size_x=64, block_size_y=8, tile_size_x=1, tile_size_y=1, time=0.73579518795
block_size_x=64, block_size_y=8, tile_size_x=1, tile_size_y=2, time=0.564575994015
block_size_x=64, block_size_y=8, tile_size_x=1, tile_size_y=4, time=0.472236800194
block_size_x=64, block_size_y=8, tile_size_x=2, tile_size_y=1, time=0.549024009705
block_size_x=64, block_size_y=8, tile_size_x=2, tile_size_y=2, time=0.438406395912
block_size_x=64, block_size_y=8, tile_size_x=2, tile_size_y=4, time=0.389945602417
block_size_x=64, block_size_y=8, tile_size_x=4, tile_size_y=1, time=0.455193603039
block_size_x=64, block_size_y=8, tile_size_x=4, tile_size_y=2, time=0.364051198959
block_size_x=64, block_size_y=8, tile_size_x=4, tile_size_y=4, time=0.375519996881
block_size_x=64, block_size_y=16, tile_size_x=1, tile_size_y=1, time=0.798195195198
block_size_x=64, block_size_y=16, tile_size_x=1, tile_size_y=2, time=0.588998401165
block_size_x=64, block_size_y=16, tile_size_x=1, tile_size_y=4, time=0.49552000761
block_size_x=64, block_size_y=16, tile_size_x=2, tile_size_y=1, time=0.595462405682
block_size_x=64, block_size_y=16, tile_size_x=2, tile_size_y=2, time=0.460972803831
block_size_x=64, block_size_y=16, tile_size_x=2, tile_size_y=4, time=0.400672000647
block_size_x=64, block_size_y=16, tile_size_x=4, tile_size_y=1, time=0.465132802725
block_size_x=64, block_size_y=16, tile_size_x=4, tile_size_y=2, time=0.364627194405
block_size_x=128, block_size_y=2, tile_size_x=1, tile_size_y=1, time=0.729363203049
block_size_x=128, block_size_y=2, tile_size_x=1, tile_size_y=2, time=0.558815991879
block_size_x=128, block_size_y=2, tile_size_x=1, tile_size_y=4, time=0.466655993462
block_size_x=128, block_size_y=2, tile_size_x=2, tile_size_y=1, time=0.600819194317
block_size_x=128, block_size_y=2, tile_size_x=2, tile_size_y=2, time=0.460281592607
block_size_x=128, block_size_y=2, tile_size_x=2, tile_size_y=4, time=0.404908800125
block_size_x=128, block_size_y=2, tile_size_x=4, tile_size_y=1, time=0.478739196062
block_size_x=128, block_size_y=2, tile_size_x=4, tile_size_y=2, time=0.386668801308
block_size_x=128, block_size_y=2, tile_size_x=4, tile_size_y=4, time=0.385510402918
block_size_x=128, block_size_y=4, tile_size_x=1, tile_size_y=1, time=0.720915210247
block_size_x=128, block_size_y=4, tile_size_x=1, tile_size_y=2, time=0.550668799877
block_size_x=128, block_size_y=4, tile_size_x=1, tile_size_y=4, time=0.466937589645
block_size_x=128, block_size_y=4, tile_size_x=2, tile_size_y=1, time=0.564921605587
block_size_x=128, block_size_y=4, tile_size_x=2, tile_size_y=2, time=0.447974395752
block_size_x=128, block_size_y=4, tile_size_x=2, tile_size_y=4, time=0.394271999598
block_size_x=128, block_size_y=4, tile_size_x=4, tile_size_y=1, time=0.46233600378
block_size_x=128, block_size_y=4, tile_size_x=4, tile_size_y=2, time=0.365190398693
block_size_x=128, block_size_y=4, tile_size_x=4, tile_size_y=4, time=0.387827193737
block_size_x=128, block_size_y=8, tile_size_x=1, tile_size_y=1, time=0.762003195286
block_size_x=128, block_size_y=8, tile_size_x=1, tile_size_y=2, time=0.579007995129
block_size_x=128, block_size_y=8, tile_size_x=1, tile_size_y=4, time=0.486649608612
block_size_x=128, block_size_y=8, tile_size_x=2, tile_size_y=1, time=0.557331204414
block_size_x=128, block_size_y=8, tile_size_x=2, tile_size_y=2, time=0.443033593893
block_size_x=128, block_size_y=8, tile_size_x=2, tile_size_y=4, time=0.396070402861
block_size_x=128, block_size_y=8, tile_size_x=4, tile_size_y=1, time=0.457075202465
block_size_x=128, block_size_y=8, tile_size_x=4, tile_size_y=2, time=0.369555193186
best performing configuration: block_size_x=32, block_size_y=8, tile_size_x=4, tile_size_y=4, time=0.322547197342

We can see that the number of kernel configurations tried by the Kernel Tuner is growing rather quickly. Also, the best performing configuration quite a bit faster than the best kernel before we started optimizing. On our GTX Titan X, the execution time went from 0.72 ms to 0.53 ms, a performance improvement of 26%!

Note that the thread block dimensions for this kernel configuration are also different. Without optimizations the best performing kernel used a thread block of 32x2, after we've added tiling the best performing kernel uses thread blocks of size 64x4, which is four times as many threads! Also the amount of work increased with tiling factors 2 in the x-direction and 4 in the y-direction, increasing the amount of work per thread block by a factor of 8. The difference in the area processed per thread block between the naive and the tiled kernel is a factor 32.

However, there are actually several kernel configurations that come close. The following Python code prints all instances with an execution time within 5% of the best performing configuration.

Using the best parameters in a production run

Now that we have determined which parameters are the best for our problems we can use them to simulate the heat diffusion problem. There are several ways to do so depending on the host language you wish to use.

Python run

To use the optimized parameters in a python run, we simply have to modify the kernel code to specify which value to use for the block and tile size. There are of course many different ways to achieve this. In simple cases on can define a dictionary of values and replace the string block_size_i and tile_size_j by their values.


In [18]:
import pycuda.autoinit

# define the optimal parameters
size = [nx,ny,1]
threads = [128,4,1]

# create a dict of fixed parameters
fixed_params = OrderedDict()
fixed_params['block_size_x'] = threads[0]
fixed_params['block_size_y'] = threads[1]

# select the kernel to use
kernel_string = kernel_string_shared

# replace the block/tile size
for k,v in fixed_params.items():
    kernel_string = kernel_string.replace(k,str(v))

We also need to determine the size of the grid


In [19]:
# for regular and shared kernel 
grid = [int(numpy.ceil(n/t)) for t,n in zip(threads,size)]

We can then transfer the data initial condition on the two gpu arrays as well as compile the code and get the function we want to use.


In [20]:
#allocate GPU memory
u_old = gpuarray.to_gpu(field)
u_new = gpuarray.to_gpu(field)

# compile the kernel
mod = compiler.SourceModule(kernel_string)
diffuse_kernel = mod.get_function("diffuse_kernel")

We now just have to use the kernel with these optimized parameters to run the simulation


In [21]:
#call the GPU kernel a 1000 times and measure performance
t0 = time()
for i in range(500):
    diffuse_kernel(u_new, u_old, block=tuple(threads), grid=tuple(grid))
    diffuse_kernel(u_old, u_new, block=tuple(threads), grid=tuple(grid))
driver.Context.synchronize()
print("1000 steps of diffuse on a %d x %d grid took" %(nx,ny), (time()-t0)*1000, "ms.")

#copy the result from the GPU to Python for plotting
gpu_result = u_old.get()
pyplot.imshow(gpu_result)


1000 steps of diffuse on a 4096 x 4096 grid took 618.2231903076172 ms.
Out[21]:
<matplotlib.image.AxesImage at 0x7f887c3d2358>

C run

If you wish to incorporate the optimized parameters in the kernel and use it in a C run you can use ifndef statement at the begining of the kerenel as demonstrated in the psedo code below.


In [ ]:
kernel_string = """ 

#ifndef block_size_x 
    #define block_size_x <insert optimal value>
#endif

#ifndef block_size_y 
    #define block_size_y <insert optimal value>
#endif

#define nx %d
#define ny %d
#define dt 0.225f
__global__ void diffuse_kernel(float *u_new, float *u) {
    ......
    }    

}
""" % (nx, ny)

This kernel can be used during the tuning since the kernel tuner will prepend #define statements to the kernel. As a result the #ifndef will be bypassed during the tuning. However the same kernel will work just fine on its own in a larger program.