In [134]:
%matplotlib inline

In [510]:
import time
time.time()


Out[510]:
1493737934.7291334

In [518]:
import inspect
import hashlib
import re
import networkx as nx
import time
import os

class Task:
    def __init__(self, name, desc, provides, requires):
        self.name = name
        self.desc = desc
        self.provides = provides
        self.requires = requires
        self.has_run = False

    def __repr__(self):
        status  = 'R' if self.has_run else ' '
        status += 'U' if self.outputs_up_to_date else ' '
        s = 'Task: %s (%s)' % (self.name, status)
        return s

    @property
    def outputs_up_to_date(self):
        '''True if all outputs are up-to-date with the inputs (does not check that
        the inputs file exist).'''
        latest_input = max([inp.mtime
                            for inp in self.requires if inp.exists], 
                           default=0)
            
        first_output = min([out.mtime
                            for out in self.provides if out.exists], 
                           default=time.time())
        flag = all((dt.exists for dt in self.provides))
        flag = flag and (latest_input < first_output)
        return flag

    @property
    def need_run(self):
        dont_need_run = self.has_run and self.outputs_up_to_date
        return not dont_need_run
    
    def __call__(self):
        ret = self.fun()
        # Mark as runned and out file up to date
        # print('setting %s.has_run = True' % self)
        self.has_run = True
        return ret

class Data:
    def __init__(self, datafile):
        self.datafile = datafile
        self.desc = 'Data: %s' % datafile
        
    @property
    def mtime(self):
        return os.path.getmtime(self.datafile)        

    @property
    def exists(self):
        return os.path.isfile(self.datafile)
        
    def __repr__(self):
        return self.desc

DIN_RE = re.compile('\.din\([\'"](\w+)[\'"]\)')
DOUT_RE = re.compile('\.dout\([\'"](\w+)[\'"]\)')

class TaskManager:
    def __init__(self):
        self.tasks = dict()
        self.graph = nx.DiGraph()
        self.str_to_data = dict()
        self.data_to_provider = dict()
        
    def get_or_register_data(self, datafile):
        if datafile in self.str_to_data:
            return self.str_to_data[datafile]
        else:
            data_obj = Data(datafile)
            self.str_to_data[datafile] = data_obj
            return data_obj
        
    def register(self, fun):
        desc = inspect.getdoc(fun)
        src = inspect.getsource(fun)
        hsh = hashlib.sha256(src.encode('utf-8')).hexdigest()
        fun_name = fun.__name__
        
        data_in = [self.get_or_register_data(dt) for dt in re.findall(DIN_RE, src)]
        data_out = [self.get_or_register_data(dt) for dt in re.findall(DOUT_RE, src)]

        def set_task(task):
            task.hash = hsh
            task.src = src
            task.desc = desc
            task.fun = fun
            # Invalidate data provided
            for d in data_out:
                self.data_to_provider[d] = task

        if fun_name in self.graph:
            task = self.graph.node[fun_name]['task']
            if task.hash != hsh:
                print('W: overriding %s' % fun_name)
                set_task(task)
            else:
                print('W: same redefinition of %s' % fun_name)
        else:
            task = Task(fun_name, desc, requires=data_in, provides=data_out)
            set_task(task)
            
        # self.tasks[fun_name] = task
            
        # Add current node to the graph
        if task.name in self.graph: 
            self.graph.remove_node(task.name)
        self.graph.add_node(task.name, task=task)
        
        # Add edges to parent tasks
        for din in data_in:
            if din not in self.data_to_provider:
                raise Exception('No provider found for «%s»' % din)                
            provider_task = self.data_to_provider[din]
            
            # print('Linking %s to %s' % (task.name, provider_task.name))          
            self.graph.add_edge(provider_task.name, task.name)

        def wrap(*args, **kwargs):
            ret = None
            # Build dep tree
            g = nx.ego_graph(self.graph.reverse(), task.name, radius=100)
            node_order = nx.topological_sort(g, reverse=True)
            # print('order: ', node_order)
            for node in node_order:
                task_obj = self.graph.node[node]['task']
                if task_obj.need_run or task_obj.name == task.name:
                    print('Calling «%s»' % task_obj)
                    ret = task_obj()
                    
                    # Mark the children to run them
                    for child in self.graph.neighbors(node):
                        self.graph.node[child]['task'].has_run = False

            return ret

        return wrap

    def din(self, s):
        '''Dummy function to show data in'''
        return s
    
    def dout(self, s):
        '''Dummy function to show data out'''
        return s
    
    def __repr__(self):
        arr = []
        for node in self.graph.nodes():
            task = self.graph.node[node]['task']
            arr.append('%s, %s, %s' % (task.name,
                                  'has_run' if task.has_run else 'need to run', 
                                  'files up to date' if task.outputs_up_to_date else 'files not up to date'))
        return '\n'.join(arr)

In [535]:
import sh

tm = TaskManager()

@tm.register
def A():
    sh.touch(tm.dout('Aout1'))
    sh.touch(tm.dout('Aout2'))

@tm.register
def B():
    assert os.path.exists(tm.din('Aout1'))
    assert os.path.exists(tm.din('Aout2'))
    sh.touch(tm.dout('Bout'))
    
@tm.register
def C():
    assert os.path.exists(tm.din('Aout1'))
    assert os.path.exists(tm.din('Bout'))
    sh.touch(tm.dout('Cout'))
    
@tm.register
def D():
    assert os.path.exists(tm.din('Cout'))
    assert os.path.exists(tm.din('Bout'))

In [537]:
print('>>>>>>>>>> All should run')
D()
print('>>>>>>>>>> Did nothing, only D should run')
D()
print('>>>>>>>>>> Removing output of A, all should rerun')
sh.rm('Aout1')
D()
print('>>>>>>>>>> Did nothing, only D should run')
D()
print('>>>>>>>>>> Touched output of A => B, C, D should rerun')
sh.touch('Aout1')
D()
print('>>>>>>>>>> Removed output of C => C and D should rerun')
sh.rm('Cout')
D()


>>>>>>>>>> All should run
Calling «Task: D (RU)»
>>>>>>>>>> Did nothing, only D should run
Calling «Task: D (RU)»
>>>>>>>>>> Removing output of A, all should rerun
Calling «Task: A (R )»
Calling «Task: B (  )»
Calling «Task: C (  )»
Calling «Task: D ( U)»
>>>>>>>>>> Did nothing, only D should run
Calling «Task: D (RU)»
>>>>>>>>>> Touched output of A => B, C, D should rerun
Calling «Task: B (R )»
Calling «Task: C (  )»
Calling «Task: D ( U)»
>>>>>>>>>> Removed output of B => B and D should rerun
Calling «Task: B (R )»
Calling «Task: C (  )»
Calling «Task: D ( U)»

In [542]:
graph = tm.graph # nx.ego_graph(tm.graph, 'C', radius=100)
layout = nx.layout.circular_layout(graph)
nx.draw(graph, pos=layout)
_ = nx.draw_networkx_labels(graph, layout)


/home/ccc/.virtualenvs/astro/lib/python3.6/site-packages/networkx/drawing/nx_pylab.py:126: MatplotlibDeprecationWarning: pyplot.hold is deprecated.
    Future behavior will be consistent with the long-time default:
    plot commands add elements without first clearing the
    Axes and/or Figure.
  b = plt.ishold()
/home/ccc/.virtualenvs/astro/lib/python3.6/site-packages/networkx/drawing/nx_pylab.py:138: MatplotlibDeprecationWarning: pyplot.hold is deprecated.
    Future behavior will be consistent with the long-time default:
    plot commands add elements without first clearing the
    Axes and/or Figure.
  plt.hold(b)
/home/ccc/.virtualenvs/astro/lib/python3.6/site-packages/matplotlib/__init__.py:917: UserWarning: axes.hold is deprecated. Please remove it from your matplotlibrc and/or style files.
  warnings.warn(self.msg_depr_set % key)
/home/ccc/.virtualenvs/astro/lib/python3.6/site-packages/matplotlib/rcsetup.py:152: UserWarning: axes.hold is deprecated, will be removed in 3.0
  warnings.warn("axes.hold is deprecated, will be removed in 3.0")

In [ ]: