Getting the hang of WASM
Almar Klein
EuroScipy 2017
Erlangen, Germany
In [ ]:
import wasmfun
In [ ]:
instructions = [('f64.const', 42),
('call', 'print_ln'),
('call', 'make_background_blue')]
In [ ]:
m = wasmfun.Module(
wasmfun.Function('$main', params=[], returns=[], locals=[], instructions=instructions),
wasmfun.ImportedFuncion('print_ln', ['f64'], [], 'js', 'print_ln'),
wasmfun.ImportedFuncion('make_background_blue', [], [], 'js', 'make_background_blue'),
)
In [ ]:
m.show()
In [ ]:
print(m.to_bytes())
print(len(m.to_bytes()))
In [ ]:
JS = """
function print_ln(x) {
var el = document.getElementById('wasm_output');
el.innerHTML += String(x).replace('\\n', '<br>') + '<br>';
}
function make_background_blue () {
document.body.style = 'background:#48f;'
}
function compile_my_wasm(wasm_data) {
var m = new WebAssembly.Module(wasm_data);
var i = new WebAssembly.Instance(m, {js: {print_ln, make_background_blue}});
}
"""
In [ ]:
from IPython.display import display, HTML, Javascript
import uuid
def run_wasm(m):
id = uuid.uuid1().hex
js = JS.replace('wasm_output', 'wasm_output_' + id)
js += "compile_my_wasm(new Uint8Array(%s));" % str(list(m.to_bytes()))
display(HTML("<div style='border: 2px solid blue;' id='wasm_output_%s'>WASM output goes here<br></div>" % id))
display(Javascript(js))
In [ ]:
run_wasm(m)
In [ ]:
instructions = [
('loop', 'emptyblock'),
# write iter
('get_local', 0), ('call', 'print_ln'),
# Increase iter
('f64.const', 1), ('get_local', 0), ('f64.add'),
('tee_local', 0), ('f64.const', 10),
('f64.lt'), ('br_if', 0),
('end'),
]
In [ ]:
m = wasmfun.Module(
wasmfun.Function('$main', params=[], returns=[], locals=['f64'], instructions=instructions),
wasmfun.ImportedFuncion('print_ln', ['f64'], [], 'js', 'print_ln'),
)
In [ ]:
m.to_bytes()
In [ ]:
wasmfun.run_wasm_in_notebook(m)
In [ ]:
wasmfun.run_wasm_in_node(m)
Brainfuck is an esoteric language that consists of only eight simple commands:
> increment the data pointer
< decrement the data pointer
+ increment (increase by one) the byte at the data pointer
- decrement (decrease by one) the byte at the data pointer
. output the byte at the data pointer.
, accept one byte of input
[ jump forward if value is zero
] jump backward if value is nonzero
Brainfuck is a simple language, but that does not mean that programming Brainfuck is easy!
In [ ]:
def bf2instructions(commands):
""" Compile brainfuck commands to WASM instructions (as tuples).
"""
instructions = []
while commands:
c = commands.pop(0)
if c == '>':
instructions += [('get_local', 0), ('i32.const', 1), ('i32.add'), ('set_local', 0)]
elif c == '<':
instructions += [('get_local', 0), ('i32.const', 1), ('i32.sub'), ('set_local', 0)]
elif c == '+':
instructions += [('get_local', 0), ('get_local', 0), # once for the read, once for the write
('i32.load8_u', 0, 0),
('i32.const', 1), ('i32.add'), ('i32.store8', 0, 0)]
elif c == '-':
instructions += [('get_local', 0), ('get_local', 0), # once for the read, once for the write
('i32.load8_u', 0, 0),
('i32.const', 1), ('i32.sub'), ('i32.store8', 0, 0)]
elif c == '.':
instructions += [('get_local', 0), ('i32.load8_u', 0, 0), ('call', 0)]
elif c == ',':
# We don't support input, just set to zero
instructions += [('get_local', 0), ('i32.const', 0), ('i32.store8', 0, 0)]
elif c == '[':
instructions += [('block', 'emptyblock'),
# if current data point == 0 goto end of block
('get_local', 0), ('i32.load8_u', 0, 0), ('i32.const', 0), ('i32.eq'), ('br_if', 0),
('loop', 'emptyblock'),
] + bf2instructions(commands ) + [
# if current data point > 0 goto start of block
('get_local', 0), ('i32.load8_u', 0, 0), ('i32.const', 0), ('i32.ne'), ('br_if', 0),
('end'),
('end')]
elif c == ']':
break
else:
pass # ignore
return instructions
In [ ]:
BF_HELLO = """
[This program prints "Hello World!" and a newline to the screen]
++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.
>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++.
"""
In [ ]:
instructions = bf2instructions(list(BF_HELLO))
In [ ]:
m = wasmfun.Module(
wasmfun.ImportedFuncion('print_charcode', ['i32'], [], 'js', 'print_charcode'),
wasmfun.Function('$main', [], [], ['i32'], instructions),
wasmfun.MemorySection((1, 1)),
wasmfun.DataSection(),
)
In [ ]:
wasmfun.run_wasm_in_notebook(m)
In [ ]:
BF_FIBONACCI = """
[Generate the fibonacci number sequence, (for numbers under 100). Taken from
http://esoteric.sange.fi/brainfuck/bf-source/prog/fibonacci.txt
]
+++++++++++>+>>>>++++++++++++++++++++++++++++++++++++++++++++>
++++++++++++++++++++++++++++++++<<<<<<[>[>>>>>>+>+<<<<<<<-]>>>>>>>
[<<<<<<<+>>>>>>>-]<[>++++++++++[-<-[>>+>+<<<-]>>>[<<<+>>>-]+<[>[-]
<[-]]>[<<[>>>+<<<-]>>[-]]<<]>>>[>>+>+<<<-]>>>[<<<+>>>-]+<[>[-]<[-]]>
[<<+>>[-]]<<<<<<<]>>>>>[++++++++++++++++++++++++++++++++++++++++++++++++.
[-]]++++++++++<[->-<]>++++++++++++++++++++++++++++++++++++++++++++++++.[-]
<<<<<<<<<<<<[>>>+>+<<<<-]>>>>[<<<<+>>>>-]<-[>>.>.<<<[-]]<<[>>+>+<<<-]>>>
[<<<+>>>-]<<[<+>-]>[<+>-]<<<-]
"""
In [ ]:
m = wasmfun.Module(
wasmfun.ImportedFuncion('print_charcode', ['i32'], [], 'js', 'print_charcode'),
wasmfun.Function('$main', [], [], ['i32'], bf2instructions(list(BF_FIBONACCI))),
wasmfun.MemorySection((1, 1)),
wasmfun.DataSection(),
)
In [ ]:
wasmfun.run_wasm_in_notebook(m)
In [ ]:
FIB_CODE = """
a = 0
b = 1
for i in range(10):
print(a)
c = b
b = a + b
a = c
"""
In [ ]:
exec(FIB_CODE)
In [ ]:
import ast
tree = ast.parse(FIB_CODE)
tree.body
In [ ]:
def _compile_expr(node, ctx, push_stack):
if isinstance(node, ast.Expr):
_compile_expr(node.value, ctx, push_stack)
elif isinstance(node, ast.Assign):
if not (len(node.targets) == 1 and isinstance(node.targets[0], ast.Name)):
raise SyntaxError('Unsupported assignment at line', node.lineno)
idx = ctx.name_idx(node.targets[0].id)
_compile_expr(node.value, ctx, True)
ctx.instructions.append(('set_local', idx))
assert not push_stack
elif isinstance(node, ast.Name):
assert push_stack
ctx.instructions.append(('get_local', ctx.name_idx(node.id)))
elif isinstance(node, ast.Num):
ctx.instructions.append(('f64.const', node.n))
elif isinstance(node, ast.UnaryOp):
_compile_expr(node.operand, ctx, True)
if isinstance(node.op, ast.USub):
ctx.instructions.append(('f64.neg'))
else:
raise SyntaxError('Unsupported unary operator: %s' % node.op.__class__.__name__)
elif isinstance(node, ast.BinOp):
_compile_expr(node.left, ctx, True)
_compile_expr(node.right, ctx, True)
if isinstance(node.op, ast.Add):
ctx.instructions.append(('f64.add'))
elif isinstance(node.op, ast.Sub):
ctx.instructions.append(('f64.sub'))
elif isinstance(node.op, ast.Mult):
ctx.instructions.append(('f64.mul'))
elif isinstance(node.op, ast.Div):
ctx.instructions.append(('f64.div'))
elif isinstance(node.op, ast.Mod):
# todo: this is fragile. E.g. for negative numbers
_compile_expr(node.left, ctx, True) # push again
_compile_expr(node.right, ctx, True)
ctx.instructions.append(('f64.div'))
ctx.instructions.append(('f64.floor'))
ctx.instructions.append(('f64.mul')) # consumes last right
ctx.instructions.append(('f64.sub')) # consumes last left
elif isinstance(node.op, ast.FloorDiv):
ctx.instructions.append(('f64.div'))
ctx.instructions.append(('f64.floor')) # not trunc
else:
raise SyntaxError('Unsuppored binary op: %s' % node.op.__class__.__name__)
if not push_stack:
ctx.instructions.append(('drop'))
elif isinstance(node, ast.Compare):
if len(node.ops) != 1:
raise SyntaxError('Only supports binary comparators (one operand).')
_compile_expr(node.left, ctx, True)
_compile_expr(node.comparators[0], ctx, True)
op = node.ops[0]
if isinstance(op, ast.Eq):
ctx.instructions.append(('f64.eq'))
elif isinstance(op, ast.NotEq):
ctx.instructions.append(('f64.ne'))
elif isinstance(op, ast.Gt):
ctx.instructions.append(('f64.qt'))
elif isinstance(op, ast.Lt):
ctx.instructions.append(('f64.lt'))
elif isinstance(op, ast.GtE):
ctx.instructions.append(('f64.qe'))
elif isinstance(op, ast.LtE):
ctx.instructions.append(('f64.le'))
else:
raise SyntaxError('Unsupported operand: %s' % op)
elif isinstance(node, ast.If):
_compile_expr(node.test, ctx, True)
assert not push_stack # Python is not an expression lang
ctx.push_block('if')
ctx.instructions.append(('if', 'emptyblock'))
for e in node.body:
_compile_expr(e, ctx, False)
if node.orelse:
ctx.instructions.append(('else', ))
for e in node.orelse:
_compile_expr(e, ctx, False)
ctx.instructions.append(('end', ))
ctx.pop_block('if')
elif isinstance(node, ast.For):
# Check whether this is the kind of simple for-loop that we support
if not (isinstance(node.iter, ast.Call) and node.iter.func.id == 'range'):
raise SyntaxError('For-loops are limited to range().')
if node.orelse:
raise SyntaxError('For-loops do not support orelse.')
if not isinstance(node.target, ast.Name):
raise SyntaxError('For-loops support just one iterable.')
# Prepare start, stop, step
start_stub = ctx.new_stub()
end_stub = ctx.new_stub()
step_stub = ctx.new_stub()
if len(node.iter.args) == 1:
ctx.instructions.append(('f64.const', 0))
_compile_expr(node.iter.args[0], ctx, True)
ctx.instructions.append(('f64.const', 1))
elif len(node.iter.args) == 2:
_compile_expr(node.iter.args[0], ctx, True)
_compile_expr(node.iter.args[1], ctx, True)
ctx.instructions.append(('f64.const', 1))
elif len(node.iter.args) == 3:
_compile_expr(node.iter.args[0], ctx, True)
_compile_expr(node.iter.args[1], ctx, True)
_compile_expr(node.iter.args[2], ctx, True)
else:
raise SyntaxError('range() should have 1, 2, or 3 args')
ctx.instructions.append(('set_local', step_stub)) # reversed order, pop from stack
ctx.instructions.append(('set_local', end_stub))
ctx.instructions.append(('set_local', start_stub))
# Body
target = ctx.name_idx(node.target.id)
ctx.push_block('for')
for i in [('get_local', start_stub), ('set_local', target), # Init target
('block', 'emptyblock'), ('loop', 'emptyblock'), # enter loop
('get_local', target), ('get_local', end_stub), ('f64.ge'), ('br_if', 1), # break (level 2)
]:
ctx.instructions.append(i)
for subnode in node.body:
_compile_expr(subnode, ctx, False)
for i in [('get_local', target), ('get_local', step_stub), ('f64.add'), ('set_local', target), # next iter
('br', 0), # loop
('end'), ('end'), # end of loop and outer block
]:
ctx.instructions.append(i)
ctx.pop_block('for')
elif isinstance(node, ast.While):
# Check whether this is the kind of simple for-loop that we support
if node.orelse:
raise SyntaxError('While-loops do not support orelse.')
# Body
ctx.push_block('while')
for i in [('block', 'emptyblock'), ('loop', 'emptyblock'), # enter loop (outer block for break)
]:
ctx.instructions.append(i)
for subnode in node.body:
_compile_expr(subnode, ctx, False)
_compile_expr(node.test, ctx, True)
for i in [('br_if', 0), # loop
('end'), ('end'), # end of loop
]:
ctx.instructions.append(i)
ctx.pop_block('while')
elif isinstance(node, ast.Continue):
ctx.instructions.append(('br', ctx.get_block_level()))
elif isinstance(node, ast.Break):
ctx.instructions.append(('br', ctx.get_block_level() + 1))
elif isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise SyntaxError('Only support simple function names')
if node.keywords:
raise SyntaxError('No support for keyword args')
name = node.func.id
if name == 'print':
assert len(node.args) == 1, 'print() accepts exactly one argument'
_compile_expr(node.args[0], ctx, True)
ctx.instructions.append(('call', 0))
elif name == 'perf_counter':
assert len(node.args) == 0, 'perf_counter() accepts exactly zero arguments'
ctx.instructions.append(('call', 1))
else:
raise SyntaxError('Not a supported function: %s' % name)
else:
raise SyntaxError('Unsupported syntax: %s' % node.__class__.__name__)
class Context:
def __init__(self):
self.instructions = []
self.names = {}
self._name_counter = 0
self._block_stack = []
def name_idx(self, name):
if name not in self.names:
self.names[name] = self._name_counter
self._name_counter += 1
return self.names[name]
def new_stub(self):
name = 'stub' + str(self._name_counter)
return self.name_idx(name)
def push_block(self, kind):
assert kind in ('if', 'for', 'while')
self._block_stack.append(kind)
def pop_block(self, kind):
assert self._block_stack.pop(-1) == kind
def get_block_level(self):
for i, kind in enumerate(reversed(self._block_stack)):
if kind in ('for', 'while'):
return i
In [ ]:
def py2wasm(python_code):
# Convert to AST
tree = ast.parse(python_code)
# Compile to instructions
ctx = Context()
for node in tree.body:
_compile_expr(node, ctx, False)
# Produce wasm module
return wasmfun.Module(
wasmfun.Function('$main', [], [], ['f64' for i in ctx.names], ctx.instructions),
wasmfun.ImportedFuncion('print_ln', ['f64'], [], 'js', 'print_ln'),
wasmfun.ImportedFuncion('perf_counter', [], ['f64'], 'js', 'perf_counter'),
)
In [ ]:
m = py2wasm("""
a = 0
b = 1
for i in range(10):
print(a)
c = b
b = a + b
a = c
""")
In [ ]:
len(m.to_bytes())
In [ ]:
wasmfun.run_wasm_in_notebook(m)
In [ ]:
PRIMES_CODE = """
max = 4000
n = 0
i = -1
t0 = perf_counter()
while n < max:
i = i + 1
if i <= 1:
continue # nope
elif i == 2:
n = n + 1
else:
gotit = 1
for j in range(2, i//2 + 1):
if i % j == 0:
gotit = 0
break
if gotit == 1:
n = n + 1
print(perf_counter() - t0)
print(i)
"""
In [ ]:
from time import perf_counter
exec(PRIMES_CODE)
In [ ]:
wasmfun.run_wasm_in_notebook(py2wasm(PRIMES_CODE))
In [ ]:
wasmfun.run_wasm_in_node(py2wasm(PRIMES_CODE))