3D Grid on GPU with Kernel Tuner

In this tutorial we are going to see how to map a series of Gaussian functions, each located at a different point on a 3D a grid. We are going to optimize the GPU code and compare its performance with the CPU implementation.

**Note:** If you are reading this tutorial on the Kernel Tuner's documentation pages, note that you can actually run this tutorial as a Jupyter Notebook. Just clone the Kernel Tuner's [GitHub repository](http://github.com/benvanwerkhoven/kernel_tuner). Install the Kernel Tuner and Jupyter Notebooks and you're ready to go! You can start the tutorial by typing "jupyter notebook" in the "kernel_tuner/tutorial" directory.

Let's start on the CPU

Before delving into the GPU implementation, let's start with a simple CPU implementation of the problem. The problem at hand is to compute the values of the following function

\begin{equation} \nonumber f = \sum_{i=1}^{N}\exp\left(-\beta \sqrt{(x-x_i)^2+(y-y_i)^2+(z-z_i)^2}\right) \end{equation}

on a 3d grid. The $x$, $y$ and $z$ vectors contain the coordinate of the points in the Cartesian space. We can define a simple Python function that computes the value of the function $f$ for one given Gaussian. Don't forget to execute all the code cells, like the one below, as you read through this notebook by selecting the cell and pressing shift+enter.


In [1]:
import numpy as np
import numpy.linalg as la
from time import time

def compute_grid(center,xgrid,ygrid,zgrid):
    x0,y0,z0 = center
    beta = -0.1
    f = np.sqrt( (xgrid-x0)**2 + (ygrid-y0)**2 + (zgrid-z0)**2 )
    f = np.exp(beta*f)
    return f

For a given center, this function returns the values of the corresponding Gaussian function mapped on the 3D grid. The grid points are here defined by the variables xgrid, ygrid and zgrid. These variables are themselves 3D grids obtained, as we will see in an instant, using the numpy.meshgrid function.

To use this function we simply have to create the grid, defined by the vectors x, y, and z. Since we want to later on send these vectors to the GPU we define them as 32-bit floats. For simplicity, we here select the interval $[-1:1]$ to define our grid. We use $n=256$ grid points in order to have a sufficiently large problem without requiring too long calculations. We then create meshgrids to be passed to the function above. We define here 100 gaussian centers that are randomly distributed within the 3D space.


In [2]:
# dimension of the problem
n = 256

# define the vectors
x = np.linspace(-1,1,n).astype(np.float32)
y = np.linspace(-1,1,n).astype(np.float32)
z = np.linspace(-1,1,n).astype(np.float32)

# create meshgrids
xgrid,ygrid,zgrid = np.meshgrid(x,y,z)
cpu_grid = np.zeros_like(xgrid)

# centers
npts = 100
center = (-1 + 2*np.random.rand(npts,3)).astype(np.float32)

# compute the grid and time the operation
t0 = time()
for xyz in center:
    cpu_grid += compute_grid(xyz,xgrid,ygrid,zgrid)
print('CPU Execution time %f ms' %( (time()-t0)*1000) )


CPU Execution time 52320.160627 ms

Depending on your hardware it might take a few seconds for the calculations above to finish.

Let's move to the GPU

Let's see now how that will look like on the GPU. We first write a kernel that does the same calculation as the above function. As you can see see below, the variables block_size_x, block_size_y and block_size_z are not yet defined here. These variables are used to set the number of threads per thread block on the GPU and are the main parameters that we will optimize in this tutorial. During tuning, Kernel Tuner will automatically insert #define statements for these parameters at the top of the kernel code. So for now we don't have to specify their values.

The dimensions of the problem nx, ny, and nz, are the number of grid points in the x, y, and z dimensions. We can again use Kernel Tuner to insert these parameters into the code.


In [3]:
# define a kernel template
# several parameters are available
# block sizes : bx, by, bz 
# dimensions  : nx, ny, nz
kernel_code = """
#include <math.h>

// a simple gaussian function
__host__ __device__ float f(float d){
    float b = 0.1;
    float x = exp(-b*d);
    return x;
}

// the main function called below
__global__ void AddGrid(float x0, float y0, float z0, float *xvect, float *yvect, float *zvect, float *out)
{

    // 3D thread 
    int x = threadIdx.x + block_size_x * blockIdx.x;
    int y = threadIdx.y + block_size_y * blockIdx.y;
    int z = threadIdx.z + block_size_z * blockIdx.z;

    if ( ( x < nx ) && (y < ny) && (z < nz) )
    {

        float dx = xvect[x]-x0;
        float dy = yvect[y]-y0;
        float dz = zvect[z]-z0;
        float d = sqrt(dx*dx + dy*dy + dz*dz);
        out[y * nx * nz + x * nz + z] = f(d);
    }
}
"""

Tune the kernel

We can now use the tuner to optimize the thread block dimensions on our GPU. To do so we define the tunable parameters of our kernel using the tune_params dictionary, which assigns to each block size the values we want the tuner to explore. We also use the tunable parameters to insert the domain dimensions nx, ny, and nz.

We also define a list containing the arguments of the CUDA function (AddGrid) above. Since we only want to optimize the performance of the kernel we only consider here one center in the middle of the grid. Note that Kernel Tuner needs either numpy.ndarray or numpy.scalar as arguments of the kernel. Hence we need to be specific on the types of the Gaussians positions.


In [ ]:
from collections import OrderedDict
from kernel_tuner import tune_kernel

# create the dictionary containing the tune parameters
tune_params = OrderedDict()
tune_params['block_size_x'] = [2,4,8,16,32]
tune_params['block_size_y'] = [2,4,8,16,32]
tune_params['block_size_z'] = [2,4,8,16,32]
tune_params['nx'] = [n]
tune_params['ny'] = [n]
tune_params['nz'] = [n]

# define the final grid
grid = np.zeros_like(xgrid)

# arguments of the CUDA function
x0,y0,z0 = np.float32(0),np.float32(0),np.float32(0)
args = [x0,y0,z0,x,y,z,grid]

# dimensionality
problem_size = (n,n,n)

As mentioned earlier, the tuner will automatically insert #define statements at the top of the kernel to define the block sizes and domain dimensions, so we don't need to specify them here. Then, we simply call the tune_kernel function.


In [5]:
# call the kernel tuner
result = tune_kernel('AddGrid', kernel_code, problem_size, args, tune_params)


Using: GeForce GTX 1080 Ti
block_size_x=2, block_size_y=2, block_size_z=2, time=3.56833920479
block_size_x=2, block_size_y=2, block_size_z=4, time=1.80796158314
block_size_x=2, block_size_y=2, block_size_z=8, time=0.940044796467
block_size_x=2, block_size_y=2, block_size_z=16, time=0.855628800392
block_size_x=2, block_size_y=2, block_size_z=32, time=0.855359995365
block_size_x=2, block_size_y=4, block_size_z=2, time=4.16174077988
block_size_x=2, block_size_y=4, block_size_z=4, time=2.11877760887
block_size_x=2, block_size_y=4, block_size_z=8, time=1.01592960358
block_size_x=2, block_size_y=4, block_size_z=16, time=0.849273598194
block_size_x=2, block_size_y=4, block_size_z=32, time=0.849235200882
block_size_x=2, block_size_y=8, block_size_z=2, time=4.19029750824
block_size_x=2, block_size_y=8, block_size_z=4, time=2.16199679375
block_size_x=2, block_size_y=8, block_size_z=8, time=1.40401918888
block_size_x=2, block_size_y=8, block_size_z=16, time=1.39618558884
block_size_x=2, block_size_y=8, block_size_z=32, time=1.39508478642
block_size_x=2, block_size_y=16, block_size_z=2, time=5.31647996902
block_size_x=2, block_size_y=16, block_size_z=4, time=2.31470079422
block_size_x=2, block_size_y=16, block_size_z=8, time=1.50787198544
block_size_x=2, block_size_y=16, block_size_z=16, time=1.53760001659
block_size_x=2, block_size_y=16, block_size_z=32, time=1.56709756851
block_size_x=2, block_size_y=32, block_size_z=2, time=5.34500494003
block_size_x=2, block_size_y=32, block_size_z=4, time=2.25130877495
block_size_x=2, block_size_y=32, block_size_z=8, time=1.50662400723
block_size_x=2, block_size_y=32, block_size_z=16, time=1.55267841816
block_size_x=4, block_size_y=2, block_size_z=2, time=4.17987194061
block_size_x=4, block_size_y=2, block_size_z=4, time=2.12309756279
block_size_x=4, block_size_y=2, block_size_z=8, time=1.01125121117
block_size_x=4, block_size_y=2, block_size_z=16, time=0.849631989002
block_size_x=4, block_size_y=2, block_size_z=32, time=0.853708791733
block_size_x=4, block_size_y=4, block_size_z=2, time=4.17051515579
block_size_x=4, block_size_y=4, block_size_z=4, time=2.15584001541
block_size_x=4, block_size_y=4, block_size_z=8, time=1.40074241161
block_size_x=4, block_size_y=4, block_size_z=16, time=1.39547519684
block_size_x=4, block_size_y=4, block_size_z=32, time=1.39331197739
block_size_x=4, block_size_y=8, block_size_z=2, time=5.30295038223
block_size_x=4, block_size_y=8, block_size_z=4, time=2.28725762367
block_size_x=4, block_size_y=8, block_size_z=8, time=1.39589118958
block_size_x=4, block_size_y=8, block_size_z=16, time=1.38867840767
block_size_x=4, block_size_y=8, block_size_z=32, time=1.37724158764
block_size_x=4, block_size_y=16, block_size_z=2, time=5.34344320297
block_size_x=4, block_size_y=16, block_size_z=4, time=2.26213116646
block_size_x=4, block_size_y=16, block_size_z=8, time=1.38793599606
block_size_x=4, block_size_y=16, block_size_z=16, time=1.3775359869
block_size_x=4, block_size_y=32, block_size_z=2, time=4.74003200531
block_size_x=4, block_size_y=32, block_size_z=4, time=2.13276162148
block_size_x=4, block_size_y=32, block_size_z=8, time=1.37233917713
block_size_x=8, block_size_y=2, block_size_z=2, time=4.18835201263
block_size_x=8, block_size_y=2, block_size_z=4, time=2.15777277946
block_size_x=8, block_size_y=2, block_size_z=8, time=1.40247042179
block_size_x=8, block_size_y=2, block_size_z=16, time=1.39366400242
block_size_x=8, block_size_y=2, block_size_z=32, time=1.39439997673
block_size_x=8, block_size_y=4, block_size_z=2, time=5.23719043732
block_size_x=8, block_size_y=4, block_size_z=4, time=2.28542718887
block_size_x=8, block_size_y=4, block_size_z=8, time=1.39207677841
block_size_x=8, block_size_y=4, block_size_z=16, time=1.38956804276
block_size_x=8, block_size_y=4, block_size_z=32, time=1.3778496027
block_size_x=8, block_size_y=8, block_size_z=2, time=5.29814395905
block_size_x=8, block_size_y=8, block_size_z=4, time=2.26398081779
block_size_x=8, block_size_y=8, block_size_z=8, time=1.38625922203
block_size_x=8, block_size_y=8, block_size_z=16, time=1.3754431963
block_size_x=8, block_size_y=16, block_size_z=2, time=4.72981758118
block_size_x=8, block_size_y=16, block_size_z=4, time=2.12483196259
block_size_x=8, block_size_y=16, block_size_z=8, time=1.37322881222
block_size_x=8, block_size_y=32, block_size_z=2, time=4.61618566513
block_size_x=8, block_size_y=32, block_size_z=4, time=2.2194111824
block_size_x=16, block_size_y=2, block_size_z=2, time=5.17600002289
block_size_x=16, block_size_y=2, block_size_z=4, time=2.27082881927
block_size_x=16, block_size_y=2, block_size_z=8, time=1.38787200451
block_size_x=16, block_size_y=2, block_size_z=16, time=1.3835711956
block_size_x=16, block_size_y=2, block_size_z=32, time=1.37543039322
block_size_x=16, block_size_y=4, block_size_z=2, time=5.30227203369
block_size_x=16, block_size_y=4, block_size_z=4, time=2.23127679825
block_size_x=16, block_size_y=4, block_size_z=8, time=1.38627202511
block_size_x=16, block_size_y=4, block_size_z=16, time=1.37677440643
block_size_x=16, block_size_y=8, block_size_z=2, time=4.64358406067
block_size_x=16, block_size_y=8, block_size_z=4, time=2.12255358696
block_size_x=16, block_size_y=8, block_size_z=8, time=1.37474560738
block_size_x=16, block_size_y=16, block_size_z=2, time=4.61655673981
block_size_x=16, block_size_y=16, block_size_z=4, time=2.19179515839
block_size_x=16, block_size_y=32, block_size_z=2, time=4.99912958145
block_size_x=32, block_size_y=2, block_size_z=2, time=5.213971138
block_size_x=32, block_size_y=2, block_size_z=4, time=2.16430072784
block_size_x=32, block_size_y=2, block_size_z=8, time=1.38772480488
block_size_x=32, block_size_y=2, block_size_z=16, time=1.3735104084
block_size_x=32, block_size_y=4, block_size_z=2, time=4.54432649612
block_size_x=32, block_size_y=4, block_size_z=4, time=2.05524477959
block_size_x=32, block_size_y=4, block_size_z=8, time=1.36935677528
block_size_x=32, block_size_y=8, block_size_z=2, time=4.42449922562
block_size_x=32, block_size_y=8, block_size_z=4, time=2.10455036163
block_size_x=32, block_size_y=16, block_size_z=2, time=4.67516155243
best performing configuration: block_size_x=2, block_size_y=4, block_size_z=32, time=0.849235200882

The tune_kernel function explores all the possible combinations of tunable parameters (here only the block size). For each possible kernel configuration, the tuner compiles the code and its measures execution time (by default using 7 iterations). At the end of the the run, the tune_kernel outputs the optimal combination of the tunable parameters. But the measured execution time of all benchmarked kernels is also returned by tune_kernel for programmatic access to the data.

As you can see the range of performances is quite large. With our GPU (GeForce GTX 1080 Ti) we obtained a maximum time of 5.30 ms and minimum one of 0.84 ms. The performance of the kernel varies by a factor 6 depending on the thread block size!

Using the optimized parameters

Now that we have determined which parameters are the best suited for our application we can specify them in our kernel and run it. In our case, the optimal grid size determined by the tuner were block_size_x = 4, block_size_y = 2, block_size_z=16. We therefore use these parameters here to define the block size. The grid size is simply obtained by dividing the dimension of the problem by the corresponding block size.


In [6]:
from pycuda import driver, compiler, gpuarray, tools
import pycuda.autoinit

# optimal values of the block size
block = [4, 2, 16]

# corresponding grid size
grid_dim = [int(np.ceil(n/b)) for b, n in zip(block, problem_size)]

Before using the kernel we need to specify the block size in its definition. There are different ways of doing this, we here simply replace the block_size_x, block_size_y and block_size_z by their values determined by the tuner. In order to do that we create a dictionary that associates the name of the block size and their values and simply make the substitution. Once the block size are specified, we can compile the kernel ourselves and get the function.


In [7]:
# change the values of the block sizes in the kernel
fixed_params = OrderedDict()
fixed_params['block_size_x'] = block[0]
fixed_params['block_size_y'] = block[1]
fixed_params['block_size_z'] = block[2]
fixed_params['nx'] = n
fixed_params['ny'] = n
fixed_params['nz'] = n

for k,v in fixed_params.items():
    kernel_code = kernel_code.replace(k,str(v))
 
# compile the kernel_code and extract the function
mod = compiler.SourceModule(kernel_code)
addgrid = mod.get_function('AddGrid')

We now have to manually create the gpuarrays that correspond to the vector x, y and z as well as the 3D grid. Once all these are defined we can call the addgrid function using the gpuarrays and the block and grid size in argument. We also time the execution to compare it with the one outputed by the kernel tuner. Note that we exlicitly synchronize the CPU and GPU to obtain an accurate timing.


In [8]:
# create the gpu arrays
xgpu = gpuarray.to_gpu(x)
ygpu = gpuarray.to_gpu(y)
zgpu = gpuarray.to_gpu(z)
grid_gpu = gpuarray.zeros((n,n,n), np.float32)

# compute the grid and time the performance
t0 = time()
for xyz in center:
    x0,y0,z0 = xyz
    addgrid(x0,y0,z0,xgpu,ygpu,zgpu,grid_gpu,block = tuple(block),grid=tuple(grid_dim))
driver.Context.synchronize()
print('Final GPU time : %f ms' %((time()-t0)*1000))


Final GPU time : 80.133200 ms

As you can see the GPU execution time is much lower than than the CPU execution time obtained above. In our case it went from roughly 40000 ms to just 80 ms !