In [ ]:
import xarray as xr

In [ ]:
#xr.open_dataset('/home/mm/bgc-md/prototypes/ModelsAsExpressions/models/minicable/')

In [ ]:
from dask.distributed import Client
client = Client(n_workers=4)

In [ ]:
#client.shutdown()

In [ ]:
from dask import delayed
import dask.array as da

In [ ]:
import numpy as np
lat_n=4
lon_n=2
lats =np.linspace(-0,1,lat_n)
lons =np.linspace(-0,2,lon_n)
data=np.zeros((lat_n,lon_n))
for lat_ind,lat in enumerate(lats):
    for lon_ind,lon in enumerate(lons):
        data[lat_ind,lon_ind]=lat**2+lon**2
data

In [ ]:
da_data=da.from_array(data,chunks=(2,2)) 

#da_lats=da.from_array(lats,chunks=1,)
#da_lons=da.from_array(lons,chunks=1,)
#x = da.arange(10, chunks=(5,))
#y = da.arange(20, chunks=(10,))
np.dtype(da_data)

In [ ]:
def f1(block):
    # emergency version, It seems difficult to guess the 
    # right shape of the output array so we just iterate over the block
    # 
    nlat,nlon=block.shape
    for ilat in range(nlat):
        for ilon in range(nlon):
            x=block[ilat,ilon]
            block[ilat,ilon]=x**2
    return  block

In [ ]:
#da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
fut=da.map_blocks(f1,da_data,dtype=np.dtype(da_data))

In [ ]:
result=fut.compute()

In [ ]:
result

In [ ]:
def f2(block):
    # this time we want to return a much bigger array
    # the statetransition operator cache is actually a 
    nlat,nlon=block.shape
    for ilat in range(nlat):
        for ilon in range(nlon):
            x=block[ilat,ilon]
            block[ilat,ilon]=x**2
    return  block

In [ ]: