Mask and visualize a GeoTiff

This notebook takes the examples mask geotiff written in Scala and plot geotiffs written in Python and merges them into a single notebook using Pixiedust.


In [1]:
import pixiedust


Pixiedust database opened successfully
Pixiedust version 1.0.7

Python

Dependencies


In [2]:
#Add all dependencies to PYTHON_PATH
import sys
sys.path.append("/usr/lib/spark/python")
sys.path.append("/usr/lib/spark/python/lib/py4j-0.10.4-src.zip")
sys.path.append("/usr/lib/python3/dist-packages")

#Define environment variables
import os
os.environ["HADOOP_CONF_DIR"] = "/etc/hadoop/conf"
os.environ["PYSPARK_PYTHON"] = "python3"
os.environ["PYSPARK_DRIVER_PYTHON"] = "ipython"

#Load PySpark to connect to a Spark cluster
from pyspark import SparkConf, SparkContext

#To read GeoTiffs as a ByteArray
from io import BytesIO
import matplotlib.pyplot as plt
import rasterio
from rasterio import plot
from rasterio.io import MemoryFile

Shared variables


In [3]:
in_geo_path = "hdfs:///user/hadoop/modis/MCD12Q1_051/A2001001__Land_Cover_Type_5.tif"
out_mask_path = "hdfs:///user/pheno/modis/usa_mask.tif"
tmp_geo = "/tmp/usa_mask.tif"

In [33]:
sc._conf.set('spark.jars',  'file:/data/local/pixiedust/bin/cloudant-spark-v2.0.0-185.jar:/usr/lib/spark/jars/*')


Out[33]:
<pyspark.conf.SparkConf at 0x7fdc620279e8>

Scala

Here is the scala code to mask a GeoTiff.

Dependencies


In [35]:
%%scala
import sys.process._
import geotrellis.proj4.CRS
import geotrellis.raster.io.geotiff.writer.GeoTiffWriter
import geotrellis.raster.io.geotiff.{SinglebandGeoTiff, _}
import geotrellis.raster.{CellType, DoubleArrayTile}
import geotrellis.spark.io.hadoop._
import geotrellis.vector.{Extent, ProjectedExtent}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

import spire.syntax.cfor._


var geo_num_cols_rows :(Int, Int) = (0, 0)
val geo_tiles_RDD = sc.hadoopGeoTiffRDD(geo_path).values

val geo_extents_withIndex = sc.hadoopMultibandGeoTiffRDD(geo_path).keys.zipWithIndex().map{case (e,v) => (v,e)}
var geo_projected_extent = (geo_extents_withIndex.filter(m => m._1 == 0).values.collect())//(0)
var geo_projected = geo_projected_extent.take(10)


---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-35-960cfcc4db38> in <module>()
----> 1 get_ipython().run_cell_magic('scala', '', 'import sys.process._\nimport geotrellis.proj4.CRS\nimport geotrellis.raster.io.geotiff.writer.GeoTiffWriter\nimport geotrellis.raster.io.geotiff.{SinglebandGeoTiff, _}\nimport geotrellis.raster.{CellType, DoubleArrayTile}\nimport geotrellis.spark.io.hadoop._\nimport geotrellis.vector.{Extent, ProjectedExtent}\nimport org.apache.spark.mllib.linalg.Vector\nimport org.apache.spark.rdd.RDD\nimport org.apache.spark.{SparkConf, SparkContext}\n\nimport spire.syntax.cfor._\n\n\nvar geo_num_cols_rows :(Int, Int) = (0, 0)\nval geo_tiles_RDD = sc.hadoopGeoTiffRDD(geo_path).values\n\nval geo_extents_withIndex = sc.hadoopMultibandGeoTiffRDD(geo_path).keys.zipWithIndex().map{case (e,v) => (v,e)}\nvar geo_projected_extent = (geo_extents_withIndex.filter(m => m._1 == 0).values.collect())//(0)\nvar geo_projected = geo_projected_extent.take(10)')

/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)
   2101             magic_arg_s = self.var_expand(line, stack_depth)
   2102             with self.builtin_trap:
-> 2103                 result = fn(magic_arg_s, cell)
   2104             return result
   2105 

<decorator-gen-126> in scala(self, line, cell)

/usr/local/lib/python3.5/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)
    185     # but it's overkill for just that one bit of state.
    186     def magic_deco(arg):
--> 187         call = lambda f, *a, **k: f(*a, **k)
    188 
    189         if callable(arg):

/usr/local/lib/python3.5/dist-packages/pixiedust/utils/scalaBridge.py in scala(self, line, cell)
    184                 runnerObject.callMethod("set" + key[0].upper() + key[1:], val["initValue"])
    185 
--> 186         varMap = runnerObject.callMethod("runCell")
    187 
    188         #capture the return vars and update the interactive shell

/usr/local/lib/python3.5/dist-packages/pixiedust/utils/javaBridge.py in callMethod(self, methodName, *args)
    148                             break;
    149                 if match:
--> 150                     return m.invoke(self.jHandle, jMethodArgs)
    151 
    152         raise ValueError("Method {0} that matches the given arguments not found".format(methodName) )

/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1131         answer = self.gateway_client.send_command(command)
   1132         return_value = get_return_value(
-> 1133             answer, self.gateway_client, self.target_id, self.name)
   1134 
   1135         for temp_arg in temp_args:

/usr/lib/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    317                 raise Py4JJavaError(
    318                     "An error occurred while calling {0}{1}{2}.\n".
--> 319                     format(target_id, ".", name), value)
    320             else:
    321                 raise Py4JError(

Py4JJavaError: An error occurred while calling o1668.invoke.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 8.0 failed 1 times, most recent failure: Lost task 0.0 in stage 8.0 (TID 8, localhost, executor driver): java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
	at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2133)
	at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1305)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2251)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2245)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
	at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:479)
	at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2136)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2245)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2245)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
	at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
	at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:80)
	at org.apache.spark.scheduler.Task.run(Task.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1423)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1422)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1422)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1650)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1925)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1938)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1951)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1965)
	at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:936)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:935)
	at com.ibm.pixiedust.PixiedustScalaRun$.runCell(pixiedustRunner.scala:50)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at sun.reflect.GeneratedMethodAccessor18.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
	at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2133)
	at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1305)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2251)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2245)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
	at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:479)
	at sun.reflect.GeneratedMethodAccessor14.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1058)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2136)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2245)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2245)
	at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2169)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2027)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1535)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:422)
	at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
	at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:80)
	at org.apache.spark.scheduler.Task.run(Task.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
	... 1 more

In [ ]:
val geo_tiles_withIndex = geo_tiles_RDD.zipWithIndex().map{case (e,v) => (v,e)}
val geo_tile0 = (geo_tiles_withIndex.filter(m => m._1==0).values.collect())(0)
geo_num_cols_rows = (geo_tile0.cols, geo_tile0.rows)
val geo_cellT = geo_tile0.cellType

Read a GeoTiff file


In [8]:
%%scala
var geo_projected_extent = new ProjectedExtent(new Extent(0,0,0,0), CRS.fromName("EPSG:3857"))
var geo_num_cols_rows :(Int, Int) = (0, 0)
val geo_tiles_RDD = sc.hadoopGeoTiffRDD(geo_path).values

val geo_extents_withIndex = sc.hadoopMultibandGeoTiffRDD(geo_path).keys.zipWithIndex().map{case (e,v) => (v,e)}
geo_projected_extent = (geo_extents_withIndex.filter(m => m._1 == 0).values.collect())(0)

val geo_tiles_withIndex = geo_tiles_RDD.zipWithIndex().map{case (e,v) => (v,e)}
val geo_tile0 = (geo_tiles_withIndex.filter(m => m._1==0).values.collect())(0)
geo_num_cols_rows = (geo_tile0.cols, geo_tile0.rows)
val geo_cellT = geo_tile0.cellType


b'pixiedustRunner.scala:32: error: not found: type ProjectedExtent'
b'        var geo_projected_extent = new ProjectedExtent(new Extent(0,0,0,0), CRS.fromName("EPSG:3857"))'
b'                                       ^'
b'pixiedustRunner.scala:32: error: not found: type Extent'
b'        var geo_projected_extent = new ProjectedExtent(new Extent(0,0,0,0), CRS.fromName("EPSG:3857"))'
b'                                                           ^'
b'pixiedustRunner.scala:32: error: not found: value CRS'
b'        var geo_projected_extent = new ProjectedExtent(new Extent(0,0,0,0), CRS.fromName("EPSG:3857"))'
b'                                                                            ^'
b'pixiedustRunner.scala:34: error: value hadoopGeoTiffRDD is not a member of org.apache.spark.SparkContext'
b'val geo_tiles_RDD = sc.hadoopGeoTiffRDD(geo_path).values'
b'                       ^'
b'pixiedustRunner.scala:36: error: value hadoopMultibandGeoTiffRDD is not a member of org.apache.spark.SparkContext'
b'val geo_extents_withIndex = sc.hadoopMultibandGeoTiffRDD(geo_path).keys.zipWithIndex().map{case (e,v) => (v,e)}'
b'                               ^'
b'5 errors found'

Read Mask


In [10]:
%%scala
val mask_tiles_RDD = sc.hadoopGeoTiffRDD(mask_path).values
val mask_tiles_withIndex = mask_tiles_RDD.zipWithIndex().map{case (e,v) => (v,e)}
val mask_tile0 = (mask_tiles_withIndex.filter(m => m._1==0).values.collect())(0)


b'pixiedustRunner.scala:32: error: value hadoopGeoTiffRDD is not a member of org.apache.spark.SparkContext'
b'        val mask_tiles_RDD = sc.hadoopGeoTiffRDD(mask_path).values'
b'                                ^'
b'one error found'

Mask GeoTiff


In [ ]:
%%scala
val res_tile = geo_tile0.localInverseMask(mask_tile0, 1, 0).toArrayDouble()

Save the new GeoTiff file


In [ ]:
%%scala
val clone_tile = DoubleArrayTile(res_tile, geo_num_cols_rows._1, geo_num_cols_rows._2)

val cloned = geotrellis.raster.DoubleArrayTile.empty(geo_num_cols_rows._1, geo_num_cols_rows._2)
cfor(0)(_ < geo_num_cols_rows._1, _ + 1) { col =>
    cfor(0)(_ < geo_num_cols_rows._2, _ + 1) { row =>
        val v = clone_tile.getDouble(col, row)
        cloned.setDouble(col, row, v)
    }
}

val geoTif = new SinglebandGeoTiff(cloned, geo_projected_extent.extent, geo_projected_extent.crs, Tags.empty, GeoTiffOptions.DEFAULT)

//Save GeoTiff to /tmp
GeoTiffWriter.write(geoTif, tmp_geo)

//Upload to HDFS
var cmd = "hadoop dfs -copyFromLocal -f " + tmp_geo + " " + masked_geo
Process(cmd)!

cmd = "rm -fr " + tmp_geo
Process(cmd)!

Python

It will read the masked GeoTiff created using Scala and plot it using matplotlib

Plot masked GeoTiff


In [ ]:
#Read it
data = sc.binaryFiles(masked_geo).take(1)
dataByteArray = bytearray(data[0][1])

#Lets check if the files was read correctly by printing its metadata
with MemoryFile(dataByteArray) as memfile:
    with memfile.open() as dataset:
        print(dataset.profile)

%matplotlib notebook
with MemoryFile(dataByteArray) as memfile:
    with memfile.open() as dataset:
        plot.show((dataset,1))