In [1]:
from devito import SpaceDimension, TimeDimension
dims = {'i': SpaceDimension(name='i'),
'j': SpaceDimension(name='j'),
'k': SpaceDimension(name='k'),
't0': TimeDimension(name='t0'),
't1': TimeDimension(name='t1')}
dims
Out[1]:
Elements such as Scalars
, Constants
and Functions
are used to build SymPy equations.
In [2]:
from devito import Grid, Constant, Function, TimeFunction
from devito.types import Array, Scalar
grid = Grid(shape=(10, 10))
symbs = {'a': Scalar(name='a'),
'b': Constant(name='b'),
'c': Array(name='c', shape=(3,), dimensions=(dims['i'],)).indexify(),
'd': Array(name='d',
shape=(3,3),
dimensions=(dims['j'],dims['k'])).indexify(),
'e': Function(name='e',
shape=(3,3,3),
dimensions=(dims['t0'],dims['t1'],dims['i'])).indexify(),
'f': TimeFunction(name='f', grid=grid).indexify()}
symbs
Out[2]:
An IET Expression
wraps a SymPy equation. Below, DummyEq
is a subclass of sympy.Eq
with some metadata attached. What, when and how metadata are attached is here irrelevant.
In [3]:
from devito.ir.iet import Expression
from devito.ir.equations import DummyEq
from devito.tools import pprint
def get_exprs(a, b, c, d, e, f):
return [Expression(DummyEq(a, b + c + 5.)),
Expression(DummyEq(d, e - f)),
Expression(DummyEq(a, 4 * (b * a))),
Expression(DummyEq(a, (6. / b) + (8. * a)))]
exprs = get_exprs(symbs['a'],
symbs['b'],
symbs['c'],
symbs['d'],
symbs['e'],
symbs['f'])
pprint(exprs)
An Iteration
typically wraps one or more Expression
s.
In [4]:
from devito.ir.iet import Iteration
def get_iters(dims):
return [lambda ex: Iteration(ex, dims['i'], (0, 3, 1)),
lambda ex: Iteration(ex, dims['j'], (0, 5, 1)),
lambda ex: Iteration(ex, dims['k'], (0, 7, 1)),
lambda ex: Iteration(ex, dims['t0'], (0, 4, 1)),
lambda ex: Iteration(ex, dims['t1'], (0, 4, 1))]
iters = get_iters(dims)
Here, we can see how blocks of Iterations
over Expressions
can be used to build loop nests.
In [5]:
def get_block1(exprs, iters):
# Perfect loop nest:
# for i
# for j
# for k
# expr0
return iters[0](iters[1](iters[2](exprs[0])))
def get_block2(exprs, iters):
# Non-perfect simple loop nest:
# for i
# expr0
# for j
# for k
# expr1
return iters[0]([exprs[0], iters[1](iters[2](exprs[1]))])
def get_block3(exprs, iters):
# Non-perfect non-trivial loop nest:
# for i
# for s
# expr0
# for j
# for k
# expr1
# expr2
# for p
# expr3
return iters[0]([iters[3](exprs[0]),
iters[1](iters[2]([exprs[1], exprs[2]])),
iters[4](exprs[3])])
block1 = get_block1(exprs, iters)
block2 = get_block2(exprs, iters)
block3 = get_block3(exprs, iters)
pprint(block1), print('\n')
pprint(block2), print('\n')
pprint(block3)
And, finally, we can build Callable
kernels that will be used to generate C code. Note that Operator
is a subclass of Callable
.
In [6]:
from devito.ir.iet import Callable
kernels = [Callable('foo', block1, 'void', ()),
Callable('foo', block2, 'void', ()),
Callable('foo', block3, 'void', ())]
print('kernel no.1:\n' + str(kernels[0].ccode) + '\n')
print('kernel no.2:\n' + str(kernels[1].ccode) + '\n')
print('kernel no.3:\n' + str(kernels[2].ccode) + '\n')
An IET is immutable. It can be "transformed" by replacing or dropping some of its inner nodes, but what this actually means is that a new IET is created. IETs are transformed by Transformer
visitors. A Transformer
takes in input a dictionary encoding replacement rules.
In [7]:
from devito.ir.iet import Transformer
# Replaces a Function's body with another
transformer = Transformer({block1: block2})
kernel_alt = transformer.visit(kernels[0])
print(kernel_alt)
Specific Expression
s within the loop nest can also be substituted.
In [8]:
# Replaces an expression with another
transformer = Transformer({exprs[0]: exprs[1]})
newblock = transformer.visit(block1)
newcode = str(newblock.ccode)
print(newcode)
In [9]:
from devito.ir.iet import Block
import cgen as c
# Creates a replacer for replacing an expression
line1 = '// Replaced expression'
replacer = Block(c.Line(line1))
transformer = Transformer({exprs[1]: replacer})
newblock = transformer.visit(block2)
newcode = str(newblock.ccode)
print(newcode)
In [10]:
# Wraps an expression in comments
line1 = '// This is the opening comment'
line2 = '// This is the closing comment'
wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2))
transformer = Transformer({exprs[0]: wrapper(exprs[0])})
newblock = transformer.visit(block1)
newcode = str(newblock.ccode)
print(newcode)