Constructing Circuits from other Circuits

A common pattern that is used extensively in Magma is to write general functions that construct new circuits from other circuits. This circuit constructors are analogous to the higher-order functions in a functional programming languages. Examples in functional programming languages include map and fold. The corresponding Magma operators are join, fork, fold, and scan. Magma generalizes these functions with braid, that allows one to construct systolic arrays which cannot be described using functions since they wire stateful elements.


In [1]:
import magma as m

Register - col and join

Perhaps the simplest example is the Mantle function that constructs a n-bit Register.

To do this we use the Magma col and join functions. col takes a Circuit constructor and the number n of circuit instances to create. It returns a list of circuit instances. join combines n circuit instances together into a single circuit.


In [2]:
from mantle import DFF

class Register(m.Generator):
    """
    Generate an n-bit register

    Interface
    ---------
        I : In(Bits[width]), O : Out(Bits[width])
    """
    @staticmethod
    def generate(width: int):
        T = m.Bits[width]
        class _Register(m.Circuit):
            name = f'Register{width}'
            io = m.IO(I=m.In(T), O=m.Out(T)) + m.ClockIO()
            reg = m.join(m.col(lambda y: DFF(name=f"reg{y}"), width))
            m.wire(reg(io.I), io.O)
        return _Register

print(repr(Register.generate(4)))


Register4 = DefineCircuit("Register4", "I", In(Bits[4]), "O", Out(Bits[4]), "CLK", In(Clock))
reg0 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg0")
reg1 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg1")
reg2 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg2")
reg3 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg3")
wire(Register4.I[0], reg0.I)
wire(Register4.I[1], reg1.I)
wire(Register4.I[2], reg2.I)
wire(Register4.I[3], reg3.I)
wire(reg0.O, Register4.O[0])
wire(reg1.O, Register4.O[1])
wire(reg2.O, Register4.O[2])
wire(reg3.O, Register4.O[3])
EndCircuit()

fork

fork(list) creates a single circuit from a list of instances. Just as with join, the outputs of the circuits are concatenated together to form a array. However, the inputs are not concatenated, they are forked. That is, all the inputs are wired together.


In [3]:
class Decode(m.Generator):
    @staticmethod
    def generate(value: int, width: int):
        class _Decode(m.Circuit):
            name = f"Decode{width}_{value}"
            io = m.IO(I=m.In(m.Bits[width]), 
                      O=m.Out(m.Bit))
            io.O @= io.I == value
        return _Decode


class Decoder(m.Generator):
    @staticmethod
    def generate(width: int):
        class _Decoder(m.Circuit):
            io = m.IO(I=m.In(m.Bits[width]), 
                      O=m.Out(m.Bits[1 << width]))

            io.O @= m.fork(m.col(lambda y: Decode(y, width), 1 << width))(io.I)
        return _Decoder

print(repr(Decoder.generate(2)))


_Decoder = DefineCircuit("_Decoder", "I", In(Bits[2]), "O", Out(Bits[4]))
Decode2_0_inst0 = Decode2_0()
Decode2_1_inst0 = Decode2_1()
Decode2_2_inst0 = Decode2_2()
Decode2_3_inst0 = Decode2_3()
wire(_Decoder.I[0], Decode2_0_inst0.I[0])
wire(_Decoder.I[1], Decode2_0_inst0.I[1])
wire(_Decoder.I[0], Decode2_1_inst0.I[0])
wire(_Decoder.I[1], Decode2_1_inst0.I[1])
wire(_Decoder.I[0], Decode2_2_inst0.I[0])
wire(_Decoder.I[1], Decode2_2_inst0.I[1])
wire(_Decoder.I[0], Decode2_3_inst0.I[0])
wire(_Decoder.I[1], Decode2_3_inst0.I[1])
wire(Decode2_0_inst0.O, _Decoder.O[0])
wire(Decode2_1_inst0.O, _Decoder.O[1])
wire(Decode2_2_inst0.O, _Decoder.O[2])
wire(Decode2_3_inst0.O, _Decoder.O[3])
EndCircuit()

There is a lot going on in this function.

One subtlety in this code is the lambda function used when invoking col. The argument y refers to the position of the created instance in the column, we use this value to feed into the Decode circuit, which computes whether the input is equal to the current index position.

The fork function sends the input value to each Decode instance, and joins the output of each Decode, which represents a one-hot encoding where the index is high if the input is equal to the index.

fold

fold is a classic higher-order function. The Magma fold takes a list of circuit, and wires the output of one circuit to the input of another circuit. The input of the first circuit instance becomes the input of the final circuit, and the output of the last instance becauses the output. The rest of the inputs and outputs are joined.

The convention is that the output O will be wired to the input I.

A good example of this in action is to combine n DFFs into a serial-in serial-out (SISO) shift register. The output of each DFF is connected to the input of the next DFF.


In [4]:
class SISO(m.Generator):
    """
    Generate Serial-In, Serial-Out shift register with `n` cycles of delay.

    I : In(Bit), O : Out(Bit)
    """
    @staticmethod
    def generate(n: int):
        class _SISO(m.Circuit):
            name = f'SISO{n}'
            io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) + m.ClockIO()

            reg = m.fold(m.col(lambda y: DFF(name=f"reg{y}"), n))
            m.wire(reg(io.I), io.O)
        return _SISO

print(repr(SISO.generate(4)))


SISO4 = DefineCircuit("SISO4", "I", In(Bit), "O", Out(Bit), "CLK", In(Clock))
reg0 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg0")
reg1 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg1")
reg2 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg2")
reg3 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg3")
wire(SISO4.I, reg0.I)
wire(reg0.O, reg1.I)
wire(reg1.O, reg2.I)
wire(reg2.O, reg3.I)
wire(reg3.O, SISO4.O)
EndCircuit()

scan

fold is a classic higher-order function. The Magma fold takes a list of circuit, and wires the output of one circuit to the input of another circuit. The input of the first circuit instance becomes the input of the final circuit, and the output of the last instance becauses the output. The rest of the inputs and outputs are joined.

The convention is that the output O will be wired to the input I.

scan can be used to combine n DFFs into a serial-in parallel-out (SIPO) shift register. The output of each DFF is connected to the input of the next DFF. In addition, all the outputs are joined to form an array of bits.


In [5]:
class SIPO(m.Generator):
    """
    Generate Serial-In, Parallel-Out shift register.

    I : In(Bit), O : Out(Bits[n])
    """
    @staticmethod
    def generate(n: int):
        T = m.Bits[n]
        class _SIPO(m.Circuit):
            name = f'SIPO{n}'
            io = m.IO(I=m.In(m.Bit), O=m.Out(T)) + m.ClockIO()

            reg = m.scan(m.col(lambda y: DFF(name=f"reg{y}"), n))
            m.wire(reg(io.I), io.O)
        return _SIPO

print(repr(SIPO.generate(4)))


SIPO4 = DefineCircuit("SIPO4", "I", In(Bit), "O", Out(Bits[4]), "CLK", In(Clock))
reg0 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg0")
reg1 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg1")
reg2 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg2")
reg3 = DFF_init0_has_ceFalse_has_resetFalse_has_async_resetFalse(name="reg3")
wire(SIPO4.I, reg0.I)
wire(reg0.O, reg1.I)
wire(reg1.O, reg2.I)
wire(reg2.O, reg3.I)
wire(reg0.O, SIPO4.O[0])
wire(reg1.O, SIPO4.O[1])
wire(reg2.O, SIPO4.O[2])
wire(reg3.O, SIPO4.O[3])
EndCircuit()

braid

These high-order circuit construction operators can be precisely constrolled using braid. braid can be used to construct general systolic circuits.

Braid takes a list of circuit instances as an input, and simultaenously wires up the various inputs and outputs in the desired way. The advantage of braid is that inputs and outputs can be selected by name, and different methods can be used to wire up inputs and outputs.

def braid(circuits,
  joinargs=[],
  flatargs=[],
  forkargs=['RESET','SET','CE','CLK'],
  foldargs={}, rfoldargs={},
  scanargs={}, rscanargs={}):

Note that by default, the clock signals are forked.

Figure from Kung and Leiserson.

For example,

braid(circuits, foldargs={'I': 'O'})

is equivalent to fork.

Similarly, for scan

braid(circuits, scanargs={'I': 'O'})

If you want to do a scan in a different direction, use

braid(circuits, rscanargs={'I': 'O'})

rscanargs, for right-scan.


In [ ]: