SRTM vegetation removal and filtering with a wavelet filter

In this notebook, we will attempt to build a generic filter for SRTM data, in order to make it more useful for hydraulic modelling in vegetation rich areas. Particular interesting areas are forested floodplains and mangrove-rich coastlines.

For vegetation removal, we use the approach suggested by Baugh et al. (2013) Wat. Resour. Res. They use the vegetation map generated by Simard et al. (2011) J. Geophys. Res. to correct and have established an optimal fraction of 0.6 of the total vegetation height to reduce.

The second part considers reduction of random errors. Baugh et al. use a low-pass filter. Here, we use a wavelet filter as suggested by Sanders (2007). Wavelets filter in the spectral domain and therefore remove variability and not large-scale features.

In this notebook, we enable the user to change the canopy height fraction as well as the parameter of the filter (amount of coefficients to remove).

First import packages

In [1]:
from osgeo import gdal, osr
import os
import numpy as np
import pywt
from gdal_readmap import gdal_readmap

from IPython.html.widgets import interact, interactive, fixed
from IPython.html import widgets
from IPython.display import clear_output, display, HTML

Make a function that reprojects data to the same grid format with GDAL


In [2]:
def reproject_dataset ( dataset, \
            xul, yul, xlr, ylr, pixel_spacing=1./1200, epsg_from=4326, epsg_to=4326, driver='MEM', file_name='' ):
    """
    A sample function to reproject and resample a GDAL dataset from within 
    Python. The idea here is to reproject from one system to another, as well
    as to change the pixel size. The procedure is slightly long-winded, but
    goes like this:
    
    1. Set up the two Spatial Reference systems.
    2. Open the original dataset, and get the geotransform
    3. Calculate bounds of new geotransform by projecting the UL corners 
    4. Calculate the number of pixels with the new projection & spacing
    5. Create an in-memory raster dataset
    6. Perform the projection
    """
    # Define the UK OSNG, see <http://spatialreference.org/ref/epsg/27700/>
    osng = osr.SpatialReference ()
    osng.ImportFromEPSG ( epsg_to )
    wgs84 = osr.SpatialReference ()
    wgs84.ImportFromEPSG ( epsg_from )
    tx = osr.CoordinateTransformation ( wgs84, osng )
    # Up to here, all  the projection have been defined, as well as a 
    # transformation from the from to the  to :)
    # We now open the dataset
    g = gdal.Open ( dataset )
    # Get the Geotransform vector
    geo_t = g.GetGeoTransform ()
    x_size = g.RasterXSize # Raster xsize
    y_size = g.RasterYSize # Raster ysize
    # Work out the boundaries of the new dataset in the target projection
    (ulx, uly, ulz ) = tx.TransformPoint( geo_t[0], geo_t[3])
    (lrx, lry, lrz ) = tx.TransformPoint( geo_t[0] + geo_t[1]*x_size, \
                                          geo_t[3] + geo_t[5]*y_size )
    # See how using 27700 and WGS84 introduces a z-value!
    # Now, we create an in-memory raster
    drv = gdal.GetDriverByName( driver )
    # The size of the raster is given the new projection and pixel spacing
    # Using the values we calculated above. Also, setting it to store one band
    # and to use Float32 data type.
    dest = drv.Create(file_name, int(np.round((xlr - xul)/pixel_spacing)), \
            int(np.round((yul - ylr)/pixel_spacing)), 1, gdal.GDT_Float32)
    #dest = drv.Create(file_name, int((lrx - ulx)/pixel_spacing), \
    #        int((uly - lry)/pixel_spacing), 1, gdal.GDT_Float32)
    # Calculate the new geotransform
    new_geo = ( xul, pixel_spacing, geo_t[2], \
                yul, geo_t[4], -pixel_spacing )
    # new_geo = ( ulx, pixel_spacing, geo_t[2], \
    #            uly, geo_t[4], -pixel_spacing )
    # Set the geotransform
    dest.SetGeoTransform( new_geo )
    dest.SetProjection ( osng.ExportToWkt() )
    # Perform the projection/resampling 
    res = gdal.ReprojectImage( g, dest, \
                wgs84.ExportToWkt(), osng.ExportToWkt(), \
                gdal.GRA_Bilinear ) # GRA_NearestNeighbour
    dest = None
    return dest

def writeMap(fileName, fileFormat, x, y, data, FillVal):
    """ Write geographical data into file"""

    verbose = False
    gdal.AllRegister()
    driver1 = gdal.GetDriverByName('GTiff')
    driver2 = gdal.GetDriverByName(fileFormat)

		# Processing
    if verbose:
        print 'Writing to temporary file ' + fileName + '.tif'
	# Create Output filename from (FEWS) product name and date and open for writing
    TempDataset = driver1.Create(fileName + '.tif',data.shape[1],data.shape[0],1,gdal.GDT_Float32)
	# Give georeferences
    xul = x[0]-(x[1]-x[0])/2
    yul = y[0]+(y[0]-y[1])/2
    TempDataset.SetGeoTransform( [ xul, x[1]-x[0], 0, yul, 0, y[1]-y[0] ] )
	# get rasterband entry
    TempBand = TempDataset.GetRasterBand(1)
	# fill rasterband with array
    TempBand.WriteArray(data,0,0)
    TempBand.FlushCache()
    TempBand.SetNoDataValue(FillVal)
	# Create data to write to correct format (supported by 'CreateCopy')
    if verbose:
        print 'Writing to ' + fileName + '.map'
    outDataset = driver2.CreateCopy(fileName, TempDataset, 0)
    TempDataset = None
    outDataset = None
    if verbose:
        print 'Removing temporary file ' + fileName + '.tif'
    os.remove(fileName + '.tif');

    if verbose:
        print 'Writing to ' + fileName + ' is done!'

Define source and target locations


In [17]:
src_folder = r'd:\projects\Servia'
dst_folder = r'd:\projects\Servia'

veg_file = 'Servia_forest_latlon.tif'
veg_file_trg = 'Servia_forest_interp.tif'
dem_file = 'Servia.tif'

Open the data and project vegetation to DEM with bilinear interpolation


In [19]:
veg_path = os.path.join(src_folder, veg_file)
dem_path = os.path.join(src_folder, dem_file)
veg_proj = os.path.join(dst_folder, veg_file_trg)
reproject_dataset(veg_path, 19., 45., 21, 44., pixel_spacing=1./1200, epsg_from=4326,
                  epsg_to=4326, driver='GTiff', file_name=veg_proj )

# reproject_dataset(veg_path, 19., 45., 21, 44., pixel_spacing=1./1200, epsg_from=4326,
                  #epsg_to=4326, driver='GTiff', file_name=veg_proj )

In [20]:
veg_path


Out[20]:
'd:\\projects\\Servia\\Servia_forest_latlon.tif'

Now load all data back into memory


In [21]:
x, y, dem, fill_value = gdal_readmap(dem_path, 'GTiff')
dem = np.float32(dem)
dem[dem==fill_value] = np.nan
x, y, veg, fill_value = gdal_readmap(veg_proj, 'GTiff')
veg[veg==fill_value] = 0

make plot


In [22]:
f = plt.figure(figsize=(15,7))
p1 = plt.subplot(121)
imshow(dem)
p2 = plt.subplot(122)
imshow(veg)


Out[22]:
<matplotlib.image.AxesImage at 0xe755f98>

now make a filter function that removes vegetation and noise and plot again


In [35]:
def plot_dem_filter(dem=[], forest=[], forest_weight=0.6, filter_name='bior6.8', filter_weight=10, extent=None, transect=None):
    xBlockSize = 3000
    yBlockSize = 3000
    # load full DEM in memory
    # x, y, dem, FillVal = readMap(file, 'GTiff')
    ii = np.isnan(dem)
    dem[ii] = 0.
    if extent:
        dem = dem[extent[2]:extent[3], extent[0]:extent[1]]
        forest = forest[extent[2]:extent[3], extent[0]:extent[1]]
    # first correct for forest
    dem_forest = np.maximum(dem - forest*forest_weight, 0)
    # if dem bigger than a certain area, chop the procedure in small pieces
    dem_new = np.zeros(dem.shape)
    rows = dem.shape[0]
    cols = dem.shape[1]
    
    for i in range(0,rows, yBlockSize):
        if i + yBlockSize < rows:
            numRows = yBlockSize
        else:
            numRows = rows - i
            if numRows % 2 == 1:
                # round to a even number
                numRows -= 1
        i2 = i + numRows
        for j in range(0, cols, xBlockSize):
            if j + xBlockSize < cols:
                numCols = xBlockSize
            else:
                numCols = cols - j
                if numCols % 2 == 1:
                    # round to a even number
                    numCols -= 1
            j2 = j + numCols
            print 'Filtering data-block y: %g -- %g; x: %g -- %g' % (i, i2, j, j2)
            coeffs = pywt.wavedec2(dem_forest[i:i2, j:j2], filter_name, level=5) # level taken from paper
            filter_weight = np.float64(filter_weight)
            # now remove some thresholds
             # noiseSigma*sqrt(2*log2(image.size))
            NewWaveletCoeffs = map (lambda x: pywt.thresholding.soft(x,filter_weight),coeffs)
            dem_new[i:i2, j:j2] = pywt.waverec2(NewWaveletCoeffs, filter_name)
    
    # now bring back missing values
    dem_new[ii] = np.nan
    max_val = np.maximum(np.nanmax(dem),np.nanmax(dem_new))
    min_val = np.minimum(dem.min(),dem_new.min())
    # and plot
    f = plt.figure(figsize=(13,7))
    p1 = plt.subplot(221)
    imshow(dem, vmin=min_val, vmax=max_val)
    if transect:
        plt.plot([transect[1], transect[3]], [transect[0], transect[2]], color='k')
        plt.xlim([0, dem.shape[1]])
        plt.ylim([dem.shape[0], 0])
    p2 = plt.subplot(222)
    imshow(dem_new, vmin=min_val, vmax=max_val)
    if transect:
        plt.plot([transect[1], transect[3]], [transect[0], transect[2]], color='k')
        plt.xlim([0, dem.shape[1]])
        plt.ylim([dem.shape[0], 0])
        p3 = plt.subplot(223)
        xc = np.int16(np.round(np.linspace(transect[1], transect[3], 100)))
        yc = np.int16(np.round(np.linspace(transect[0], transect[2], 100)))
        plt.plot(dem[yc,xc], color='b')
        plt.plot(dem_forest[yc, xc], color='g')
        plt.plot(dem_new[yc, xc], color='r')
        plt.xlabel('Transect [-]')
        plt.ylabel('elevation [m]')
        
    #return dem_new
    dem_new[ii] = -9999.
    writeMap('Servia_filter.tif', 'GTiff', x, y, dem_new, -9999.)

Plot over whole area


In [36]:
plot_dem_filter(dem, veg, filter_name='bior6.8', filter_weight=4)


Filtering data-block y: 0 -- 1200; x: 0 -- 2400

In [32]:
interactive(plot_dem_filter, dem=fixed(dem), forest=fixed(veg), foreset_weight=(0.1, 0.9), filter_name=fixed('bior6.8'), filter_weight=(0,30), extent=fixed((1000, 1700, 300, 800)), transect=fixed((400, 200, 200, 400)))


Filtering data-block y: 0 -- 500; x: 0 -- 700

The right figure displays a much better behaviour than the left (original SRTM) figure. However, along the banks, still some uncorrected areas can be found. This is due to the fact that the vegetation map is only 1 km in resolution and resampled to match the 90 meter resolution of SRTM. Even with bilinear interpolation some too high elevated pixels remain along river banks. Nevertheless, the correction is extremely pragmatic and important to use. Further study along coastal areas would be interesting to see if mangrove-rich areas can be corrected as well. This is particularly a problem for coastal storm surge modelling.


In [9]: