Type Inference

Class definitions for the abstract syntax tree nodes which comprise the little language for which types will be inferred.

In [1]:
class Lambda(object):
    """Lambda abstraction"""

    def __init__(self, v, body):
        self.v = v
        self.body = body

    def __str__(self):
        return "(lambda ({v}) {body})".format(v=self.v, body=self.body)

class Identifier(object):

    def __init__(self, name):
        self.name = name

    def __str__(self):
        return self.name

class Apply(object):
    """Function application"""

    def __init__(self, fn, arg):
        self.fn = fn
        self.arg = arg

    def __str__(self):
        return "(apply {fn} {arg})".format(fn=self.fn, arg=self.arg)

Curry = Apply

class Let(object):
    """Let binding"""

    def __init__(self, v, defn, body):
        self.v = v
        self.defn = defn
        self.body = body

    def __str__(self):
        return "(let {v} = {defn} in {body})".format(v=self.v, defn=self.defn, body=self.body)

class Letrec(object):
    """Letrec binding"""

    def __init__(self, v, defn, body):
        self.v = v
        self.defn = defn
        self.body = body

    def __str__(self):
        return "(letrec {v} = {defn} in {body})".format(v=self.v, defn=self.defn, body=self.body)

Exception types

In [2]:
class InferenceError(Exception):
    """Raised if the type inference algorithm cannot infer types successfully"""

    def __init__(self, message):
        self.__message = message

    message = property(lambda self: self.__message)

    def __str__(self):
        return str(self.message)

class ParseError(Exception):
    """Raised if the type environment supplied for is incomplete"""

    def __init__(self, message):
        self.__message = message

    message = property(lambda self: self.__message)

    def __str__(self):
        return str(self.message)

Types and type constructors

In [3]:
class TypeVariable(object):
    A type variable standing for an arbitrary type.

    All type variables have a unique id, but names are only assigned lazily,
    when required.

    next_variable_id = 0

    def __init__(self):
        self.id = TypeVariable.next_variable_id
        TypeVariable.next_variable_id += 1
        self.instance = None
        self.__name = None

    next_variable_name = 'a'

    def name(self):
        Names are allocated to TypeVariables lazily, so that only TypeVariables
        if self.__name is None:
            self.__name = TypeVariable.next_variable_name
            TypeVariable.next_variable_name = chr(ord(TypeVariable.next_variable_name) + 1)
        return self.__name

    def __str__(self):
        if self.instance is not None:
            return str(self.instance)
            return self.name

    def __repr__(self):
        return "TypeVariable(id = {0})".format(self.id)

class TypeOperator(object):
    """An n-ary type constructor which builds a new type from old"""

    def __init__(self, name, types):
        self.name = name
        self.types = types

    def __str__(self):
        num_types = len(self.types)
        if num_types == 0:
            return self.name
        elif num_types == 2:
            return "({0} {1} {2})".format(str(self.types[0]), self.name, str(self.types[1]))
            return "{0} {1}" .format(self.name, ' '.join(self.types))

class Function(TypeOperator):
    """A binary type constructor which builds function types"""

    def __init__(self, from_type, to_type):
        super(Function, self).__init__("->", [from_type, to_type])

In [4]:
# Basic types are constructed with a nullary type constructor
Integer = TypeOperator("int", [])  # Basic integer
Bool = TypeOperator("bool", [])  # Basic bool

Type inference machinery

In [5]:
def analyse(node, type_map, non_generic=None):
    Computes the type of the expression given by node.

    The type of the node is computed in the context of the context of the
    supplied type environment type_map. Data types can be introduced into the
    language simply by having a predefined set of identifiers in the initial
    environment. environment; this way there is no need to change the syntax or, more
    importantly, the type-checking program when extending the language.

        node: The root of the abstract syntax tree.
        type_map: The type environment is a mapping of expression identifier names
            to type assignments.
        non_generic: A set of non-generic variables, or None

        The computed type of the expression.

        InferenceError: The type of the expression could not be inferred, for example
            if it is not possible to unify two types such as Integer and Bool
        ParseError: The abstract syntax tree rooted at node could not be parsed

    if non_generic is None:
        non_generic = set()

    if isinstance(node, Identifier):
        return get_type(node.name, type_map, non_generic)
    elif isinstance(node, (Apply, Curry)):
        fun_type = analyse(node.fn, type_map, non_generic)
        arg_type = analyse(node.arg, type_map, non_generic)
        result_type = TypeVariable()
        unify(Function(arg_type, result_type), fun_type)
        return result_type
    elif isinstance(node, Lambda):
        arg_type = TypeVariable()
        new_type_map = type_map.copy()
        new_type_map[node.v] = arg_type
        new_non_generic = non_generic.copy()
        result_type = analyse(node.body, new_type_map, new_non_generic)
        return Function(arg_type, result_type)
    elif isinstance(node, Let):
        defn_type = analyse(node.defn, type_map, non_generic)
        new_type_map = type_map.copy()
        new_type_map[node.v] = defn_type
        return analyse(node.body, new_type_map, non_generic)
    elif isinstance(node, Letrec):
        new_type = TypeVariable()
        new_type_map = type_map.copy()
        new_type_map[node.v] = new_type
        new_non_generic = non_generic.copy()
        defn_type = analyse(node.defn, new_type_map, new_non_generic)
        unify(new_type, defn_type)
        return analyse(node.body, new_type_map, non_generic)
    assert 0, "Unhandled syntax node {0}".format(type(node))

In [6]:
def get_type(name, type_map, non_generic):
    Get the type of identifier name from the type environment type_map.

        name: The identifier name
        type_map: The type environment mapping from identifier names to types
        non_generic: A set of non-generic TypeVariables

        ParseError: Raised if name is an undefined symbol in the type
    if name in type_map:
        return fresh(type_map[name], non_generic)
    elif is_integer_literal(name):
        return Integer
        raise ParseError("Undefined symbol {0}".format(name))

def fresh(t, non_generic):
    Makes a copy of a type expression.

    The type t is copied. The the generic variables are duplicated and the
    non_generic variables are shared.

        t: A type to be copied.
        non_generic: A set of non-generic TypeVariables
    mappings = {}  # A mapping of TypeVariables to TypeVariables

    def freshrec(tp):
        p = find(tp)
        if isinstance(p, TypeVariable):
            if is_generic(p, non_generic):
                if p not in mappings:
                    mappings[p] = TypeVariable()
                return mappings[p]
                return p
        elif isinstance(p, TypeOperator):
            return TypeOperator(p.name, [freshrec(x) for x in p.types])

    return freshrec(t)
  ta = find(ta)
  tb = find(tb)
  if both ta,tb are terms of the form D p1..pn with identical D,n then
    unify(ta[i],tb[i]) for each corresponding ith parameter
  if at least one of ta,tb is a type variable then
    error 'types do not match'

In [7]:
def unify(t1, t2):
    Unify the two types t1 and t2.

    Makes the types t1 and t2 the same.

        t1: The first type to be made equivalent
        t2: The second type to be be equivalent


        InferenceError: Raised if the types cannot be unified.
    a = find(t1)
    b = find(t2)
    if isinstance(a, TypeVariable):
        if a != b:
            if occurs_in_type(a, b):
                raise InferenceError("recursive unification")
            a.instance = b
    elif isinstance(a, TypeOperator) and isinstance(b, TypeVariable):
        unify(b, a)
    elif isinstance(a, TypeOperator) and isinstance(b, TypeOperator):
        if a.name != b.name or len(a.types) != len(b.types):
            raise InferenceError("Type mismatch: {0} != {1}".format(str(a), str(b)))
        for p, q in zip(a.types, b.types):
            unify(p, q)
        raise InferenceError('types do not match')

In [8]:
def find(t):
    Returns the currently defining instance of t.

    As a side effect, collapses the list of type instances. The function find
    is used whenever a type expression has to be inspected: it will always
    return a type expression which is either an uninstantiated type variable or
    a type operator; i.e. it will skip instantiated variables, and will
    actually find them from expressions to remove long chains of instantiated

        t: The type to be found

        An uninstantiated TypeVariable or a TypeOperator
    if isinstance(t, TypeVariable):
        if t.instance is not None:
            t.instance = find(t.instance)
            return t.instance
    return t

def is_generic(v, non_generic):
    Checks whether a given variable occurs in a list of non-generic variables

    Note that a variables in such a list may be instantiated to a type term,
    in which case the variables contained in the type term are considered

    Note: Must be called with v pre-found

        v: The TypeVariable to be tested for genericity
        non_generic: A set of non-generic TypeVariables

        True if v is a generic variable, otherwise False
    return not occurs_in(v, non_generic)

def occurs_in_type(v, type2):
    Checks whether a type variable occurs in a type expression.

    Note: Must be called with v pre-found

        v:  The TypeVariable to be tested for
        type2: The type in which to search

        True if v occurs in type2, otherwise False
    found_type2 = find(type2)
    if found_type2 == v:
        return True
    elif isinstance(found_type2, TypeOperator):
        return occurs_in(v, found_type2.types)
    return False

def occurs_in(t, types):
    Checks whether a types variable occurs in any other types.

        t:  The TypeVariable to be tested for
        types: The sequence of types in which to search

        True if t occurs in any of types, otherwise False
    return any(occurs_in_type(t, t2) for t2 in types)

def is_integer_literal(name):
    Checks whether name is an integer literal string.

        name: The identifier to check

        True if name is an integer literal, otherwise False
    result = True
    except ValueError:
        result = False
    return result

In [9]:
# ==================================================================#
# Example code to exercise the above

def try_analyse(node, type_map):
    Try to evaluate a type, printing the result or reporting errors.

        node: The root node of the abstract syntax tree of the expression.
        type_map: The type environment in which to evaluate the expression.

    print(str(node) + " : ", end=' ')
        t = analyse(node, type_map)
    except (ParseError, InferenceError) as e:

In [10]:
var1 = TypeVariable()
var2 = TypeVariable()
var3 = TypeVariable()

pair_type = TypeOperator("*", (var1, var2))

type_map = {
    "pair": Function(var1, Function(var2, pair_type)),
    "#t": Bool,
    "if": Function(Bool, Function(var3, Function(var3, var3))),
    "=": Function(Integer, Bool),
    "pred": Function(Integer, Integer),
    "*": Function(Integer, Function(Integer, Integer))

In [11]:
import metakernel; metakernel.register_ipython_magics()

In [12]:

(define pred
 (lambda (n) (- n 1)))

(letrec ((factorial 
          (lambda (n)
            (if (zero? n) 
                (* n (factorial (pred n)))))))
 (factorial 5))


In [13]:
# factorial
e1  = Letrec("factorial",  # letrec factorial =
               Lambda("n",  # fn n =>
                      Curry( # if (zero n) 1
                              Apply(Identifier("if"),  # if (zero n)
                                    Apply(Identifier("="), Identifier("n"))),
                          Curry(  # * n
                              Apply(Identifier("*"), Identifier("n")),
                                    Apply(Identifier("pred"), Identifier("n")))
                      ),  # in
               Apply(Identifier("factorial"), Identifier("5"))

In [14]:
try_analyse(e1, type_map)

(letrec factorial = (lambda (n) (apply (apply (apply if (apply = n)) 1) (apply (apply * n) (apply factorial (apply pred n))))) in (apply factorial 5)) :  int

In [15]:
# Should fail:
# fn x => (pair(x(3) (x(true)))
fail = Lambda("x",
                     Apply(Identifier("x"), Identifier("3"))),
               Apply(Identifier("x"), Identifier("#t"))))

In [16]:
try_analyse(fail, type_map)

(lambda (x) (apply (apply pair (apply x 3)) (apply x #t))) :  Type mismatch: bool != int

In [17]:
# pair(f(3), f(true))
e3 = Apply(
        Apply(Identifier("pair"), Apply(Identifier("f"), Identifier("4"))),
        Apply(Identifier("f"), Identifier("#t")))

In [18]:
try_analyse(e3, type_map)

(apply (apply pair (apply f 4)) (apply f #t)) :  Undefined symbol f

In [19]:
pair = Apply(Apply(Identifier("pair"),

In [20]:
# let f = (fn x => x) in ((pair (f 4)) (f true))
e4 = Let("f", Lambda("x", Identifier("x")), pair)

In [21]:
try_analyse(e4, type_map)

(let f = (lambda (x) x) in (apply (apply pair (apply f 4)) (apply f #t))) :  (int * bool)

In [22]:
# fn f => f f (fail)
e5 = Lambda("f", Apply(Identifier("f"), Identifier("f")))

In [23]:
try_analyse(e5, type_map)

(lambda (f) (apply f f)) :  recursive unification

In [24]:
# let g = fn f => 5 in g g
e6 = Let("g",
        Lambda("f", Identifier("5")),
        Apply(Identifier("g"), Identifier("g")))

In [25]:
try_analyse(e6, type_map)

(let g = (lambda (f) 5) in (apply g g)) :  int

In [30]:
# example that demonstrates generic and non-generic variables:
# fn g => let f = fn x => g in pair (f 3, f true)
e7 = Lambda("g",
               Lambda("x", Identifier("g")),
                         Apply(Identifier("f"), Identifier("3"))
                   Apply(Identifier("f"), Identifier("#t")))))

In [31]:
try_analyse(e7, type_map)

(lambda (g) (let f = (lambda (x) g) in (apply (apply pair (apply f 3)) (apply f #t)))) :  (e -> (e * e))

In [28]:
# Function composition
# fn f (fn g (fn arg (f g arg)))
e8 = Lambda("f", Lambda("g", Lambda("arg", Apply(Identifier("g"), Apply(Identifier("f"), Identifier("arg"))))))

In [29]:
try_analyse(e8, type_map)

(lambda (f) (lambda (g) (lambda (arg) (apply g (apply f arg))))) :  ((b -> c) -> ((c -> d) -> (b -> d)))

In [ ]: