WASARD is a general purpose transfer model between optical and SAR imagery for water classification.
A trained WASARD model can be executed on SAR imagery to create water classification maps over a region.
This notebook is inspired by an IGARSS publication titled Water Across Synthetic Aperture Radar Data (WASARD): SAR Water Body Classification for the Open Data Cube authored by Zachary Kreiser, Brian Killough, Syed R Rizvi.
WASARD is trained using water classifications on optical imagery as a point of reference. A machine learning model is used to approximate a transfer function between SAR data and Optical Classifications.
Details regarding WASARD in this notebook
Transfer Model:
Linear SVM
Optical water classifier:
WOFS
Optical Source:
Landsat 8
SAR Target
: Sentinel 1a
Sentinel1
and Landsat8
ImagerySentinel1
imagery is loaded.Landsat8
imagery is loaded, reprojected, and upsampled to match the resolution of Sentine1
In [ ]:
import datacube
dc = datacube.Datacube(app = "[notebook][wasard][samoa]")
In [ ]:
sar_product_name = "s1g_gamma0"
optical_product_name = "ls8_usgs_sr_scene"
# #Sa'anapu
# longitude_extent = (-171.904492, -171.790327)
# latitude_extent = (-14.0 , -13.962341)
#Apia
latitude_extent = -13.853425,-13.815715
longitude_extent =-171.787842, -171.681356
date_range = ('2016-8-1','2018-3-1')
In [ ]:
from utils.data_cube_utilities.dc_display_map import display_map
display_map(latitude = latitude_extent, longitude = longitude_extent)
In [ ]:
sentinel_coordinates = dc.load(product = sar_product_name,
latitude = latitude_extent,
longitude = longitude_extent,
time = date_range,
measurements = [])
landsat_coordinates = dc.load(product = optical_product_name,
latitude = latitude_extent,
longitude = longitude_extent,
time = date_range,
measurements = [])
In [ ]:
import utils.data_cube_utilities.xarray_bokeh_plotting as xr_bokeh
xr_bokeh.init_notebook()
In [ ]:
xr_bokeh.dim_alignement( sentinel_coordinates.isel(latitude = slice(0,100), longitude = slice(0,100)), " Sentinel 1",
landsat_coordinates.isel(latitude = slice(0,100), longitude = slice(0,100)), "Landsat 8")
Sentinel CRS
In [ ]:
sentinel_details = dc.list_products()[dc.list_products()["name"].str.contains(sar_product_name)]
sentinel_details
In [ ]:
sentinel_crs = str(sentinel_details['crs'].values[0])
sentinel_resolution = tuple(sentinel_details['resolution'].values[0])
In [ ]:
print (sentinel_crs)
In [ ]:
sentinel_dataset = dc.load(product = sar_product_name,
latitude = latitude_extent,
longitude = longitude_extent,
time = date_range,
)
In [ ]:
landsat_dataset = dc.load(product = optical_product_name,
latitude = latitude_extent,
longitude = longitude_extent,
time = date_range,
output_crs = sentinel_crs,
resolution = sentinel_resolution
)
In [ ]:
subset_of_landsat_coords = landsat_dataset.isel(latitude = slice(0,100),
longitude = slice(0,100),
time = 0).coords
In [ ]:
subset_of_sentinel_coords = sentinel_dataset.isel(latitude = slice(0,100),
longitude = slice(0,100),
time = 0).coords
In [ ]:
xr_bokeh.dim_alignement(subset_of_sentinel_coords, "S1",
subset_of_landsat_coords, "LS")
In [ ]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
def remove_all_zero(dataset):
return dataset.drop([c[0].values
for c in [(t,np.count_nonzero(dataset.sel(time=t).vv))
for t in dataset.time] if c[1]==0],dim='time')
In [ ]:
from typing import List
import itertools
has_time_dimension = lambda x: "time" in dict(x.dims).keys()
get_first = lambda x: x[0]
get_last = lambda x: x[-1]
def group_dates_by_day( dates: List[np.datetime64]) -> List[List[np.datetime64]]:
generate_key = lambda b: ((b - np.datetime64('1970-01-01T00:00:00Z')) / (np.timedelta64(1, 'h')*24)).astype(int)
return [list(group) for key, group in itertools.groupby(dates, key=generate_key)]
def reduce_on_day(ds: xr.Dataset,
reduction_func: np.ufunc = np.nanmean) -> xr.Dataset:
#Group dates by day into date_groups
day_groups = group_dates_by_day(ds.time.values)
#slice large dataset into many smaller datasets by date_group
group_chunks = (ds.sel(time = t) for t in day_groups)
#reduce each dataset using something like "average" or "median" such that many values for a day become one value
group_slices = (_ds.reduce(reduction_func, dim = "time") for _ds in group_chunks if has_time_dimension(_ds))
#recombine slices into larger dataset
new_dataset = xr.concat(group_slices, dim = "time")
#rename times values using the first time in each date_group
new_times = list(map(get_first, day_groups))
new_dataset = new_dataset.reindex(dict(time = np.array(new_times)))
return new_dataset
In [ ]:
sentinel_dataset = remove_all_zero(sentinel_dataset)
sentinel_dataset = reduce_on_day(sentinel_dataset)
In [ ]:
subset_of_landsat_coords = landsat_dataset.isel(latitude = 0, longitude = 0).coords
subset_of_sentinel_coords = sentinel_dataset.isel(latitude = 0, longitude = 0).coords
xr_bokeh.dim_alignement(subset_of_sentinel_coords, "S1",
subset_of_landsat_coords, "LS")
In [ ]:
from utils.data_cube_utilities.dc_mosaic import ls8_unpack_qa
In [ ]:
from utils.data_cube_utilities.dc_mosaic import ls8_unpack_qa
from utils.data_cube_utilities.dc_mosaic import create_median_mosaic
import xarray as xr
import numpy as np
def clean_mask_ls8(ds:xr.Dataset) -> np.array:
clear_xarray = ls8_unpack_qa(ds.pixel_qa, "clear")
water_xarray = ls8_unpack_qa(ds.pixel_qa, "water")
cloud_free_boolean_mask = np.logical_or(clear_xarray, water_xarray)
return cloud_free_boolean_mask
def median_mosaic_ls8(dataset):
# The mask here is based on pixel_qa products. It comes bundled in with most Landsat Products.
return create_median_mosaic(dataset, clean_mask = clean_mask_ls8(dataset))
In [ ]:
mosaic = median_mosaic_ls8(landsat_dataset.isel(time = slice(0,20)))
In [ ]:
from utils.data_cube_utilities.dc_rgb import rgb
rgb(mosaic, bands=['red', 'green', 'blue'], width= 15)
In [ ]:
from utils.data_cube_utilities.dc_water_classifier import wofs_classify
In [ ]:
water_classifications = wofs_classify(landsat_dataset, clean_mask= clean_mask_ls8(landsat_dataset), no_data= np.nan)
In [ ]:
%matplotlib inline
def aspect_ratio_helper(x,y, fixed_width = 20):
width = fixed_width
height = y * (fixed_width / x)
return (width, height)
In [ ]:
import matplotlib.pyplot as plt
plt.figure(figsize = aspect_ratio_helper(*reversed(list(water_classifications.wofs.shape)[1:])))
water_classifications.mean(dim = 'time').wofs.plot(cmap = "jet_r")
In [ ]:
from utils.data_cube_utilities import wasard
In [ ]:
samoa_classifier = wasard.wasard_classifier(sar_dataset=sentinel_dataset,
landsat_dataset=landsat_dataset)
In [ ]:
samoa_classified = samoa_classifier.wasard_classify(sentinel_dataset)
In [ ]:
plt.figure(figsize = aspect_ratio_helper(*reversed(list(water_classifications.wofs.shape)[1:])))
samoa_classified.wasard.mean(dim = "time").plot(cmap = "jet_r")