A common pattern that is used extensively in Magma
is to write general
functions that construct new circuits from other circuits.
This circuit constructors are analogous to the higher-order functions
in a functional programming languages.
Examples in functional programming languages include map
and fold
.
The corresponding Magma
operators are join
, fork
, fold
, and scan
.
Magma
generalizes these functions with braid
,
that allows one to construct systolic arrays which
cannot be described using functions since they wire stateful elements.
In [1]:
import magma as m
Perhaps the simplest example is the Mantle
function that constructs a n-bit Register
.
To do this we use the Magma
col
and join
functions.
col
takes a Circuit
constructor and the number n
of circuit instances to create.
It returns a list of circuit instances.
join
combines n
circuit instances together into a single circuit.
In [2]:
from mantle import DFF
class Register(m.Generator):
"""
Generate an n-bit register
Interface
---------
I : In(Bits[width]), O : Out(Bits[width])
"""
@staticmethod
def generate(width: int):
T = m.Bits[width]
class _Register(m.Circuit):
name = f'Register{width}'
io = m.IO(I=m.In(T), O=m.Out(T)) + m.ClockIO()
reg = m.join(m.col(lambda y: DFF(name=f"reg{y}"), width))
m.wire(reg(io.I), io.O)
return _Register
print(repr(Register.generate(4)))
In [3]:
class Decode(m.Generator):
@staticmethod
def generate(value: int, width: int):
class _Decode(m.Circuit):
name = f"Decode{width}_{value}"
io = m.IO(I=m.In(m.Bits[width]),
O=m.Out(m.Bit))
io.O @= io.I == value
return _Decode
class Decoder(m.Generator):
@staticmethod
def generate(width: int):
class _Decoder(m.Circuit):
io = m.IO(I=m.In(m.Bits[width]),
O=m.Out(m.Bits[1 << width]))
io.O @= m.fork(m.col(lambda y: Decode(y, width), 1 << width))(io.I)
return _Decoder
print(repr(Decoder.generate(2)))
There is a lot going on in this function.
One subtlety in this code is the lambda function used when invoking col
.
The argument y
refers to the position of the created instance in the column,
we use this value to feed into the Decode
circuit, which computes whether the
input is equal to the current index position.
The fork
function sends the input value to each Decode
instance, and joins
the output of each Decode
, which represents a one-hot encoding where the index is
high if the input is equal to the index.
fold
is a classic higher-order function.
The Magma
fold takes a list of circuit,
and wires the output of one circuit to the input of another circuit.
The input of the first circuit instance becomes the input of the final circuit,
and the output of the last instance becauses the output.
The rest of the inputs and outputs are joined.
The convention is that the output O
will be wired to the input I
.
A good example of this in action is to combine n
DFFs into a serial-in serial-out (SISO)
shift register.
The output of each DFF is connected to the input of the next DFF.
In [4]:
class SISO(m.Generator):
"""
Generate Serial-In, Serial-Out shift register with `n` cycles of delay.
I : In(Bit), O : Out(Bit)
"""
@staticmethod
def generate(n: int):
class _SISO(m.Circuit):
name = f'SISO{n}'
io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()
reg = m.fold(m.col(lambda y: DFF(name=f"reg{y}"), n))
m.wire(reg(io.I), io.O)
return _SISO
print(repr(SISO.generate(4)))
fold
is a classic higher-order function.
The Magma
fold takes a list of circuit,
and wires the output of one circuit to the input of another circuit.
The input of the first circuit instance becomes the input of the final circuit,
and the output of the last instance becauses the output.
The rest of the inputs and outputs are joined.
The convention is that the output O
will be wired to the input I
.
scan
can be used to combine n
DFFs into a serial-in parallel-out (SIPO)
shift register.
The output of each DFF is connected to the input of the next DFF.
In addition, all the outputs are joined to form an array of bits.
In [5]:
class SIPO(m.Generator):
"""
Generate Serial-In, Parallel-Out shift register.
I : In(Bit), O : Out(Bits[n])
"""
@staticmethod
def generate(n: int):
T = m.Bits[n]
class _SIPO(m.Circuit):
name = f'SIPO{n}'
io = m.IO(I=m.In(m.Bit), O=m.Out(T)) + m.ClockIO()
reg = m.scan(m.col(lambda y: DFF(name=f"reg{y}"), n))
m.wire(reg(io.I), io.O)
return _SIPO
print(repr(SIPO.generate(4)))
These high-order circuit construction operators
can be precisely constrolled using braid
.
braid
can be used to construct general systolic circuits.
Braid takes a list of circuit instances as an input, and simultaenously wires up the various inputs and outputs in the desired way. The advantage of braid is that inputs and outputs can be selected by name, and different methods can be used to wire up inputs and outputs.
def braid(circuits,
joinargs=[],
flatargs=[],
forkargs=['RESET','SET','CE','CLK'],
foldargs={}, rfoldargs={},
scanargs={}, rscanargs={}):
Note that by default, the clock signals are forked.
Figure from Kung and Leiserson.
For example,
braid(circuits, foldargs={'I': 'O'})
is equivalent to fork.
Similarly, for scan
braid(circuits, scanargs={'I': 'O'})
If you want to do a scan in a different direction, use
braid(circuits, rscanargs={'I': 'O'})
rscanargs
, for right-scan.
In [ ]: