In [ ]:
import rasterio
import rasterio.features
import rasterio.warp
import geopyspark as gps
import numpy as np
import csv
import matplotlib.pyplot as plt

from datetime import datetime
from pyspark import SparkContext
from osgeo import osr

%matplotlib inline

In [ ]:
sc = SparkContext(conf=gps.geopyspark_conf(appName="Landsat").set("spark.ui.enabled",True))

In [ ]:
csv_data = [{'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B2.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '2'},
            {'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B3.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '3'},
            {'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B4.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '4'},
            {'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B5.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '5'}]

In [ ]:
rdd0 = sc.parallelize(csv_data)

In [ ]:
def get_metadata(line):
    
    try:
        with rasterio.open(line['uri']) as dataset:
            bounds = dataset.bounds
            height = height = dataset.height
            width = dataset.width
            crs = dataset.get_crs()
            srs = osr.SpatialReference()
            srs.ImportFromWkt(crs.wkt)
            proj4 = srs.ExportToProj4()
            ws = [w for (ij, w) in dataset.block_windows()]
    except:
            ws = []
            
    def windows(line, ws):
        for w in ws:
            ((row_start, row_stop), (col_start, col_stop)) = w

            left  = bounds.left + (bounds.right - bounds.left)*(float(col_start)/width)
            right = bounds.left + (bounds.right - bounds.left)*(float(col_stop)/ width)
            bottom = bounds.top + (bounds.bottom - bounds.top)*(float(row_stop)/height)
            top = bounds.top + (bounds.bottom - bounds.top)*(float(row_start)/height)
            extent = gps.Extent(left,bottom,right,top)
            instant = datetime.strptime(line['date'], '%Y%j')
                
            new_line = line.copy()
            new_line.pop('date')
            new_line.pop('scene_id')
            new_line['window'] = w
            new_line['projected_extent'] = gps.TemporalProjectedExtent(extent=extent, instant=instant, proj4=proj4)
            yield new_line
    
    return [i for i in windows(line, ws)]

In [ ]:
rdd1 = rdd0.flatMap(get_metadata)
rdd1.first()

In [ ]:
def get_data(line):
    
    new_line = line.copy()

    with rasterio.open(line['uri']) as dataset:
        new_line['data'] = dataset.read(1, window=line['window'])
        new_line.pop('window')
        new_line.pop('uri')
    
    return new_line

In [ ]:
rdd2 = rdd1.map(get_data)
rdd2.first()

In [ ]:
rdd3 = rdd2.groupBy(lambda line: line['projected_extent'])
rdd3.first()

In [ ]:
def make_tiles(line):
    projected_extent = line[0]
    bands = sorted(line[1], key=lambda l: l['band'])
    array = np.array([l['data'] for l in bands])
    tile = gps.Tile.from_numpy_array(array, no_data_value=0)
    return (projected_extent, tile)

def interesting_tile(line):
    [tpe, tile] = line
    return (np.sum(tile[0][0]) != 0)

def square_tile(line):
    [tpe, tile] = line
    return tile[0][0].shape == (512,512)

In [ ]:
rdd4 = rdd3.map(make_tiles).filter(square_tile)
data = rdd4.filter(interesting_tile).first()
data

In [ ]:
plt.imshow(data[1][0][0])

In [ ]:
raster_layer = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPACETIME, rdd4)

In [ ]:
tiled_raster_layer = raster_layer.tile_to_layout(layout = gps.GlobalLayout(), target_crs=3857)

In [ ]:
pyramid = tiled_raster_layer.pyramid()

In [ ]:
for layer in pyramid.levels.values():
    gps.write("file:///tmp/catalog/", "landsat", layer, time_unit=gps.TimeUnit.DAYS)

Display (Optional)


In [ ]:
pyramid = tiled_raster_layer.to_spatial_layer().pyramid()

In [ ]:
for layer in pyramid.levels.values():
    gps.write("file:///tmp/catalog/", "landsat-spatial", layer)

In [ ]:
from PIL import Image

def render_tile(tile):
    norm = np.uint8(tile[0] / tile[0].max() * 255)
    mask = np.uint8((norm[0] != 0) * 255)
    return Image.fromarray(np.dstack([norm[2], norm[1], norm[0], mask]), mode='RGBA')

In [ ]:
tms_server = gps.TMS.build(("file:///tmp/catalog/", "landsat-spatial"), display=render_tile)
tms_server.bind('0.0.0.0')

In [ ]:
import folium

m = folium.Map(tiles='Stamen Terrain', location=[35.6, 140.1], zoom_start=5)
folium.TileLayer(tiles=tms_server.url_pattern, attr='GeoPySpark').add_to(m)
m

In [ ]: