Getting the hang of WASM



Almar Klein
EuroScipy 2017
Erlangen, Germany

This talk ...

  • Is me fooling around with Web Assembly
  • Is a Jupyter notebook (using RISE for presentation mode)
  • Code is executed live
  • Makes use of a small library called wasmfun (https://github.com/almarklein/wasmfun)
  • pip install wasmfun

In [ ]:
import wasmfun

What is Web Assembly?

Web Assembly is a low level representation of code

Designed to be:

  • Easy to convert to machine code
  • Fast

In [ ]:
instructions = [('f64.const', 42),
                ('call', 'print_ln'),
                ('call', 'make_background_blue')]

Instructions are packed into functions ...

... which are packed into modules


In [ ]:
m = wasmfun.Module(    
    wasmfun.Function('$main', params=[], returns=[], locals=[], instructions=instructions),    
    wasmfun.ImportedFuncion('print_ln', ['f64'], [], 'js', 'print_ln'),
    wasmfun.ImportedFuncion('make_background_blue', [], [], 'js', 'make_background_blue'),
    )

In [ ]:
m.show()

Web Assembly modules have a compact binary format


In [ ]:
print(m.to_bytes())
print(len(m.to_bytes()))

Web Assembly is safe

  • Functionality is imported from host environment
    • Access to DOM when in the browser
    • Access to file system etc. when on desktop

In [ ]:
JS = """
function print_ln(x) {
    var el = document.getElementById('wasm_output');
    el.innerHTML += String(x).replace('\\n', '<br>') + '<br>';
}

function make_background_blue () {
    document.body.style = 'background:#48f;'
}

function compile_my_wasm(wasm_data) {
    var m = new WebAssembly.Module(wasm_data);
    var i = new WebAssembly.Instance(m, {js: {print_ln, make_background_blue}});
}
"""

In [ ]:
from IPython.display import display, HTML, Javascript
import uuid

def run_wasm(m):
    id = uuid.uuid1().hex
    js = JS.replace('wasm_output', 'wasm_output_' + id)
    js += "compile_my_wasm(new Uint8Array(%s));" % str(list(m.to_bytes()))
    
    display(HTML("<div style='border: 2px solid blue;' id='wasm_output_%s'>WASM output goes here<br></div>" % id))
    display(Javascript(js))

Let's run our module in the browser!


In [ ]:
run_wasm(m)

Again, now with a for-loop


In [ ]:
instructions = [
    ('loop', 'emptyblock'),
        # write iter
        ('get_local', 0), ('call', 'print_ln'),
        # Increase iter
        ('f64.const', 1), ('get_local', 0), ('f64.add'),
        ('tee_local', 0), ('f64.const', 10),
        ('f64.lt'), ('br_if', 0),
    ('end'),
    ]

In [ ]:
m = wasmfun.Module(
    wasmfun.Function('$main', params=[], returns=[], locals=['f64'], instructions=instructions),
    wasmfun.ImportedFuncion('print_ln', ['f64'], [], 'js', 'print_ln'),    
    )

In [ ]:
m.to_bytes()

In [ ]:
wasmfun.run_wasm_in_notebook(m)

Web Assembly will run anywhere!


In [ ]:
wasmfun.run_wasm_in_node(m)

Before moving on ...

  • Web Assembly is an open standard
  • Everyone's on board

Brainfuck

Let's compile Brainfuck to WASM ...

Brainfuck is an esoteric language that consists of only eight simple commands:

>    increment the data pointer
<    decrement the data pointer
+    increment (increase by one) the byte at the data pointer
-    decrement (decrease by one) the byte at the data pointer
.    output the byte at the data pointer.
,    accept one byte of input
[    jump forward if value is zero
]    jump backward if value is nonzero

Brainfuck is a simple language, but that does not mean that programming Brainfuck is easy!


In [ ]:
def bf2instructions(commands):
    """ Compile brainfuck commands to WASM instructions (as tuples).
    """    
    instructions = []
    while commands:
        c = commands.pop(0)
        if c == '>':
            instructions += [('get_local', 0), ('i32.const', 1), ('i32.add'), ('set_local', 0)]
        elif c == '<':
            instructions += [('get_local', 0), ('i32.const', 1), ('i32.sub'), ('set_local', 0)]
        elif c == '+':
            instructions += [('get_local', 0), ('get_local', 0),  # once for the read, once for the write
                             ('i32.load8_u', 0, 0),
                             ('i32.const', 1), ('i32.add'), ('i32.store8', 0, 0)]
        elif c == '-':
            instructions += [('get_local', 0), ('get_local', 0),  # once for the read, once for the write
                             ('i32.load8_u', 0, 0),
                             ('i32.const', 1), ('i32.sub'), ('i32.store8', 0, 0)]
        elif c == '.':
            instructions += [('get_local', 0), ('i32.load8_u', 0, 0), ('call', 0)]
        elif c == ',':
            # We don't support input, just set to zero
            instructions += [('get_local', 0), ('i32.const', 0), ('i32.store8', 0, 0)]
        elif c == '[':
            instructions += [('block', 'emptyblock'),
                                # if current data point == 0 goto end of block
                                ('get_local', 0), ('i32.load8_u', 0, 0), ('i32.const', 0), ('i32.eq'), ('br_if', 0),
                                ('loop', 'emptyblock'),
                                    ] + bf2instructions(commands ) + [
                                    # if current data point > 0 goto start of block
                                    ('get_local', 0), ('i32.load8_u', 0, 0), ('i32.const', 0), ('i32.ne'), ('br_if', 0),
                                ('end'),
                             ('end')]
        elif c == ']':
            break
        else:
            pass  # ignore
    
    return instructions

In [ ]:
BF_HELLO = """
[This program prints "Hello World!" and a newline to the screen]
++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.
>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++.
"""

In [ ]:
instructions = bf2instructions(list(BF_HELLO))

In [ ]:
m = wasmfun.Module(
    wasmfun.ImportedFuncion('print_charcode', ['i32'], [], 'js', 'print_charcode'),
    wasmfun.Function('$main', [], [], ['i32'], instructions),
    wasmfun.MemorySection((1, 1)),
    wasmfun.DataSection(),
    )

In [ ]:
wasmfun.run_wasm_in_notebook(m)

In [ ]:
BF_FIBONACCI = """
[Generate the fibonacci number sequence, (for numbers under 100). Taken from
http://esoteric.sange.fi/brainfuck/bf-source/prog/fibonacci.txt
]
+++++++++++>+>>>>++++++++++++++++++++++++++++++++++++++++++++>
++++++++++++++++++++++++++++++++<<<<<<[>[>>>>>>+>+<<<<<<<-]>>>>>>>
[<<<<<<<+>>>>>>>-]<[>++++++++++[-<-[>>+>+<<<-]>>>[<<<+>>>-]+<[>[-]
<[-]]>[<<[>>>+<<<-]>>[-]]<<]>>>[>>+>+<<<-]>>>[<<<+>>>-]+<[>[-]<[-]]>
[<<+>>[-]]<<<<<<<]>>>>>[++++++++++++++++++++++++++++++++++++++++++++++++.
[-]]++++++++++<[->-<]>++++++++++++++++++++++++++++++++++++++++++++++++.[-]
<<<<<<<<<<<<[>>>+>+<<<<-]>>>>[<<<<+>>>>-]<-[>>.>.<<<[-]]<<[>>+>+<<<-]>>>
[<<<+>>>-]<<[<+>-]>[<+>-]<<<-]
"""

In [ ]:
m = wasmfun.Module(
    wasmfun.ImportedFuncion('print_charcode', ['i32'], [], 'js', 'print_charcode'),
    wasmfun.Function('$main', [], [], ['i32'], bf2instructions(list(BF_FIBONACCI))),
    wasmfun.MemorySection((1, 1)),
    wasmfun.DataSection(),
    )

In [ ]:
wasmfun.run_wasm_in_notebook(m)

Python

Let's compile Python to WASM ...

Making a full Python to WASM compiler will be a lot of work, and would not necessarily result in a faster interpreter than CPython. But we can implement a subset of Python in which all values are floats.

Example code - Fibonacci sequence


In [ ]:
FIB_CODE = """
a = 0
b = 1
for i in range(10):
    print(a)
    c = b
    b = a + b
    a = c
"""

In [ ]:
exec(FIB_CODE)

From code to AST

With Python itself we can easily generate an Abstract Syntax Tree from Python code


In [ ]:
import ast
tree = ast.parse(FIB_CODE)
tree.body

From AST to WASM

In a nutshell: walk over the tree and generate WASM instructions.


In [ ]:
def _compile_expr(node, ctx, push_stack):
    
    if isinstance(node, ast.Expr):
        _compile_expr(node.value, ctx, push_stack)
    
    elif isinstance(node, ast.Assign):
        if not (len(node.targets) == 1 and isinstance(node.targets[0], ast.Name)):
            raise SyntaxError('Unsupported assignment at line', node.lineno)
        idx = ctx.name_idx(node.targets[0].id)
        _compile_expr(node.value, ctx, True)
        ctx.instructions.append(('set_local', idx))
        assert not push_stack
    
    elif isinstance(node, ast.Name):
        assert push_stack
        ctx.instructions.append(('get_local', ctx.name_idx(node.id)))
    
    elif isinstance(node, ast.Num):
        ctx.instructions.append(('f64.const', node.n))
    
    elif isinstance(node, ast.UnaryOp):
        _compile_expr(node.operand, ctx, True)
        if isinstance(node.op, ast.USub):
            ctx.instructions.append(('f64.neg'))
        else:
            raise SyntaxError('Unsupported unary operator: %s' % node.op.__class__.__name__)
    
    elif isinstance(node, ast.BinOp):
        _compile_expr(node.left, ctx, True)
        _compile_expr(node.right, ctx, True)
        if isinstance(node.op, ast.Add):
            ctx.instructions.append(('f64.add'))
        elif isinstance(node.op, ast.Sub):
            ctx.instructions.append(('f64.sub'))
        elif isinstance(node.op, ast.Mult):
            ctx.instructions.append(('f64.mul'))
        elif isinstance(node.op, ast.Div):
            ctx.instructions.append(('f64.div'))
        elif isinstance(node.op, ast.Mod):
            # todo: this is fragile. E.g. for negative numbers
            _compile_expr(node.left, ctx, True)  # push again
            _compile_expr(node.right, ctx, True)
            ctx.instructions.append(('f64.div'))
            ctx.instructions.append(('f64.floor'))
            ctx.instructions.append(('f64.mul'))  # consumes last right
            ctx.instructions.append(('f64.sub'))  # consumes last left
        elif isinstance(node.op, ast.FloorDiv):
            ctx.instructions.append(('f64.div'))
            ctx.instructions.append(('f64.floor'))  # not trunc
        else:
            raise SyntaxError('Unsuppored binary op: %s' % node.op.__class__.__name__)
        if not push_stack:
            ctx.instructions.append(('drop'))
    
    elif isinstance(node, ast.Compare):
        if len(node.ops) != 1:
            raise SyntaxError('Only supports binary comparators (one operand).')
        _compile_expr(node.left, ctx, True)
        _compile_expr(node.comparators[0], ctx, True)
        op = node.ops[0]
        if isinstance(op, ast.Eq):
            ctx.instructions.append(('f64.eq'))
        elif isinstance(op, ast.NotEq):
            ctx.instructions.append(('f64.ne'))
        elif isinstance(op, ast.Gt):
            ctx.instructions.append(('f64.qt'))
        elif isinstance(op, ast.Lt):
            ctx.instructions.append(('f64.lt'))
        elif isinstance(op, ast.GtE):
            ctx.instructions.append(('f64.qe'))
        elif isinstance(op, ast.LtE):
            ctx.instructions.append(('f64.le'))
        else:
            raise SyntaxError('Unsupported operand: %s' % op)
    
    elif isinstance(node, ast.If):
        _compile_expr(node.test, ctx, True)
        assert not push_stack  # Python is not an expression lang
        ctx.push_block('if')
        ctx.instructions.append(('if', 'emptyblock'))
        for e in node.body:
            _compile_expr(e, ctx, False)
        if node.orelse:
            ctx.instructions.append(('else', ))
            for e in node.orelse:
                _compile_expr(e, ctx, False)
        ctx.instructions.append(('end', ))
        ctx.pop_block('if')
    
    elif isinstance(node, ast.For):
        # Check whether this is the kind of simple for-loop that we support
        if not (isinstance(node.iter, ast.Call) and node.iter.func.id == 'range'):
            raise SyntaxError('For-loops are limited to range().')
        if node.orelse:
            raise SyntaxError('For-loops do not support orelse.')
        if not isinstance(node.target, ast.Name):
            raise SyntaxError('For-loops support just one iterable.')
        # Prepare start, stop, step
        start_stub = ctx.new_stub()
        end_stub = ctx.new_stub()
        step_stub = ctx.new_stub()
        if len(node.iter.args) == 1:
            ctx.instructions.append(('f64.const', 0))
            _compile_expr(node.iter.args[0], ctx, True)
            ctx.instructions.append(('f64.const', 1))
        elif len(node.iter.args) == 2:
            _compile_expr(node.iter.args[0], ctx, True)
            _compile_expr(node.iter.args[1], ctx, True)
            ctx.instructions.append(('f64.const', 1))
        elif len(node.iter.args) == 3:
            _compile_expr(node.iter.args[0], ctx, True)
            _compile_expr(node.iter.args[1], ctx, True)
            _compile_expr(node.iter.args[2], ctx, True)
        else:
            raise SyntaxError('range() should have 1, 2, or 3 args')
        ctx.instructions.append(('set_local', step_stub))  # reversed order, pop from stack
        ctx.instructions.append(('set_local', end_stub))
        ctx.instructions.append(('set_local', start_stub))
        # Body
        target = ctx.name_idx(node.target.id)
        ctx.push_block('for')
        for i in [('get_local', start_stub), ('set_local', target), # Init target
                  ('block', 'emptyblock'), ('loop', 'emptyblock'),  # enter loop
                  ('get_local', target), ('get_local', end_stub), ('f64.ge'), ('br_if', 1),  # break (level 2)
                  ]:
            ctx.instructions.append(i)
        for subnode in node.body:
            _compile_expr(subnode, ctx, False)
        for i in [('get_local', target), ('get_local', step_stub), ('f64.add'), ('set_local', target),  # next iter
                  ('br', 0),  # loop
                  ('end'), ('end'),  # end of loop and outer block
                  ]:
            ctx.instructions.append(i)
        ctx.pop_block('for')
    
    elif isinstance(node, ast.While):
        # Check whether this is the kind of simple for-loop that we support
        if node.orelse:
            raise SyntaxError('While-loops do not support orelse.')
        # Body
        ctx.push_block('while')
        for i in [('block', 'emptyblock'), ('loop', 'emptyblock'),  # enter loop (outer block for break)
                  ]:
            ctx.instructions.append(i)
        for subnode in node.body:
            _compile_expr(subnode, ctx, False)
        _compile_expr(node.test, ctx, True)
        for i in [('br_if', 0),  # loop
                  ('end'), ('end'),  # end of loop
                  ]:
            ctx.instructions.append(i)
        ctx.pop_block('while')
    
    elif isinstance(node, ast.Continue):
        ctx.instructions.append(('br', ctx.get_block_level()))
    
    elif isinstance(node, ast.Break):
        ctx.instructions.append(('br', ctx.get_block_level() + 1))
    
    elif isinstance(node, ast.Call):
        if not isinstance(node.func, ast.Name):
            raise SyntaxError('Only support simple function names')
        if node.keywords:
            raise SyntaxError('No support for keyword args')
        name = node.func.id
        if name == 'print':
            assert len(node.args) == 1, 'print() accepts exactly one argument'
            _compile_expr(node.args[0], ctx, True)
            ctx.instructions.append(('call', 0))
        elif name == 'perf_counter':
            assert len(node.args) == 0, 'perf_counter() accepts exactly zero arguments'
            ctx.instructions.append(('call', 1))
        else:
            raise SyntaxError('Not a supported function: %s' % name)
    else:
        raise SyntaxError('Unsupported syntax: %s' % node.__class__.__name__)

class Context:
    
    def __init__(self):
        self.instructions = []
        self.names = {}
        self._name_counter = 0
        self._block_stack = []
    
    def name_idx(self, name):
        if name not in self.names:
            self.names[name] = self._name_counter
            self._name_counter += 1
        return self.names[name]
    
    def new_stub(self):
        name = 'stub' + str(self._name_counter)
        return self.name_idx(name)
    
    def push_block(self, kind):
        assert kind in ('if', 'for', 'while')
        self._block_stack.append(kind)
    
    def pop_block(self, kind):
        assert self._block_stack.pop(-1) == kind
    
    def get_block_level(self):
        for i, kind in enumerate(reversed(self._block_stack)):
            if kind in ('for', 'while'):
                return i

In [ ]:
def py2wasm(python_code):
    
    # Convert to AST
    tree = ast.parse(python_code)
    
    # Compile to instructions
    ctx = Context()
    for node in tree.body:
        _compile_expr(node, ctx, False)
    
    # Produce wasm module
    return wasmfun.Module(
        wasmfun.Function('$main', [], [], ['f64' for i in ctx.names], ctx.instructions),
        wasmfun.ImportedFuncion('print_ln', ['f64'], [], 'js', 'print_ln'),
        wasmfun.ImportedFuncion('perf_counter', [], ['f64'], 'js', 'perf_counter'),        
        )

In [ ]:
m = py2wasm("""
a = 0
b = 1
for i in range(10):
    print(a)
    c = b
    b = a + b
    a = c
""")

In [ ]:
len(m.to_bytes())

In [ ]:
wasmfun.run_wasm_in_notebook(m)

Prime numbers example


In [ ]:
PRIMES_CODE = """
max = 4000
n = 0
i = -1
t0 = perf_counter()

while n < max:
    i = i + 1
    
    if i <= 1:
        continue  # nope
    elif i == 2:
        n = n + 1
    else:
        gotit = 1
        for j in range(2, i//2 + 1):
            if i % j == 0:
                gotit = 0
                break
        if gotit == 1:
            n = n + 1

print(perf_counter() - t0)
print(i)
"""

Find primes in Python


In [ ]:
from time import perf_counter
exec(PRIMES_CODE)

Find primes is the browser


In [ ]:
wasmfun.run_wasm_in_notebook(py2wasm(PRIMES_CODE))

Find primes on desktop


In [ ]:
wasmfun.run_wasm_in_node(py2wasm(PRIMES_CODE))

Summarizing

  • Web Assembly is fast, compact, safe and open
  • Will give rise to exciting things
  • Can also have implications for Python

Thanks!

(Wanna learn more? Talk to me!)