In [9]:
import rasterio
import numpy as np
import os
In [12]:
vrtfile = '/Users/asger/Data/DHM/DSM_2014/DSM_605_68.vrt'
In [32]:
import xml.etree.ElementTree as ET
from affine import Affine
def rect_to_win(rect):
rows = (int(rect['yOff']), int(rect['yOff']) + int(rect['ySize']))
cols = (int(rect['xOff']), int(rect['xOff']) + int(rect['xSize']))
return (rows, cols)
def parse_vrt( filename ):
vrt = {"shape": None, "bands":[], "srs": "", "transformation": None, "rootpath": "", "filename": ""}
vrt["filename"] = filename
vrt["rootpath"] = os.path.dirname(filename)
xmltree = ET.parse( filename )
xmlroot = xmltree.getroot()
vrt["shape"] = (float(xmlroot.attrib['rasterYSize']), float(xmlroot.attrib['rasterXSize']))
# Iterate over children of root
for rootchild in xmlroot:
if rootchild.tag == 'SRS':
vrt["srs"] = rootchild.text
elif rootchild.tag == 'GeoTransform':
params = [float(f) for f in rootchild.text.split(',')]
vrt["transformation"] = Affine( params[1], params[2], params[0], params[4], params[5], params[3] )
elif rootchild.tag == 'VRTRasterBand':
from rasterio import dtypes
if len(vrt["bands"]) > 0:
raise Exception('Only single band supported for now')
band = { 'sources' : [] }
vrt["bands"].append(band)
typename = rootchild.attrib['dataType']
band['dtype'] = dtypes.dtype_fwd[ dtypes.dtype_rev[typename.lower()] ]
for bandchild in rootchild:
if bandchild.tag in ['ColorInterp', 'Histograms']:
continue
elif bandchild.tag == 'NoDataValue':
band['nodata'] = float( bandchild.text )
elif bandchild.tag == 'ComplexSource':
complexSource = {'sourcetype': 'ComplexSource'}
band['sources'].append( complexSource )
for sourcechild in bandchild:
if sourcechild.tag == 'SourceFilename':
complexSource['sourcefile'] = {
'filename': sourcechild.text,
'relative': sourcechild.get('relativeToVRT') == '1'}
elif sourcechild.tag == 'SourceBand' :
complexSource['sourceband']= int( sourcechild.text )
elif sourcechild.tag == 'SrcRect':
complexSource['srcwin'] = rect_to_win(sourcechild.attrib)
elif sourcechild.tag == 'DstRect':
complexSource['dstwin'] = rect_to_win(sourcechild.attrib)
elif sourcechild.tag == 'NODATA':
complexSource['nodata'] = float( sourcechild.text )
return vrt
In [79]:
class TiledRaster():
def __init__(self,vrtfile):
self.vrt = parse_vrt(vrtfile)
self.vrtfile = self.vrt['filename']
self.shape = self.vrt['shape']
self.basepath = self.vrt['rootpath']
def tiles(self, bandnumber = 1):
for t in self.vrt['bands'][bandnumber - 1]['sources']:
yield t
def get_tile_file(self, tile):
src = tile['sourcefile']
filename = src['filename']
if src['relative']:
filename = os.path.join(self.basepath, filename)
return filename
def read_tile_buffered(self, tile, bufferpixels, masked=None):
"""Returns ( numpyarray, tiledata_window)
tiledatawindow is a window indicating which part of the numpy array stems from the tile"""
bufferinfo = self._adjust_buffer(tile, bufferpixels)
win = self._expand_window(tile, bufferinfo)
tile_data_window = self._unbuffered_window(tile, bufferinfo)
with rasterio.open(self.vrtfile, mode='r') as f:
data = f.read_band(1, window=win, masked=masked)
return data, tile_data_window
def _adjust_buffer(self, tile, bufferpixels):
"""Adjust left/top/right/bottom buffer to stay within vrt bounds"""
left = top = right = bottom = bufferpixels
dstwin = tile['dstwin']
if dstwin[1][0] - bufferpixels < 0:
left = dstwin[1][0]
if dstwin[0][0] - bufferpixels < 0:
top = dstwin[0][0]
if dstwin[1][1] + bufferpixels > self.shape[1]:
right = self.shape[1] - dstwin[1][1]
if dstwin[0][1] + bufferpixels > self.shape[0]:
bottom = self.shape[0] - dstwin[0][1]
return (left, top, right, bottom)
def _expand_window(self, tile, bufferinfo):
"""Expands a window by bufferinfo"""
window = tile['dstwin']
rows = ( window[0][0] - bufferinfo[1], window[0][1] + bufferinfo[3] )
cols = ( window[1][0] - bufferinfo[0], window[1][1] + bufferinfo[2] )
return (rows, cols)
def _unbuffered_window(self, tile, bufferinfo):
""" Takes a buffered window and bufferinfo and calculates indices for the unbuffered data"""
window = tile['dstwin']
winshape = [ end - start for start, end in window ]
return ((bufferinfo[1], winshape[0] - bufferinfo[3]),(bufferinfo[0], winshape[1] - bufferinfo[2]))
In [81]:
tr = TiledRaster(vrtfile)
tile = [ t for t in tr.tiles() if 'dsm_1km_6055_689.tif' in t['sourcefile']['filename'] ][0]
print tile
bufferinfo = tr._adjust_buffer(tile, 10)
print bufferinfo
expanded_window = tr._expand_window(tile, bufferinfo)
print expanded_window
unbuf = tr._unbuffered_window(tile, bufferinfo)
print unbuf
In [82]:
data, win = tr.read_tile_buffered(tile, 10, masked=True)
In [83]:
data
Out[83]:
In [84]:
win
Out[84]:
In [ ]: