AST - Abstract Syntax Tree

For reasons, I want to parse a python source code file and extract certain elements. The case in point involves looking for all functions with a given decorator applied and return certain attributes of the function declaration.

When assuming a certain coding style, this could probably be done with a handful of lines or even a regex. This becomes problematic if youu want to be able to properly parse any and all (valid) python code. You'll soon find yourself reinventing the (lexer-)wheel which is already available in Python itsself.

Thanks to others, there is a built-in ast module which parses Python source code into an AST. The AST can then be inspected and modified, and even recompiled into source code. In our case we are only interested in inspection.


In [1]:
import ast
example_module = '''
@my_decorator
def my_function(my_argument):
    """My Docstring"""
    my_value = 420
    return my_value
    
def foo():
    pass
    
@Some_decorator
@Another_decorator
def bar():
    pass
    
@MyClass.subpackage.my_deco_function    
def baz():
    pass'''

The ast module "helps Python applications to process trees of the Python abstract syntax grammar. The abstract syntax itself might change with each Python release; this module helps to find out programmatically what the current grammar looks like."

The tree of objects all inherit from ast.AST and the actual types and their properties can be found in the so called ASDL. The actual grammar of python as defined in the Zephyr Abstract Syntax Definition Language. The grammar file can be found in the Python sources at Parser/python.asdl.


In [2]:
tree = ast.parse(example_module)
print(tree) # the object


<_ast.Module object at 0x10ec9d668>

In [3]:
# Built in dump method shows the actual content of the entire tree
print(ast.dump(ast.parse(example_module)))


Module(body=[FunctionDef(name='my_function', args=arguments(args=[arg(arg='my_argument', annotation=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Expr(value=Str(s='My Docstring')), Assign(targets=[Name(id='my_value', ctx=Store())], value=Num(n=420)), Return(value=Name(id='my_value', ctx=Load()))], decorator_list=[Name(id='my_decorator', ctx=Load())], returns=None), FunctionDef(name='foo', args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Pass()], decorator_list=[], returns=None), FunctionDef(name='bar', args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Pass()], decorator_list=[Name(id='Some_decorator', ctx=Load()), Name(id='Another_decorator', ctx=Load())], returns=None), FunctionDef(name='baz', args=arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Pass()], decorator_list=[Attribute(value=Attribute(value=Name(id='MyClass', ctx=Load()), attr='subpackage', ctx=Load()), attr='my_deco_function', ctx=Load())], returns=None)])

The astunparse module helps in pretty printing the tree, which we rely heavy upon during exploration.


In [20]:
import astunparse
print(astunparse.dump(tree))


Module(body=[
  FunctionDef(
    name='my_function',
    args=arguments(
      args=[arg(
        arg='my_argument',
        annotation=None)],
      vararg=None,
      kwonlyargs=[],
      kw_defaults=[],
      kwarg=None,
      defaults=[]),
    body=[
      Expr(value=Str(s='My Docstring')),
      Assign(
        targets=[Name(
          id='my_value',
          ctx=Store())],
        value=Num(n=420)),
      Return(value=Name(
        id='my_value',
        ctx=Load()))],
    decorator_list=[Name(
      id='my_decorator',
      ctx=Load())],
    returns=None),
  FunctionDef(
    name='foo',
    args=arguments(
      args=[],
      vararg=None,
      kwonlyargs=[],
      kw_defaults=[],
      kwarg=None,
      defaults=[]),
    body=[Pass()],
    decorator_list=[],
    returns=None),
  FunctionDef(
    name='bar',
    args=arguments(
      args=[],
      vararg=None,
      kwonlyargs=[],
      kw_defaults=[],
      kwarg=None,
      defaults=[]),
    body=[Pass()],
    decorator_list=[
      Name(
        id='Some_decorator',
        ctx=Load()),
      Name(
        id='Another_decorator',
        ctx=Load())],
    returns=None),
  FunctionDef(
    name='baz',
    args=arguments(
      args=[],
      vararg=None,
      kwonlyargs=[],
      kw_defaults=[],
      kwarg=None,
      defaults=[]),
    body=[Pass()],
    decorator_list=[Attribute(
      value=Attribute(
        value=Name(
          id='MyClass',
          ctx=Load()),
        attr='subpackage',
        ctx=Load()),
      attr='my_deco_function',
      ctx=Load())],
    returns=None)])

We want to look at function definitions which are aptly named FunctionDef in the ASDL and represented as FunctionDef objects in the tree. Looking at the ASDL we see the following deifnition for FunctionDef (reformatted):

FunctionDef(identifier name,
            arguments args,
            stmt* body,
            expr* decorator_list,
            expr? returns,
            string? docstring)

Which seems to correspond to the structure of the object in the AST as shown in the astunparse dump above. There is some documentation at a place called Green Tree Snakes which explains the components of the FunctionDef object.

Traversing and inspecting the tree

There are two ways to work with the tree. The easiest is ast.walk() which "Recursively yield all descendant nodes in the tree starting at node (including node itself), in no specified order." and apparently does so breadth first. Alternatively you can subclass the ast.NodeVisitor class. This class provides a visit() method which does a depth first traversal. You can override visit_<Class_Name> which are called whenever the traversal hits a node of that class.


In [5]:
class MyVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        print(f'Nodetype: {type(node).__name__:{16}} {node}')
        ast.NodeVisitor.generic_visit(self, node)
        

v = MyVisitor()
print('Using NodeVisitor (depth first):')
v.visit(tree)

print('\nWalk()ing the tree breadth first:')
for node in ast.walk(tree):
    print(f'Nodetype: {type(node).__name__:{16}} {node}')


Using NodeVisitor (depth first):
Nodetype: Module           <_ast.Module object at 0x10ec9d668>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d6a0>
Nodetype: arguments        <_ast.arguments object at 0x10ec9d6d8>
Nodetype: arg              <_ast.arg object at 0x10ec9d710>
Nodetype: Expr             <_ast.Expr object at 0x10ec9d748>
Nodetype: Str              <_ast.Str object at 0x10ec9d780>
Nodetype: Assign           <_ast.Assign object at 0x10ec9d7b8>
Nodetype: Name             <_ast.Name object at 0x10ec9d7f0>
Nodetype: Store            <_ast.Store object at 0x10d24e780>
Nodetype: Num              <_ast.Num object at 0x10ec9d828>
Nodetype: Return           <_ast.Return object at 0x10ec9d860>
Nodetype: Name             <_ast.Name object at 0x10ec9d898>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Name             <_ast.Name object at 0x10ec9d8d0>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d908>
Nodetype: arguments        <_ast.arguments object at 0x10ec9d940>
Nodetype: Pass             <_ast.Pass object at 0x10ec9d978>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d9b0>
Nodetype: arguments        <_ast.arguments object at 0x10ec9d9e8>
Nodetype: Pass             <_ast.Pass object at 0x10ec9da20>
Nodetype: Name             <_ast.Name object at 0x10ec9da58>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Name             <_ast.Name object at 0x10ec9da90>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9dac8>
Nodetype: arguments        <_ast.arguments object at 0x10ec9db00>
Nodetype: Pass             <_ast.Pass object at 0x10ec9db38>
Nodetype: Attribute        <_ast.Attribute object at 0x10ec9db70>
Nodetype: Attribute        <_ast.Attribute object at 0x10ec9dba8>
Nodetype: Name             <_ast.Name object at 0x10ec9dbe0>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Load             <_ast.Load object at 0x10d24e668>

Walk()ing the tree breadth first:
Nodetype: Module           <_ast.Module object at 0x10ec9d668>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d6a0>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d908>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d9b0>
Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9dac8>
Nodetype: arguments        <_ast.arguments object at 0x10ec9d6d8>
Nodetype: Expr             <_ast.Expr object at 0x10ec9d748>
Nodetype: Assign           <_ast.Assign object at 0x10ec9d7b8>
Nodetype: Return           <_ast.Return object at 0x10ec9d860>
Nodetype: Name             <_ast.Name object at 0x10ec9d8d0>
Nodetype: arguments        <_ast.arguments object at 0x10ec9d940>
Nodetype: Pass             <_ast.Pass object at 0x10ec9d978>
Nodetype: arguments        <_ast.arguments object at 0x10ec9d9e8>
Nodetype: Pass             <_ast.Pass object at 0x10ec9da20>
Nodetype: Name             <_ast.Name object at 0x10ec9da58>
Nodetype: Name             <_ast.Name object at 0x10ec9da90>
Nodetype: arguments        <_ast.arguments object at 0x10ec9db00>
Nodetype: Pass             <_ast.Pass object at 0x10ec9db38>
Nodetype: Attribute        <_ast.Attribute object at 0x10ec9db70>
Nodetype: arg              <_ast.arg object at 0x10ec9d710>
Nodetype: Str              <_ast.Str object at 0x10ec9d780>
Nodetype: Name             <_ast.Name object at 0x10ec9d7f0>
Nodetype: Num              <_ast.Num object at 0x10ec9d828>
Nodetype: Name             <_ast.Name object at 0x10ec9d898>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Attribute        <_ast.Attribute object at 0x10ec9dba8>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Store            <_ast.Store object at 0x10d24e780>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Name             <_ast.Name object at 0x10ec9dbe0>
Nodetype: Load             <_ast.Load object at 0x10d24e668>
Nodetype: Load             <_ast.Load object at 0x10d24e668>

For our purposes we should be able to use the walk method, I find it simpler to use for now. Let;s see what happens if we grab those FunctionDef objects and inspect them in the same way. Using the unparse() methof of astunparse we can transform it back into source code for extra fun.


In [6]:
for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        print(f'Nodetype: {type(node).__name__:{16}} {node}')
        print(astunparse.unparse(node))


Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d6a0>


@my_decorator
def my_function(my_argument):
    'My Docstring'
    my_value = 420
    return my_value

Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d908>


def foo():
    pass

Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9d9b0>


@Some_decorator
@Another_decorator
def bar():
    pass

Nodetype: FunctionDef      <_ast.FunctionDef object at 0x10ec9dac8>


@MyClass.subpackage.my_deco_function
def baz():
    pass

We wanted to only grab functions who have a certain decorator, so we need to inspect the decorator_list attribute of the FunctionDef class.


In [7]:
for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        decorators = [d.id for d in node.decorator_list]
        print(node.name, decorators)


my_function ['my_decorator']
foo []
bar ['Some_decorator', 'Another_decorator']
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-7-666517aaaecf> in <module>()
      1 for node in ast.walk(tree):
      2     if isinstance(node, ast.FunctionDef):
----> 3         decorators = [d.id for d in node.decorator_list]
      4         print(node.name, decorators)

<ipython-input-7-666517aaaecf> in <listcomp>(.0)
      1 for node in ast.walk(tree):
      2     if isinstance(node, ast.FunctionDef):
----> 3         decorators = [d.id for d in node.decorator_list]
      4         print(node.name, decorators)

AttributeError: 'Attribute' object has no attribute 'id'

So looking more closely there is a different representation in the AST for a single keyword (@function) decorator as there is for a compound (@Class.method). Compare the decorator in my_function:

decorator_list=[Name(
      id='my_decorator',
      ctx=Load())]

against the compound decorator in baz:

decorator_list=[Attribute(
      value=Attribute(
        value=Name(
          id='MyClass',
          ctx=Load()),
        attr='subpackage',
        ctx=Load()),
      attr='my_deco_function',
      ctx=Load())]

So we need to modify our treewalk to acomodate for this. When the top level element in the decorator_liist is of type Name, we grab the id and be done with it. If it is of type Attribute we need to do some more extra work. From the ASDL we can see that Attribute is a nested element:

Attribute(expr value, identifier attr, expr_context ctx)

Assuming it's nested ast.Attributes with a ast.Name at the root we can define a flattening function.


In [8]:
def flatten_attr(node):
    if isinstance(node, ast.Attribute):
        return str(flatten_attr(node.value)) + '.' + node.attr
    elif isinstance(node, ast.Name):
        return str(node.id)
    else:
        pass

In [9]:
for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        found_decorators = []
        for decorator in node.decorator_list:
            if isinstance(decorator, ast.Name):
                found_decorators.append(decorator.id)
            elif isinstance(decorator, ast.Attribute):
                    found_decorators.append(flatten_attr(decorator))
            
                
        print(node.name, found_decorators)


my_function ['my_decorator']
foo []
bar ['Some_decorator', 'Another_decorator']
baz ['MyClass.subpackage.my_deco_function']

The actual sources I want to parse have an additional complication, the decorator functions have arguments passed into them. And I want to know what's in them as well. So let's switch to some actual source code and see how to do that. I have removed the body of the function as we are only interested in the decorator now.


In [10]:
source = """
@Route.get(
    r"/projects/{project_id}/snapshots",
    description="List snapshots of a project",
    parameters={
        "project_id": "Project UUID",
    },
    status_codes={
        200: "Snasphot list returned",
        404: "The project doesn't exist"
    })
def list(request, response):
    pass"""

print(astunparse.dump(ast.parse(source)))


Module(body=[FunctionDef(
  name='list',
  args=arguments(
    args=[
      arg(
        arg='request',
        annotation=None),
      arg(
        arg='response',
        annotation=None)],
    vararg=None,
    kwonlyargs=[],
    kw_defaults=[],
    kwarg=None,
    defaults=[]),
  body=[Pass()],
  decorator_list=[Call(
    func=Attribute(
      value=Name(
        id='Route',
        ctx=Load()),
      attr='get',
      ctx=Load()),
    args=[Str(s='/projects/{project_id}/snapshots')],
    keywords=[
      keyword(
        arg='description',
        value=Str(s='List snapshots of a project')),
      keyword(
        arg='parameters',
        value=Dict(
          keys=[Str(s='project_id')],
          values=[Str(s='Project UUID')])),
      keyword(
        arg='status_codes',
        value=Dict(
          keys=[
            Num(n=200),
            Num(n=404)],
          values=[
            Str(s='Snasphot list returned'),
            Str(s="The project doesn't exist")]))])],
  returns=None)])

We find the decorator_list to contain a ast.Call object rather than a Name or Attribute. This corresponds to the signature of the called decorator function. I am interested in the first positional argument as well as the keyword arguments. Let's grab the [0] element of the decorator list to simplify.


In [11]:
complex_decorator = ast.parse(source).body[0].decorator_list[0]
print(astunparse.dump(complex_decorator))


Call(
  func=Attribute(
    value=Name(
      id='Route',
      ctx=Load()),
    attr='get',
    ctx=Load()),
  args=[Str(s='/projects/{project_id}/snapshots')],
  keywords=[
    keyword(
      arg='description',
      value=Str(s='List snapshots of a project')),
    keyword(
      arg='parameters',
      value=Dict(
        keys=[Str(s='project_id')],
        values=[Str(s='Project UUID')])),
    keyword(
      arg='status_codes',
      value=Dict(
        keys=[
          Num(n=200),
          Num(n=404)],
        values=[
          Str(s='Snasphot list returned'),
          Str(s="The project doesn't exist")]))])

In [21]:
decorator_name = flatten_attr(complex_decorator.func)
decorator_path = complex_decorator.args[0].s
for kw in complex_decorator.keywords:
    if kw.arg == 'description':
        decorator_description = kw.value.s
    if kw.arg == 'parameters':
        decorator_parameters = ast.literal_eval(astunparse.unparse(kw.value))
    if kw.arg == 'status_codes':
        decorator_statuscodes = ast.literal_eval(astunparse.unparse(kw.value))

print(decorator_name, decorator_path)
print('Parameters:')
for p in decorator_parameters.keys():
    print('  ' + str(p) + ': ' + decorator_parameters[p])    
print('Status Codes:')
for sc in decorator_statuscodes.keys():
    print('  ' + str(sc) + ': ' + decorator_statuscodes[sc])


Route.get /projects/{project_id}/snapshots
Parameters:
  project_id: Project UUID
Status Codes:
  200: Snasphot list returned
  404: The project doesn't exist

Time to bring it all together and write a function that takes a filename and a decorator as argument and spits out a list of tuples which hold the:

  • Function name (str)
  • description for the given decorator (str)
  • parameters for the decorator (dict)
  • status codes for the decorator (dict)

for every function in the sourcefile which is decorated with that decorator.


In [19]:
import collections

Route = collections.namedtuple('Route', 'filename function_name path description parameters status_codes')

def extract_routes(file, decorator_name):
    routes = []
    filename = file
    with open(file) as f:
        try:
            tree = ast.parse(f.read())
        except:
            return routes
            
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            funcname = node.name
            for d in node.decorator_list:
                if isinstance(d, ast.Call):
                    if flatten_attr(d.func) == decorator_name:
                        route_path = d.args[0].s
                        description = None
                        parameters = None
                        statuscodes = None
                        for kw in d.keywords:
                            if kw.arg == 'description':
                                description = kw.value.s
                            if kw.arg == 'parameters':
                                parameters = ast.literal_eval(astunparse.unparse(kw.value))
                            if kw.arg == 'status_codes':
                                statuscodes = ast.literal_eval(astunparse.unparse(kw.value))
                        r = Route(filename, funcname, route_path, description, parameters, statuscodes)
                        routes.append(r)
    
    return routes

get_routes = []
from pathlib import Path

pathlist = Path('./controller').glob('*.py')
for path in pathlist:
    # because path is object not string
    filename = str(path)
    get_routes += extract_routes(filename, 'Route.post')

for route in get_routes:
    print(f'{route.filename}  {route.function_name:{20}} {route.path:{40}}')


controller/compute_handler.py  create               /computes                               
controller/compute_handler.py  post_forward         /computes/{compute_id}/{emulator}/{action:.+}
controller/drawing_handler.py  create               /projects/{project_id}/drawings         
controller/link_handler.py  create               /projects/{project_id}/links            
controller/link_handler.py  start_capture        /projects/{project_id}/links/{link_id}/start_capture
controller/link_handler.py  stop_capture         /projects/{project_id}/links/{link_id}/stop_capture
controller/node_handler.py  create               /projects/{project_id}/nodes            
controller/node_handler.py  start_all            /projects/{project_id}/nodes/start      
controller/node_handler.py  stop_all             /projects/{project_id}/nodes/stop       
controller/node_handler.py  suspend_all          /projects/{project_id}/nodes/suspend    
controller/node_handler.py  reload_all           /projects/{project_id}/nodes/reload     
controller/node_handler.py  start                /projects/{project_id}/nodes/{node_id}/start
controller/node_handler.py  stop                 /projects/{project_id}/nodes/{node_id}/stop
controller/node_handler.py  suspend              /projects/{project_id}/nodes/{node_id}/suspend
controller/node_handler.py  reload               /projects/{project_id}/nodes/{node_id}/reload
controller/node_handler.py  post_file            /projects/{project_id}/nodes/{node_id}/files/{path:.+}
controller/project_handler.py  create_project       /projects                               
controller/project_handler.py  close                /projects/{project_id}/close            
controller/project_handler.py  open                 /projects/{project_id}/open             
controller/project_handler.py  load                 /projects/load                          
controller/project_handler.py  import_project       /projects/{project_id}/import           
controller/project_handler.py  duplicate            /projects/{project_id}/duplicate        
controller/project_handler.py  write_file           /projects/{project_id}/files/{path:.+}  
controller/server_handler.py  shutdown             /shutdown                               
controller/server_handler.py  check_version        /version                                
controller/server_handler.py  write_settings       /settings                               
controller/server_handler.py  debug                /debug                                  
controller/symbol_handler.py  upload               /symbols/{symbol_id:.+}/raw             

In [ ]: