Class definitions for the abstract syntax tree nodes which comprise the little language for which types will be inferred.
In [1]:
class Lambda(object):
"""Lambda abstraction"""
def __init__(self, v, body):
self.v = v
self.body = body
def __str__(self):
return "(lambda ({v}) {body})".format(v=self.v, body=self.body)
class Identifier(object):
"""Identifier"""
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
class Apply(object):
"""Function application"""
def __init__(self, fn, arg):
self.fn = fn
self.arg = arg
def __str__(self):
return "(apply {fn} {arg})".format(fn=self.fn, arg=self.arg)
Curry = Apply
class Let(object):
"""Let binding"""
def __init__(self, v, defn, body):
self.v = v
self.defn = defn
self.body = body
def __str__(self):
return "(let {v} = {defn} in {body})".format(v=self.v, defn=self.defn, body=self.body)
class Letrec(object):
"""Letrec binding"""
def __init__(self, v, defn, body):
self.v = v
self.defn = defn
self.body = body
def __str__(self):
return "(letrec {v} = {defn} in {body})".format(v=self.v, defn=self.defn, body=self.body)
Exception types
In [2]:
class InferenceError(Exception):
"""Raised if the type inference algorithm cannot infer types successfully"""
def __init__(self, message):
self.__message = message
message = property(lambda self: self.__message)
def __str__(self):
return str(self.message)
class ParseError(Exception):
"""Raised if the type environment supplied for is incomplete"""
def __init__(self, message):
self.__message = message
message = property(lambda self: self.__message)
def __str__(self):
return str(self.message)
Types and type constructors
In [3]:
class TypeVariable(object):
"""
A type variable standing for an arbitrary type.
All type variables have a unique id, but names are only assigned lazily,
when required.
"""
next_variable_id = 0
def __init__(self):
self.id = TypeVariable.next_variable_id
TypeVariable.next_variable_id += 1
self.instance = None
self.__name = None
next_variable_name = 'a'
@property
def name(self):
"""
Names are allocated to TypeVariables lazily, so that only TypeVariables
present
"""
if self.__name is None:
self.__name = TypeVariable.next_variable_name
TypeVariable.next_variable_name = chr(ord(TypeVariable.next_variable_name) + 1)
return self.__name
def __str__(self):
if self.instance is not None:
return str(self.instance)
else:
return self.name
def __repr__(self):
return "TypeVariable(id = {0})".format(self.id)
class TypeOperator(object):
"""An n-ary type constructor which builds a new type from old"""
def __init__(self, name, types):
self.name = name
self.types = types
def __str__(self):
num_types = len(self.types)
if num_types == 0:
return self.name
elif num_types == 2:
return "({0} {1} {2})".format(str(self.types[0]), self.name, str(self.types[1]))
else:
return "{0} {1}" .format(self.name, ' '.join(self.types))
class Function(TypeOperator):
"""A binary type constructor which builds function types"""
def __init__(self, from_type, to_type):
super(Function, self).__init__("->", [from_type, to_type])
In [4]:
# Basic types are constructed with a nullary type constructor
Integer = TypeOperator("int", []) # Basic integer
Bool = TypeOperator("bool", []) # Basic bool
Type inference machinery
In [5]:
def analyse(node, type_map, non_generic=None):
"""
Computes the type of the expression given by node.
The type of the node is computed in the context of the context of the
supplied type environment type_map. Data types can be introduced into the
language simply by having a predefined set of identifiers in the initial
environment. environment; this way there is no need to change the syntax or, more
importantly, the type-checking program when extending the language.
Args:
node: The root of the abstract syntax tree.
type_map: The type environment is a mapping of expression identifier names
to type assignments.
non_generic: A set of non-generic variables, or None
Returns:
The computed type of the expression.
Raises:
InferenceError: The type of the expression could not be inferred, for example
if it is not possible to unify two types such as Integer and Bool
ParseError: The abstract syntax tree rooted at node could not be parsed
"""
if non_generic is None:
non_generic = set()
if isinstance(node, Identifier):
return get_type(node.name, type_map, non_generic)
elif isinstance(node, (Apply, Curry)):
fun_type = analyse(node.fn, type_map, non_generic)
arg_type = analyse(node.arg, type_map, non_generic)
result_type = TypeVariable()
unify(Function(arg_type, result_type), fun_type)
return result_type
elif isinstance(node, Lambda):
arg_type = TypeVariable()
new_type_map = type_map.copy()
new_type_map[node.v] = arg_type
new_non_generic = non_generic.copy()
new_non_generic.add(arg_type)
result_type = analyse(node.body, new_type_map, new_non_generic)
return Function(arg_type, result_type)
elif isinstance(node, Let):
defn_type = analyse(node.defn, type_map, non_generic)
new_type_map = type_map.copy()
new_type_map[node.v] = defn_type
return analyse(node.body, new_type_map, non_generic)
elif isinstance(node, Letrec):
new_type = TypeVariable()
new_type_map = type_map.copy()
new_type_map[node.v] = new_type
new_non_generic = non_generic.copy()
new_non_generic.add(new_type)
defn_type = analyse(node.defn, new_type_map, new_non_generic)
unify(new_type, defn_type)
return analyse(node.body, new_type_map, non_generic)
assert 0, "Unhandled syntax node {0}".format(type(node))
In [6]:
def get_type(name, type_map, non_generic):
"""
Get the type of identifier name from the type environment type_map.
Args:
name: The identifier name
type_map: The type environment mapping from identifier names to types
non_generic: A set of non-generic TypeVariables
Raises:
ParseError: Raised if name is an undefined symbol in the type
environment.
"""
if name in type_map:
return fresh(type_map[name], non_generic)
elif is_integer_literal(name):
return Integer
else:
raise ParseError("Undefined symbol {0}".format(name))
def fresh(t, non_generic):
"""
Makes a copy of a type expression.
The type t is copied. The the generic variables are duplicated and the
non_generic variables are shared.
Args:
t: A type to be copied.
non_generic: A set of non-generic TypeVariables
"""
mappings = {} # A mapping of TypeVariables to TypeVariables
def freshrec(tp):
p = find(tp)
if isinstance(p, TypeVariable):
if is_generic(p, non_generic):
if p not in mappings:
mappings[p] = TypeVariable()
return mappings[p]
else:
return p
elif isinstance(p, TypeOperator):
return TypeOperator(p.name, [freshrec(x) for x in p.types])
return freshrec(t)
unify(ta,tb):
ta = find(ta)
tb = find(tb)
if both ta,tb are terms of the form D p1..pn with identical D,n then
unify(ta[i],tb[i]) for each corresponding ith parameter
else
if at least one of ta,tb is a type variable then
union(ta,tb)
else
error 'types do not match'
In [7]:
def unify(t1, t2):
"""
Unify the two types t1 and t2.
Makes the types t1 and t2 the same.
Args:
t1: The first type to be made equivalent
t2: The second type to be be equivalent
Returns:
None
Raises:
InferenceError: Raised if the types cannot be unified.
"""
a = find(t1)
b = find(t2)
if isinstance(a, TypeVariable):
if a != b:
if occurs_in_type(a, b):
raise InferenceError("recursive unification")
a.instance = b
elif isinstance(a, TypeOperator) and isinstance(b, TypeVariable):
unify(b, a)
elif isinstance(a, TypeOperator) and isinstance(b, TypeOperator):
if a.name != b.name or len(a.types) != len(b.types):
raise InferenceError("Type mismatch: {0} != {1}".format(str(a), str(b)))
for p, q in zip(a.types, b.types):
unify(p, q)
else:
raise InferenceError('types do not match')
In [8]:
def find(t):
"""
Returns the currently defining instance of t.
As a side effect, collapses the list of type instances. The function find
is used whenever a type expression has to be inspected: it will always
return a type expression which is either an uninstantiated type variable or
a type operator; i.e. it will skip instantiated variables, and will
actually find them from expressions to remove long chains of instantiated
variables.
Args:
t: The type to be found
Returns:
An uninstantiated TypeVariable or a TypeOperator
"""
if isinstance(t, TypeVariable):
if t.instance is not None:
t.instance = find(t.instance)
return t.instance
return t
def is_generic(v, non_generic):
"""
Checks whether a given variable occurs in a list of non-generic variables
Note that a variables in such a list may be instantiated to a type term,
in which case the variables contained in the type term are considered
non-generic.
Note: Must be called with v pre-found
Args:
v: The TypeVariable to be tested for genericity
non_generic: A set of non-generic TypeVariables
Returns:
True if v is a generic variable, otherwise False
"""
return not occurs_in(v, non_generic)
def occurs_in_type(v, type2):
"""
Checks whether a type variable occurs in a type expression.
Note: Must be called with v pre-found
Args:
v: The TypeVariable to be tested for
type2: The type in which to search
Returns:
True if v occurs in type2, otherwise False
"""
found_type2 = find(type2)
if found_type2 == v:
return True
elif isinstance(found_type2, TypeOperator):
return occurs_in(v, found_type2.types)
return False
def occurs_in(t, types):
"""
Checks whether a types variable occurs in any other types.
Args:
t: The TypeVariable to be tested for
types: The sequence of types in which to search
Returns:
True if t occurs in any of types, otherwise False
"""
return any(occurs_in_type(t, t2) for t2 in types)
def is_integer_literal(name):
"""
Checks whether name is an integer literal string.
Args:
name: The identifier to check
Returns:
True if name is an integer literal, otherwise False
"""
result = True
try:
int(name)
except ValueError:
result = False
return result
In [9]:
# ==================================================================#
# Example code to exercise the above
def try_analyse(node, type_map):
"""
Try to evaluate a type, printing the result or reporting errors.
Args:
node: The root node of the abstract syntax tree of the expression.
type_map: The type environment in which to evaluate the expression.
Returns:
None
"""
print(str(node) + " : ", end=' ')
try:
t = analyse(node, type_map)
print(str(t))
except (ParseError, InferenceError) as e:
print(e)
In [10]:
var1 = TypeVariable()
var2 = TypeVariable()
var3 = TypeVariable()
pair_type = TypeOperator("*", (var1, var2))
type_map = {
"pair": Function(var1, Function(var2, pair_type)),
"#t": Bool,
"if": Function(Bool, Function(var3, Function(var3, var3))),
"=": Function(Integer, Bool),
"pred": Function(Integer, Integer),
"*": Function(Integer, Function(Integer, Integer))
}
In [11]:
import metakernel; metakernel.register_ipython_magics()
In [12]:
%%scheme
(define pred
(lambda (n) (- n 1)))
(letrec ((factorial
(lambda (n)
(if (zero? n)
1
(* n (factorial (pred n)))))))
(factorial 5))
Out[12]:
In [13]:
# factorial
e1 = Letrec("factorial", # letrec factorial =
Lambda("n", # fn n =>
Curry( # if (zero n) 1
Curry(
Apply(Identifier("if"), # if (zero n)
Apply(Identifier("="), Identifier("n"))),
Identifier("1")),
Curry( # * n
Apply(Identifier("*"), Identifier("n")),
Apply(Identifier("factorial"),
Apply(Identifier("pred"), Identifier("n")))
)
)
), # in
Apply(Identifier("factorial"), Identifier("5"))
)
In [14]:
try_analyse(e1, type_map)
In [15]:
# Should fail:
# fn x => (pair(x(3) (x(true)))
fail = Lambda("x",
Apply(
Apply(Identifier("pair"),
Apply(Identifier("x"), Identifier("3"))),
Apply(Identifier("x"), Identifier("#t"))))
In [16]:
try_analyse(fail, type_map)
In [17]:
# pair(f(3), f(true))
e3 = Apply(
Apply(Identifier("pair"), Apply(Identifier("f"), Identifier("4"))),
Apply(Identifier("f"), Identifier("#t")))
In [18]:
try_analyse(e3, type_map)
In [19]:
pair = Apply(Apply(Identifier("pair"),
Apply(Identifier("f"),
Identifier("4"))),
Apply(Identifier("f"),
Identifier("#t")))
In [20]:
# let f = (fn x => x) in ((pair (f 4)) (f true))
e4 = Let("f", Lambda("x", Identifier("x")), pair)
In [21]:
try_analyse(e4, type_map)
In [22]:
# fn f => f f (fail)
e5 = Lambda("f", Apply(Identifier("f"), Identifier("f")))
In [23]:
try_analyse(e5, type_map)
In [24]:
# let g = fn f => 5 in g g
e6 = Let("g",
Lambda("f", Identifier("5")),
Apply(Identifier("g"), Identifier("g")))
In [25]:
try_analyse(e6, type_map)
In [30]:
# example that demonstrates generic and non-generic variables:
# fn g => let f = fn x => g in pair (f 3, f true)
e7 = Lambda("g",
Let("f",
Lambda("x", Identifier("g")),
Curry(
Apply(Identifier("pair"),
Apply(Identifier("f"), Identifier("3"))
),
Apply(Identifier("f"), Identifier("#t")))))
In [31]:
try_analyse(e7, type_map)
In [28]:
# Function composition
# fn f (fn g (fn arg (f g arg)))
e8 = Lambda("f", Lambda("g", Lambda("arg", Apply(Identifier("g"), Apply(Identifier("f"), Identifier("arg"))))))
In [29]:
try_analyse(e8, type_map)
In [ ]: