In [9]:
import rasterio
import numpy as np
import os

In [12]:
vrtfile = '/Users/asger/Data/DHM/DSM_2014/DSM_605_68.vrt'

In [32]:
import xml.etree.ElementTree as ET
from affine import Affine

def rect_to_win(rect):
    rows = (int(rect['yOff']), int(rect['yOff']) + int(rect['ySize']))
    cols = (int(rect['xOff']), int(rect['xOff']) + int(rect['xSize']))
    return (rows, cols)

def parse_vrt( filename ):
    vrt = {"shape": None, "bands":[], "srs": "", "transformation": None, "rootpath": "", "filename": ""}
    vrt["filename"] = filename
    vrt["rootpath"] = os.path.dirname(filename)
    
    xmltree = ET.parse( filename )
    xmlroot = xmltree.getroot()
    vrt["shape"] = (float(xmlroot.attrib['rasterYSize']), float(xmlroot.attrib['rasterXSize']))
    
    # Iterate over children of root
    for rootchild in xmlroot:
        if rootchild.tag == 'SRS':
            vrt["srs"] = rootchild.text
        elif rootchild.tag == 'GeoTransform':
            params = [float(f) for f in rootchild.text.split(',')]
            vrt["transformation"] = Affine( params[1], params[2], params[0], params[4], params[5], params[3] )
        elif rootchild.tag == 'VRTRasterBand':
            from rasterio import dtypes
            if len(vrt["bands"]) > 0:
                raise Exception('Only single band supported for now')
            band = { 'sources' : [] }
            vrt["bands"].append(band)
            typename = rootchild.attrib['dataType']
            band['dtype'] = dtypes.dtype_fwd[ dtypes.dtype_rev[typename.lower()] ]
            for bandchild in rootchild:
                if bandchild.tag in ['ColorInterp', 'Histograms']:
                    continue
                elif bandchild.tag == 'NoDataValue':
                    band['nodata'] = float( bandchild.text )
                elif bandchild.tag == 'ComplexSource':
                    complexSource = {'sourcetype': 'ComplexSource'}
                    band['sources'].append( complexSource )
                    for sourcechild in bandchild:
                        if sourcechild.tag == 'SourceFilename':
                            complexSource['sourcefile'] = {
                                'filename': sourcechild.text, 
                                'relative': sourcechild.get('relativeToVRT') == '1'}
                        elif sourcechild.tag == 'SourceBand' :
                            complexSource['sourceband']= int( sourcechild.text )
                        elif sourcechild.tag == 'SrcRect':
                            complexSource['srcwin'] = rect_to_win(sourcechild.attrib) 
                        elif sourcechild.tag == 'DstRect':
                            complexSource['dstwin'] = rect_to_win(sourcechild.attrib)
                        elif sourcechild.tag == 'NODATA':                        
                            complexSource['nodata'] = float( sourcechild.text )
    return vrt

In [79]:
class TiledRaster():
    def __init__(self,vrtfile):
        self.vrt = parse_vrt(vrtfile)
        self.vrtfile = self.vrt['filename']
        self.shape = self.vrt['shape']
        self.basepath = self.vrt['rootpath']
        
    def tiles(self, bandnumber = 1):
        for t in self.vrt['bands'][bandnumber - 1]['sources']:
            yield t
    
    def get_tile_file(self, tile):
        src = tile['sourcefile'] 
        filename = src['filename']
        if src['relative']:
            filename = os.path.join(self.basepath, filename)
        return filename
    
    def read_tile_buffered(self, tile, bufferpixels, masked=None):
        """Returns ( numpyarray, tiledata_window)
        tiledatawindow is a window indicating which part of the numpy array stems from the tile"""
        bufferinfo = self._adjust_buffer(tile, bufferpixels)
        win = self._expand_window(tile, bufferinfo)
        tile_data_window = self._unbuffered_window(tile, bufferinfo)
        with rasterio.open(self.vrtfile, mode='r') as f:
            data = f.read_band(1, window=win, masked=masked)
        return data, tile_data_window
        
    def _adjust_buffer(self, tile, bufferpixels):
        """Adjust left/top/right/bottom buffer to stay within vrt bounds"""
        left = top = right = bottom = bufferpixels
        dstwin = tile['dstwin']
        if dstwin[1][0] - bufferpixels < 0:
            left = dstwin[1][0]
        if dstwin[0][0] - bufferpixels < 0:
            top = dstwin[0][0]
        if dstwin[1][1] + bufferpixels > self.shape[1]:
            right = self.shape[1] - dstwin[1][1]
        if dstwin[0][1] + bufferpixels > self.shape[0]:
            bottom = self.shape[0] - dstwin[0][1]
        return (left, top, right, bottom)
    
    def _expand_window(self, tile, bufferinfo):
        """Expands a window by bufferinfo"""
        window = tile['dstwin']
        rows = ( window[0][0] - bufferinfo[1], window[0][1] + bufferinfo[3] )
        cols = ( window[1][0] - bufferinfo[0], window[1][1] + bufferinfo[2] )
        return (rows, cols)
    
    def _unbuffered_window(self, tile, bufferinfo):
        """ Takes a buffered window and bufferinfo and calculates indices for the unbuffered data"""
        window = tile['dstwin']
        winshape = [ end - start for start, end in window ]
        return ((bufferinfo[1], winshape[0] - bufferinfo[3]),(bufferinfo[0], winshape[1] - bufferinfo[2]))

In [81]:
tr = TiledRaster(vrtfile)
tile = [ t for t in tr.tiles() if 'dsm_1km_6055_689.tif' in t['sourcefile']['filename'] ][0]
print tile
bufferinfo = tr._adjust_buffer(tile, 10)
print bufferinfo
expanded_window = tr._expand_window(tile, bufferinfo)
print expanded_window
unbuf = tr._unbuffered_window(tile, bufferinfo)
print unbuf


{'sourcetype': 'ComplexSource', 'srcwin': ((0, 2500), (0, 2500)), 'sourcefile': {'relative': True, 'filename': 'DSM_605_68_TIF_UTM32-ETRS89/dsm_1km_6055_689.tif'}, 'dstwin': ((10000, 12500), (22500, 25000)), 'sourceband': 1, 'nodata': -9999.0}
(10, 10, 0.0, 10)
((9990, 12510), (22490, 25000.0))
((10, 2490), (10, 2500.0))

In [82]:
data, win = tr.read_tile_buffered(tile, 10, masked=True)

In [83]:
data


Out[83]:
masked_array(data =
 [[ 2.35071087  2.32748556  2.31290698 ...,  4.36390114  5.6388669
   5.25279999]
 [ 2.34498763  2.33384609  2.33390045 ...,  4.67028189  4.01992226
   4.22879982]
 [ 2.36370325  2.36135387  2.34420419 ...,  4.34113312  3.45245767
   3.18720365]
 ..., 
 [ 4.83269167  4.82080984  4.82189465 ..., -0.1608575  -0.15683722
  -0.15000001]
 [ 4.84731293  4.82314205  4.81542635 ..., -0.16       -0.16       -0.1495966 ]
 [ 4.84875441  4.83501101  4.81044436 ..., -0.15561403 -0.15577827
  -0.15669918]],
             mask =
 False,
       fill_value = -9999.0)

In [84]:
win


Out[84]:
((10, 2490), (10, 2500.0))

In [ ]: