Original notebook by Stephan Hoyer, Rossbypalooza, 2016.
Modified by Edward Byers, Matthew Gidden and Fabien Maussion for EGU General Assembly 2017, Vienna, Austria
Thursday, 27th April, 15:30–17:00 / Room -2.91
Convenors
xarray is an open source project and Python packagexarray has been designed to perform labelled data analysis on multi-dimensional arraysxarray.Dataset is an in-memory representation of a netCDF file.xarray is built on top of the dataprocessing library Pandas (the best way to work with tabular data (e.g., CSV files) in Python)For analyzing GCM output:
Other tools:
Resources for teaching and learning xarray in geosciences:
In [1]:
    
# standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import warnings
%matplotlib inline
np.set_printoptions(precision=3, linewidth=80, edgeitems=1)  # make numpy less verbose
xr.set_options(display_width=70)
warnings.simplefilter('ignore') # filter some warning messages
    
In [2]:
    
import numpy as np
a = np.array([[1, 3, 9], [2, 8, 4]])
a
    
    Out[2]:
In [3]:
    
a[1, 2]
    
    Out[3]:
In [4]:
    
a.mean(axis=0)
    
    Out[4]:
numpy is a powerful but "low-level" array manipulation tool. Axis only have numbers and no names (it is easy to forget which axis is what, a common source of trivial bugs), arrays can't carry metadata (e.g. units), and the data is unstructured (i.e. the coordinates and/or other related arrays have to be handled separately: another source of bugs).
This is where xarray comes in!
We'll start with the "air_temperature" tutorial dataset. This tutorial comes with the xarray package. Other examples here.
In [5]:
    
ds = xr.tutorial.load_dataset('air_temperature')
    
In [6]:
    
ds
    
    Out[6]:
In [7]:
    
ds.air
    
    Out[7]:
In [8]:
    
ds.dims
    
    Out[8]:
In [9]:
    
ds.attrs
    
    Out[9]:
In [10]:
    
ds.air.values
    
    Out[10]:
In [11]:
    
type(ds.air.values)
    
    Out[11]:
In [12]:
    
ds.air.dims
    
    Out[12]:
In [13]:
    
ds.air.attrs
    
    Out[13]:
In [14]:
    
ds.air.attrs['tutorial-date'] = 27042017
    
In [15]:
    
ds.air.attrs
    
    Out[15]:
In [16]:
    
kelvin = ds.air.mean(dim='time')
kelvin.plot();
    
    
In [17]:
    
centigrade = kelvin - 273.16
centigrade.plot();
    
    
Notice xarray has changed the colormap according to the dataset (borrowing logic from Seaborn).
In [18]:
    
# ufuncs work too
np.sin(centigrade).plot();
    
    
In [19]:
    
ds
    
    Out[19]:
Let's add those kelvin and centigrade dataArrays to the dataset.
In [20]:
    
ds['centigrade'] = centigrade
ds['kelvin'] = kelvin
ds
    
    Out[20]:
In [21]:
    
ds.kelvin.attrs  # attrs are empty! Let's add some
    
    Out[21]:
In [22]:
    
ds.kelvin.attrs['Description'] = 'Mean air tempterature (through time) in kelvin.'
    
In [23]:
    
ds.kelvin
    
    Out[23]:
In [24]:
    
ds.to_netcdf('new file.nc')
    
In [25]:
    
ds.air[:, 1, 2]  # note that the attributes, coordinates are preserved
    
    Out[25]:
In [26]:
    
ds.air[:, 1, 2].plot();
    
    
This selection implies prior knowledge about the structure of the data, and is therefore much less readable than the "xarray methods" presented below.
In [27]:
    
ds.air.isel(time=0).plot();  # like above, but with a dimension name this time
    
    
In [28]:
    
ds.air.sel(lat=72.5, lon=205).plot();
    
    
In [29]:
    
ds.air.sel(time='2013-01-02').plot(); # Note that we will extract 4 time steps! 3d data is plotted as histogram
    
    
In [30]:
    
ds.air.sel(time='2013-01-02T06:00').plot();  # or look at a single timestep
    
    
The syntax is similar, but you'll need to use a slice:
In [31]:
    
ds.air.sel(lat=slice(60, 50), lon=slice(200, 270), time='2013-01-02T06:00:00').plot();
    
    
In [32]:
    
ds.air.sel(lat=41.8781, lon=360-87.6298, method='nearest', tolerance=5).plot();
    
    
In [33]:
    
a = xr.DataArray(np.arange(3), dims='time', 
                 coords={'time':np.arange(3)})
b = xr.DataArray(np.arange(4), dims='space', 
                 coords={'space':np.arange(4)})
a + b
    
    Out[33]:
In [34]:
    
atime = np.arange(3)
btime = np.arange(5) + 1
atime, btime
    
    Out[34]:
In [35]:
    
a = xr.DataArray(np.arange(3), dims='time', 
                 coords={'time':atime})
b = xr.DataArray(np.arange(5), dims='time', 
                 coords={'time':btime})
a + b
    
    Out[35]:
In [36]:
    
ds.max()
    
    Out[36]:
In [37]:
    
ds.air.median(dim=['lat', 'lon']).plot();
    
    
In [38]:
    
means = ds.air.mean(dim=['time'])
means.where(means > 273.15).plot();
    
    
Xarray implements the "split-apply-combine" paradigm with groupby. This works really well for calculating climatologies:
In [39]:
    
ds.air.groupby('time.season').mean()
    
    Out[39]:
In [40]:
    
ds.air.groupby('time.month').mean('time')
    
    Out[40]:
In [41]:
    
clim = ds.air.groupby('time.month').mean('time')
    
You can also do arithmetic with groupby objects, which repeats the arithmetic over each group:
In [42]:
    
anomalies = ds.air.groupby('time.month') - clim
    
In [43]:
    
anomalies
    
    Out[43]:
In [44]:
    
anomalies.plot();
    
    
In [45]:
    
anomalies.sel(time= '2013-02').plot();  # Find all the anomolous values for February
    
    
Resample adjusts a time series to a new resolution:
In [46]:
    
tmin = ds.air.resample('1D', dim='time', how='min')  # Resample to one day '1D
tmax = ds.air.resample('1D', dim='time', how='max')
    
In [47]:
    
(tmin.sel(time='2013-02-15') - 273.15).plot();
    
    
In [48]:
    
ds_extremes = xr.Dataset({'tmin': tmin, 'tmax': tmax})
    
In [49]:
    
ds_extremes
    
    Out[49]:
xarray plotting functions rely on matplotlib internally, but they make use of all available metadata to make the plotting operations more intuitive and interpretable.
In [50]:
    
zonal_t_average = ds.air.mean(dim=['lon', 'time']) - 273.15
zonal_t_average.plot();  # 1D arrays are plotted as line plots
    
    
In [51]:
    
t_average = ds.air.mean(dim='time') - 273.15
t_average.plot();  # 2D arrays are plotted with pcolormesh
    
    
In [52]:
    
t_average.plot.contourf();  # but you can use contour(), contourf() or imshow() if you wish
    
    
In [53]:
    
t_average.plot.contourf(cmap='BrBG_r', vmin=-15, vmax=15);
    
    
In [54]:
    
t_average.plot.contourf(cmap='BrBG_r', levels=22, center=False);
    
    
In [55]:
    
air_outliers = ds.air.isel(time=0).copy()
air_outliers[0, 0] = 100
air_outliers[-1, -1] = 400
air_outliers.plot();  # outliers mess with the datarange and colorscale!
    
    
In [56]:
    
# Using `robust=True` uses the 2nd and 98th percentiles of the data to compute the color limits.
air_outliers.plot(robust=True);
    
    
In [57]:
    
t_season = ds.air.groupby('time.season').mean(dim='time') - 273.15
    
In [58]:
    
# facet plot allows to do multiplot with the same color mappings
t_season.plot.contourf(x='lon', y='lat', col='season', col_wrap=2, levels=22);
    
    
For plotting on maps, we rely on the excellent cartopy library.
In [59]:
    
import cartopy.crs as ccrs
    
In [60]:
    
f = plt.figure(figsize=(8, 4))
# Define the map projection *on which* you want to plot
ax = plt.axes(projection=ccrs.Orthographic(-80, 35))
# ax is an empty plot. We now plot the variable t_average onto ax
# the keyword "transform" tells the function in which projection the air temp data is stored 
t_average.plot(ax=ax, transform=ccrs.PlateCarree())
# Add gridlines and coastlines to the plot
ax.coastlines(); ax.gridlines();
    
    
In [61]:
    
# this time we need to retrieve the plots to do things with the axes later on
p = t_season.plot(x='lon', y='lat', col='season', transform=ccrs.PlateCarree(),
                  subplot_kws={'projection': ccrs.Orthographic(-80, 35)})
for ax in p.axes.flat:
    ax.coastlines()
    
    
Statistical visualization with Seaborn:
In [62]:
    
import seaborn as sns
data = (ds_extremes
        .sel_points(lat=[41.8781, 37.7749], lon=[360-87.6298, 360-122.4194],
                    method='nearest', tolerance=3,
                    dim=xr.DataArray(['Chicago', 'San Francisco'],
                                     name='location', dims='location'))
        .to_dataframe()
        .reset_index()
        .assign(month=lambda x: x.time.dt.month))
plt.figure(figsize=(10, 5))
sns.violinplot('month', 'tmax', 'location', data=data, split=True, inner=None);
    
    
Here's a quick demo of how xarray can leverage dask to work with data that doesn't fit in memory. This lets xarray substitute for tools like cdo and nco.
xarraycan open multiple files at once using string pattern matching.
In this case we open all the files that match our filestr, i.e. all the files for the 2080s.
Each of these files (compressed) is approximately 80 MB.
In [63]:
    
from glob import glob
files = glob('data/*dis*.nc')
runoff = xr.open_mfdataset(files)
    
In [64]:
    
runoff
    
    Out[64]:
xarray even puts them in the right order for you.
In [65]:
    
runoff.time
    
    Out[65]:
How big is all this data uncompressed? Will it fit into memory?
In [66]:
    
runoff.nbytes / 1e9  # Convert to gigiabytes
    
    Out[66]:
We can do this chunking in xarray very easily.
xarray computes data 'lazily'. That means that data is only loaded into memory when it is actually required. This also allows us to inspect datasets without loading all the data into memory.
To do this xarray integrates with dask to support streaming computation on datasets that don’t fit into memory.
In [67]:
    
runoff = runoff.chunk({'lat': 60})
    
In [68]:
    
runoff.chunks
    
    Out[68]:
In [69]:
    
%time ro_seasonal = runoff.groupby('time.season').mean('time')
    
    
In [70]:
    
import dask
from multiprocessing.pool import ThreadPool
dask.set_options(pool=ThreadPool(1))
    
    Out[70]:
In [71]:
    
%time ro_seasonal.compute()
    
    
    Out[71]:
In [72]:
    
dask.set_options(pool=ThreadPool(4))
    
    Out[72]:
In [73]:
    
%time ro_seasonal = runoff.groupby('time.season').mean('time')
    
    
In [74]:
    
%time result = ro_seasonal.compute()
    
    
In [75]:
    
brazil = dict(lat=slice(10.75, -40.75), lon=slice(-100.25, -25.25))
result.dis.sel(**brazil).plot(col='season', size=4, cmap='Spectral_r')
    
    Out[75]:
    
For more details, read this blog post: http://continuum.io/blog/xray-dask