In this tutorial we are going to see how to map a series of Gaussian functions, each located at a different point on a 3D a grid. We are going to optimize the GPU code and compare its performance with the CPU implementation.
Before delving into the GPU implementation, let's start with a simple CPU implementation of the problem. The problem at hand is to compute the values of the following function
\begin{equation} \nonumber f = \sum_{i=1}^{N}\exp\left(-\beta \sqrt{(x-x_i)^2+(y-y_i)^2+(z-z_i)^2}\right) \end{equation}on a 3d grid. The $x$, $y$ and $z$ vectors contain the coordinate of the points in the Cartesian space. We can define a simple Python function that computes the value of the function $f$ for one given Gaussian. Don't forget to execute all the code cells, like the one below, as you read through this notebook by selecting the cell and pressing shift+enter.
In [1]:
import numpy as np
import numpy.linalg as la
from time import time
def compute_grid(center,xgrid,ygrid,zgrid):
x0,y0,z0 = center
beta = -0.1
f = np.sqrt( (xgrid-x0)**2 + (ygrid-y0)**2 + (zgrid-z0)**2 )
f = np.exp(beta*f)
return f
For a given center, this function returns the values of the corresponding Gaussian function mapped on the 3D grid. The grid points are here defined by the variables xgrid
, ygrid
and zgrid
. These variables are themselves 3D grids obtained, as we will see in an instant, using the numpy.meshgrid
function.
To use this function we simply have to create the grid, defined by the vectors x, y, and z. Since we want to later on send these vectors to the GPU we define them as 32-bit floats. For simplicity, we here select the interval $[-1:1]$ to define our grid. We use $n=256$ grid points in order to have a sufficiently large problem without requiring too long calculations. We then create meshgrids to be passed to the function above. We define here 100 gaussian centers that are randomly distributed within the 3D space.
In [2]:
# dimension of the problem
n = 256
# define the vectors
x = np.linspace(-1,1,n).astype(np.float32)
y = np.linspace(-1,1,n).astype(np.float32)
z = np.linspace(-1,1,n).astype(np.float32)
# create meshgrids
xgrid,ygrid,zgrid = np.meshgrid(x,y,z)
cpu_grid = np.zeros_like(xgrid)
# centers
npts = 100
center = (-1 + 2*np.random.rand(npts,3)).astype(np.float32)
# compute the grid and time the operation
t0 = time()
for xyz in center:
cpu_grid += compute_grid(xyz,xgrid,ygrid,zgrid)
print('CPU Execution time %f ms' %( (time()-t0)*1000) )
Depending on your hardware it might take a few seconds for the calculations above to finish.
Let's see now how that will look like on the GPU. We first write a kernel that does the same calculation as the above function. As you can see see below, the variables block_size_x
, block_size_y
and block_size_z
are not yet defined here. These variables are used to set the number of threads per thread block on the GPU and are the main parameters that we will optimize in this tutorial. During tuning, Kernel Tuner will automatically insert #define
statements for these parameters at the top of the kernel code. So for now we don't have to specify their values.
The dimensions of the problem nx
, ny
, and nz
, are the number of grid points in the x, y, and z dimensions. We can again use Kernel Tuner to insert these parameters into the code.
In [3]:
# define a kernel template
# several parameters are available
# block sizes : bx, by, bz
# dimensions : nx, ny, nz
kernel_code = """
#include <math.h>
// a simple gaussian function
__host__ __device__ float f(float d){
float b = 0.1;
float x = exp(-b*d);
return x;
}
// the main function called below
__global__ void AddGrid(float x0, float y0, float z0, float *xvect, float *yvect, float *zvect, float *out)
{
// 3D thread
int x = threadIdx.x + block_size_x * blockIdx.x;
int y = threadIdx.y + block_size_y * blockIdx.y;
int z = threadIdx.z + block_size_z * blockIdx.z;
if ( ( x < nx ) && (y < ny) && (z < nz) )
{
float dx = xvect[x]-x0;
float dy = yvect[y]-y0;
float dz = zvect[z]-z0;
float d = sqrt(dx*dx + dy*dy + dz*dz);
out[y * nx * nz + x * nz + z] = f(d);
}
}
"""
We can now use the tuner to optimize the thread block dimensions on our GPU. To do so we define the tunable parameters of our kernel using the tune_params
dictionary, which assigns to each block size the values we want the tuner to explore. We also use the tunable parameters to insert the domain dimensions nx
, ny
, and nz
.
We also define a list containing the arguments of the CUDA function (AddGrid) above. Since we only want to optimize the performance of the kernel we only consider here one center in the middle of the grid. Note that Kernel Tuner needs either numpy.ndarray
or numpy.scalar
as arguments of the kernel. Hence we need to be specific on the types of the Gaussians positions.
In [ ]:
from collections import OrderedDict
from kernel_tuner import tune_kernel
# create the dictionary containing the tune parameters
tune_params = OrderedDict()
tune_params['block_size_x'] = [2,4,8,16,32]
tune_params['block_size_y'] = [2,4,8,16,32]
tune_params['block_size_z'] = [2,4,8,16,32]
tune_params['nx'] = [n]
tune_params['ny'] = [n]
tune_params['nz'] = [n]
# define the final grid
grid = np.zeros_like(xgrid)
# arguments of the CUDA function
x0,y0,z0 = np.float32(0),np.float32(0),np.float32(0)
args = [x0,y0,z0,x,y,z,grid]
# dimensionality
problem_size = (n,n,n)
As mentioned earlier, the tuner will automatically insert #define
statements at the top of the kernel to define the block sizes and domain dimensions, so we don't need to specify them here. Then, we simply call the tune_kernel
function.
In [5]:
# call the kernel tuner
result = tune_kernel('AddGrid', kernel_code, problem_size, args, tune_params)
The tune_kernel
function explores all the possible combinations of tunable parameters (here only the block size). For each possible kernel configuration, the tuner compiles the code and its measures execution time (by default using 7 iterations). At the end of the the run, the tune_kernel
outputs the optimal combination of the tunable parameters. But the measured execution time of all benchmarked kernels is also returned by tune_kernel
for programmatic access to the data.
As you can see the range of performances is quite large. With our GPU (GeForce GTX 1080 Ti) we obtained a maximum time of 5.30 ms and minimum one of 0.84 ms. The performance of the kernel varies by a factor 6 depending on the thread block size!
Now that we have determined which parameters are the best suited for our application we can specify them in our kernel and run it. In our case, the optimal grid size determined by the tuner were block_size_x = 4, block_size_y = 2, block_size_z=16. We therefore use these parameters here to define the block size. The grid size is simply obtained by dividing the dimension of the problem by the corresponding block size.
In [6]:
from pycuda import driver, compiler, gpuarray, tools
import pycuda.autoinit
# optimal values of the block size
block = [4, 2, 16]
# corresponding grid size
grid_dim = [int(np.ceil(n/b)) for b, n in zip(block, problem_size)]
Before using the kernel we need to specify the block size in its definition. There are different ways of doing this, we here simply replace the block_size_x
, block_size_y
and block_size_z
by their values determined by the tuner. In order to do that we create a dictionary that associates the name of the block size and their values and simply make the substitution. Once the block size are specified, we can compile the kernel ourselves and get the function.
In [7]:
# change the values of the block sizes in the kernel
fixed_params = OrderedDict()
fixed_params['block_size_x'] = block[0]
fixed_params['block_size_y'] = block[1]
fixed_params['block_size_z'] = block[2]
fixed_params['nx'] = n
fixed_params['ny'] = n
fixed_params['nz'] = n
for k,v in fixed_params.items():
kernel_code = kernel_code.replace(k,str(v))
# compile the kernel_code and extract the function
mod = compiler.SourceModule(kernel_code)
addgrid = mod.get_function('AddGrid')
We now have to manually create the gpuarrays that correspond to the vector x, y and z as well as the 3D grid. Once all these are defined we can call the addgrid
function using the gpuarrays and the block and grid size in argument. We also time the execution to compare it with the one outputed by the kernel tuner. Note that we exlicitly synchronize the CPU and GPU to obtain an accurate timing.
In [8]:
# create the gpu arrays
xgpu = gpuarray.to_gpu(x)
ygpu = gpuarray.to_gpu(y)
zgpu = gpuarray.to_gpu(z)
grid_gpu = gpuarray.zeros((n,n,n), np.float32)
# compute the grid and time the performance
t0 = time()
for xyz in center:
x0,y0,z0 = xyz
addgrid(x0,y0,z0,xgpu,ygpu,zgpu,grid_gpu,block = tuple(block),grid=tuple(grid_dim))
driver.Context.synchronize()
print('Final GPU time : %f ms' %((time()-t0)*1000))
As you can see the GPU execution time is much lower than than the CPU execution time obtained above. In our case it went from roughly 40000 ms to just 80 ms !