This next example shows how you make stateful things with registers and more complex hardware structures with functions. We generate a 3-bit ripple carry adder building off of the 1-bit adder from the prior example, and then hook it to a register to count up modulo 8.
In [ ]:
import pyrtl
In [ ]:
pyrtl.reset_working_block()
A function in PyRTL is nothing special -- it just so happens that the statements it encapsulate tell PyRTL to build some hardware.
In [ ]:
def one_bit_add(a, b, carry_in):
assert len(a) == len(b) == 1 # len returns the bitwidth
sum = a ^ b ^ carry_in
carry_out = a & b | a & carry_in | b & carry_in
return sum, carry_out
If we call one_bit_add above with the arguments x, y, and z it will make a one-bit adder to add those values together and return the wires for sum and carry_out as applied to x, y, and z. If I call it again on i, j, and k it will build a new one-bit adder for those inputs and return the resulting sum and carry_out for that adder.
While PyRTL actually provides an "+" operator for wirevectors which generates adders, a ripple carry adder is something people can understand easily but has enough structure to be mildly interesting. Let's define an adder of arbitrary length recursively and (hopefully) pythonically. More comments after the code.
In [ ]:
def ripple_add(a, b, carry_in=0):
a, b = pyrtl.match_bitwidth(a, b)
# this function is a function that allows us to match the bitwidth of multiple
# different wires. By default, it zero extends the shorter bits
if len(a) == 1:
sumbits, carry_out = one_bit_add(a, b, carry_in)
else:
lsbit, ripplecarry = one_bit_add(a[0], b[0], carry_in)
msbits, carry_out = ripple_add(a[1:], b[1:], ripplecarry)
sumbits = pyrtl.concat(msbits, lsbit)
return sumbits, carry_out
Now let's build a 3-bit counter from our N-bit ripple carry adder.
In [ ]:
counter = pyrtl.Register(bitwidth=3, name='counter')
sum, carry_out = ripple_add(counter, pyrtl.Const("1'b1"))
counter.next <<= sum
Now let's run the bugger. No need for inputs, it doesn't have any, but let's throw in an assert to check that it really counts up modulo 8. Finally we'll print the trace to the screen.
In [ ]:
sim_trace = pyrtl.SimulationTrace()
sim = pyrtl.Simulation(tracer=sim_trace)
for cycle in range(15):
sim.step({})
assert sim.value[counter] == cycle % 8
sim_trace.render_trace()