In [ ]:
    
import rasterio
import rasterio.features
import rasterio.warp
import geopyspark as gps
import numpy as np
import csv
import matplotlib.pyplot as plt
from datetime import datetime
from pyspark import SparkContext
from osgeo import osr
%matplotlib inline
    
In [ ]:
    
sc = SparkContext(conf=gps.geopyspark_conf(appName="Landsat").set("spark.ui.enabled",True))
    
In [ ]:
    
csv_data = [{'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B2.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '2'},
            {'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B3.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '3'},
            {'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B4.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '4'},
            {'uri': 's3://landsat-pds/L8/107/035/LC81070352015218LGN00/LC81070352015218LGN00_B5.TIF', 'scene_id': 'LC81070352015218LGN00', 'date': '2015218', 'band': '5'}]
    
In [ ]:
    
rdd0 = sc.parallelize(csv_data)
    
In [ ]:
    
def get_metadata(line):
    
    try:
        with rasterio.open(line['uri']) as dataset:
            bounds = dataset.bounds
            height = height = dataset.height
            width = dataset.width
            crs = dataset.get_crs()
            srs = osr.SpatialReference()
            srs.ImportFromWkt(crs.wkt)
            proj4 = srs.ExportToProj4()
            ws = [w for (ij, w) in dataset.block_windows()]
    except:
            ws = []
            
    def windows(line, ws):
        for w in ws:
            ((row_start, row_stop), (col_start, col_stop)) = w
            left  = bounds.left + (bounds.right - bounds.left)*(float(col_start)/width)
            right = bounds.left + (bounds.right - bounds.left)*(float(col_stop)/ width)
            bottom = bounds.top + (bounds.bottom - bounds.top)*(float(row_stop)/height)
            top = bounds.top + (bounds.bottom - bounds.top)*(float(row_start)/height)
            extent = gps.Extent(left,bottom,right,top)
            instant = datetime.strptime(line['date'], '%Y%j')
                
            new_line = line.copy()
            new_line.pop('date')
            new_line.pop('scene_id')
            new_line['window'] = w
            new_line['projected_extent'] = gps.TemporalProjectedExtent(extent=extent, instant=instant, proj4=proj4)
            yield new_line
    
    return [i for i in windows(line, ws)]
    
In [ ]:
    
rdd1 = rdd0.flatMap(get_metadata)
rdd1.first()
    
In [ ]:
    
def get_data(line):
    
    new_line = line.copy()
    with rasterio.open(line['uri']) as dataset:
        new_line['data'] = dataset.read(1, window=line['window'])
        new_line.pop('window')
        new_line.pop('uri')
    
    return new_line
    
In [ ]:
    
rdd2 = rdd1.map(get_data)
rdd2.first()
    
In [ ]:
    
rdd3 = rdd2.groupBy(lambda line: line['projected_extent'])
rdd3.first()
    
In [ ]:
    
def make_tiles(line):
    projected_extent = line[0]
    bands = sorted(line[1], key=lambda l: l['band'])
    array = np.array([l['data'] for l in bands])
    tile = gps.Tile.from_numpy_array(array, no_data_value=0)
    return (projected_extent, tile)
def interesting_tile(line):
    [tpe, tile] = line
    return (np.sum(tile[0][0]) != 0)
def square_tile(line):
    [tpe, tile] = line
    return tile[0][0].shape == (512,512)
    
In [ ]:
    
rdd4 = rdd3.map(make_tiles).filter(square_tile)
data = rdd4.filter(interesting_tile).first()
data
    
In [ ]:
    
plt.imshow(data[1][0][0])
    
In [ ]:
    
raster_layer = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPACETIME, rdd4)
    
In [ ]:
    
tiled_raster_layer = raster_layer.tile_to_layout(layout = gps.GlobalLayout(), target_crs=3857)
    
In [ ]:
    
pyramid = tiled_raster_layer.pyramid()
    
In [ ]:
    
for layer in pyramid.levels.values():
    gps.write("file:///tmp/catalog/", "landsat", layer, time_unit=gps.TimeUnit.DAYS)
    
In [ ]:
    
pyramid = tiled_raster_layer.to_spatial_layer().pyramid()
    
In [ ]:
    
for layer in pyramid.levels.values():
    gps.write("file:///tmp/catalog/", "landsat-spatial", layer)
    
In [ ]:
    
from PIL import Image
def render_tile(tile):
    norm = np.uint8(tile[0] / tile[0].max() * 255)
    mask = np.uint8((norm[0] != 0) * 255)
    return Image.fromarray(np.dstack([norm[2], norm[1], norm[0], mask]), mode='RGBA')
    
In [ ]:
    
tms_server = gps.TMS.build(("file:///tmp/catalog/", "landsat-spatial"), display=render_tile)
tms_server.bind('0.0.0.0')
    
In [ ]:
    
import folium
m = folium.Map(tiles='Stamen Terrain', location=[35.6, 140.1], zoom_start=5)
folium.TileLayer(tiles=tms_server.url_pattern, attr='GeoPySpark').add_to(m)
m
    
In [ ]: