In [ ]:
import dask.array as da
import os
import numpy as np
import pylab as plt
import h5py
from dask.dot import dot_graph
#from dask.multiprocessing import get
from dask import get
from functools import partial
from time import sleep, clock
from scipy.integrate import simps
from dask.callbacks import Callback
from distributed import Client
#from multiprocessing.pool import ThreadPool
from ForwardEquation import forwardEquation, forwardEquation_dask
from Gradient import computeAdjoint, computeAdjoint_dask
from TricubicInterpolation import TriCubic
from LineSearch import lineSearch
from InfoCompleteness import precondition
from InitialModel import createInitialModel
from CalcRays import calcRays,calcRays_dask
from RealData import plotDataPack
from PlotTools import animateTCISlices
from Covariance import CovarianceClass
def store_Fdot(resettable,outputfolder,n1,n2,Fdot,gamma,beta,dm,dgamma,v,sigma_m,L_m,covC):
filename="{}/F{}gamma{}.hdf5".format(outputfolder,n1,n2)
if os.path.isfile(filename) and resettable:
return filename
#gamma.save(filename)
out = Fdot.copy()
xvec = dm.xvec
yvec = dm.yvec
zvec = dm.zvec
gamma_dm = scalarProduct(gamma.getShapedArray(),dm.getShapedArray(),sigma_m,L_m,xvec,yvec,zvec,covC)
a = dm.m*(beta*gamma_dm - scalarProduct(dgamma.getShapedArray(),Fdot.getShapedArray(),sigma_m,L_m,xvec,yvec,zvec,covC))
a -= v.m*gamma_dm
a /= scalarProduct(dgamma.getShapedArray(),dm.getShapedArray(),sigma_m,L_m,xvec,yvec,zvec,covC)
out.m + a
print("Beta: {}".format(beta))
print("Difference: {}".format(np.dot(out.m,gamma.m)/np.linalg.norm(out.m)/np.linalg.norm(gamma.m)))
if resettable:
out.save(filename)
return filename
else:
return out
def pull_Fdot(resettable,filename):
if resettable:
return TriCubic(filename=filename)
else:
return filename
def pull_gamma(resettable,filename):
if resettable:
return TriCubic(filename=filename)
else:
return filename
def store_gamma(resettable,outputfolder,n,rays, g, dobs, i0, K_ne, mTCI, mPrior, CdCt, sigma_m, Nkernel, sizeCell):
filename='{}/gamma_{}.hdf5'.format(outputfolder,n)
if os.path.isfile(filename) and resettable:
return filename
gradient = computeAdjoint_dask(rays, g, dobs, i0, K_ne, mTCI, mPrior.getShapedArray(), CdCt, sigma_m, Nkernel, sizeCell)
TCI = TriCubic(mTCI.xvec,mTCI.yvec,mTCI.zvec,gradient)
if resettable:
TCI.save(filename)
return filename
else:
return TCI
def plot_gamma(outputfolder,n,TCI):
foldername = '{}/gamma_{}'.format(outputfolder,n)
animateTCISlices(TCI,foldername,numSeconds=20.)
return foldername
def store_forwardEq(resettable,outputfolder,n,templateDatapack,antIdx,timeIdx,dirIdx,rays,K_ne,mTCI,i0):
filename = "{}/g_{}.hdf5".format(outputfolder,n)
if os.path.isfile(filename) and resettable:
return filename
assert not np.any(np.isnan(mTCI.m)), "nans in model"
g = forwardEquation(rays,K_ne,mTCI,i0)
assert not np.any(np.isnan(g)), "nans in g"
datapack = templateDatapack.clone()
datapack.set_dtec(g,antIdx=antIdx,timeIdx=timeIdx,dirIdx = dirIdx)
dobs = templateDatapack.get_dtec(antIdx=antIdx,timeIdx=timeIdx,dirIdx = dirIdx)
vmin = np.min(dobs)
vmax = np.max(dobs)
plotDataPack(datapack,antIdx=antIdx,timeIdx=timeIdx,dirIdx = dirIdx,
figname=filename.split('.')[0], vmin = vmin, vmax = vmax)#replace('hdf5','png'))
if resettable:
datapack.save(filename)
return filename
else:
return datapack
def pull_forwardEq(resettable,filename,antIdx,timeIdx,dirIdx):
if resettable:
datapack = DataPack(filename=filename)
g = datapack.get_dtec(antIdx=antIdx,timeIdx=timeIdx,dirIdx = dirIdx)
return g
else:
g = filename.get_dtec(antIdx=antIdx,timeIdx=timeIdx,dirIdx = dirIdx)
return g
def calcEpsilon(outputfolder,n,phi,mTCI,rays,K_ne,i0,g,dobs,CdCt):
bins = max(10,int(np.ceil(np.sqrt(g.size))))
drad = 8.44797256e-7/120e6
dtau = 1.34453659e-7/120e6**2
r = dtau/drad*1e9#mu sec factor
plt.figure()
plt.hist(g.flatten(),alpha=0.2,label='g',bins=bins)
plt.hist(dobs.flatten(),alpha=0.2,label='dobs',bins=bins)
plt.legend(frameon=False)
plt.savefig("{}/data-hist-{}.png".format(outputfolder,n))
plt.clf()
plt.hist((g-dobs).flatten()*drad*1e16,bins=bins)
plt.xlabel(r"$d\phi$ [rad] | {:.2f} delay [ns]".format(r))
plt.savefig("{}/datadiff-hist-{}.png".format(outputfolder,n))
plt.close('all')
ep,S,reduction = lineSearch(rays,K_ne,mTCI,i0,phi.getShapedArray(),g,dobs,CdCt,figname="{}/lineSearch{}".format(outputfolder,n))
return ep,S,reduction
def store_m(resettable,outputfolder,n,mTCI0,phi,rays,K_ne,i0,g,dobs,CdCt,stateFile):
filename = "{}/m_{}.hdf5".format(outputfolder,n)
with h5py.File(stateFile,'w') as state:
if '/{}/epsilon_n'.format(n) not in state:
epsilon_n,S,reduction = calcEpsilon(outputfolder,n,phi,mTCI0,rays,K_ne,i0,g,dobs,CdCt)
state['/{}/epsilon_n'.format(n)] = epsilon_n
state['/{}/S'.format(n)] = S
state['/{}/reduction'.format(n)] = reduction
state.flush()
else:
epsilon_n,S,reduction = state['/{}/epsilon_n'.format(n)], state['/{}/S'.format(n)], state['/{}/reduction'.format(n)]
if os.path.isfile(filename) and resettable:
return filename
mTCI = mTCI0.copy()
mTCI.m -= epsilon_n*phi.m
if resettable:
mTCI.save(filename)
return filename
else:
return mTCI
def pull_m(resettable,filename):
if resettable:
return TriCubic(filename=filename)
else:
return filename #object not filename
def scalarProduct(a,b,sigma_m,L_m,xvec,yvec,zvec,covC):
out = covC.contract(b)
out *= a
out = simps(simps(simps(out,zvec,axis=2),yvec,axis=1),xvec,axis=0)
#out /= (np.pi*8.*sigma_m**2 * L_m**3)
return out
def calcBeta(dgamma, v, dm,sigma_m,L_m,covC):
xvec = dgamma.xvec
yvec = dgamma.yvec
zvec = dgamma.zvec
beta = 1. + scalarProduct(dgamma.getShapedArray(),v.getShapedArray(),sigma_m,L_m,xvec,yvec,zvec,covC)/(scalarProduct(dgamma.getShapedArray(),dm.getShapedArray(),sigma_m,L_m,xvec,yvec,zvec,covC) + 1e-15)
print("E[|dm|] = {} | E[|dgamma|] = {}".format(np.mean(np.abs(dm.m)),np.mean(np.abs(dgamma.m))))
return beta
def diffTCI(TCI1,TCI2):
TCI = TCI1.copy()
TCI.m -= TCI2.m
return TCI
def store_F0dot(resettable,outputfolder,n,F0,gamma):
filename="{}/F0gamma{}.hdf5".format(outputfolder,n)
if os.path.isfile(filename) and resettable:
return filename
out = gamma.copy()
out.m *= F0.m
if resettable:
out.save(filename)
return filename
else:
return out
def plot_model(outputfolder,n,mModel,mPrior,K_ne):
tmp = mModel.m.copy()
np.exp(mModel.m,out=mModel.m)
mModel.m *= K_ne
mModel.m -= K_ne*np.exp(mPrior.m)
foldername = '{}/m_{}'.format(outputfolder,n)
animateTCISlices(mModel,foldername,numSeconds=20.)
mModel.m = tmp
print("Animation of model - prior in {}".format(foldername))
return foldername
def createBFGSDask(resettable,outputfolder,N,datapack,L_ne,sizeCell,i0, antIdx=-1, dirIdx=-1, timeIdx = [0]):
try:
os.makedirs(outputfolder)
except:
pass
print("Using output folder: {}".format(outputfolder))
stateFile = "{}/state".format(outputfolder)
straightLineApprox = True
tmax = 1000.
antennas,antennaLabels = datapack.get_antennas(antIdx = antIdx)
patches, patchNames = datapack.get_directions(dirIdx = dirIdx)
times,timestamps = datapack.get_times(timeIdx=timeIdx)
datapack.setReferenceAntenna(antennaLabels[i0])
#plotDataPack(datapack,antIdx = antIdx, timeIdx = timeIdx, dirIdx = dirIdx,figname='{}/dobs'.format(outputfolder))
dobs = datapack.get_dtec(antIdx = antIdx, timeIdx = timeIdx, dirIdx = dirIdx)
Na = len(antennas)
Nt = len(times)
Nd = len(patches)
fixtime = times[Nt>>1]
phase = datapack.getCenterDirection()
arrayCenter = datapack.radioArray.getCenter()
#Average time axis down and center on fixtime
if Nt == 1:
var = (0.5*np.percentile(dobs[dobs>0],25) + 0.5*np.percentile(-dobs[dobs<0],25))**2
Cd = np.ones([Na,1,Nd],dtype=np.double)*var
Ct = (np.abs(dobs)*0.05)**2
CdCt = Cd + Ct
else:
dt = times[1].gps - times[0].gps
print("Averaging down window of length {} seconds [{} timestamps]".format(dt*Nt, Nt))
Cd = np.stack([np.var(dobs,axis=1)],axis=1)
dobs = np.stack([np.mean(dobs,axis=1)],axis=1)
Ct = (np.abs(dobs)*0.05)**2
CdCt = Cd + Ct
timeIdx = [Nt>>1]
times,timestamps = datapack.get_times(timeIdx=timeIdx)
Nt = len(times)
print("E[S/N]: {} +/- {}".format(np.mean(np.abs(dobs)/np.sqrt(CdCt+1e-15)),np.std(np.abs(dobs)/np.sqrt(CdCt+1e-15))))
vmin = np.min(datapack.get_dtec(antIdx = antIdx, timeIdx = timeIdx, dirIdx = dirIdx))
vmax = np.max(datapack.get_dtec(antIdx = antIdx, timeIdx = timeIdx, dirIdx = dirIdx))
plotDataPack(datapack,antIdx=antIdx,timeIdx=timeIdx,dirIdx = dirIdx,
figname='{}/dobs'.format(outputfolder), vmin = vmin, vmax = vmax)#replace('hdf5','png'))
neTCI = createInitialModel(datapack,antIdx = antIdx, timeIdx = timeIdx, dirIdx = dirIdx, zmax = tmax,spacing=sizeCell)
#make uniform
#neTCI.m[:] = np.mean(neTCI.m)
neTCI.save("{}/nePriori.hdf5".format(outputfolder))
rays = calcRays(antennas,patches,times, arrayCenter, fixtime, phase, neTCI, datapack.radioArray.frequency,
straightLineApprox, tmax, neTCI.nz)
mTCI = neTCI.copy()
K_ne = np.mean(mTCI.m)
mTCI.m /= K_ne
np.log(mTCI.m,out=mTCI.m)
Nkernel = max(1,int(float(L_ne)/sizeCell))
sigma_m = np.log(10.)#ne = K*exp(m+dm) = K*exp(m)*exp(dm), exp(dm) in (0.1,10) -> dm = (log(10) - log(0.1))/2.
covC = CovarianceClass(mTCI,sigma_m,L_ne,7./2.)
#uvw = UVW(location = datapack.radioArray.getCenter().earth_location,obstime = fixtime,phase = phase)
#ants_uvw = antennas.transform_to(uvw).cartesian.xyz.to(au.km).value.transpose()
#dirs_uvw = patches.transform_to(uvw).cartesian.xyz.value.transpose()
F0 = precondition(neTCI, datapack,antIdx=antIdx, dirIdx=dirIdx, timeIdx = timeIdx)
F0.m *= 0.
F0.m += 1.
#
dsk = {}
for n in range(int(N)):
#g_n
dsk['store_forwardEq{}'.format(n)] = (store_forwardEq,resettable,'outputfolder',n,'templateDatapack','antIdx','timeIdx','dirIdx','rays',
'K_ne','pull_m{}'.format(n),'i0')
dsk['pull_forwardEq{}'.format(n)] = (pull_forwardEq,resettable,'store_forwardEq{}'.format(n),'antIdx','timeIdx','dirIdx')
#gradient
dsk['store_gamma{}'.format(n)] = (store_gamma,resettable,'outputfolder',n,'rays', 'pull_forwardEq{}'.format(n), 'dobs', 'i0', 'K_ne',
'pull_m{}'.format(n),'mprior', 'CdCt', 'sigma_m', 'Nkernel', 'sizeCell')
dsk['pull_gamma{}'.format(n)] = (pull_gamma,resettable,'store_gamma{}'.format(n))
#m update
dsk['store_m{}'.format(n+1)] = (store_m,resettable,'outputfolder',n+1,'pull_m{}'.format(n),'pull_phi{}'.format(n),'rays',
'K_ne','i0','pull_forwardEq{}'.format(n),'dobs','CdCt','stateFile')
dsk['pull_m{}'.format(n+1)] = (pull_m,resettable,'store_m{}'.format(n+1))
dsk['plot_m{}'.format(n+1)] = (plot_model,'outputfolder',n+1,'pull_m{}'.format(n+1),'mprior','K_ne')
dsk['plot_gamma{}'.format(n)] = (plot_gamma,'outputfolder',n,'pull_gamma{}'.format(n))
#phi
dsk['pull_phi{}'.format(n)] = (pull_Fdot,resettable,'store_F{}(gamma{})'.format(n,n))
dsk['store_F{}(gamma{})'.format(n+1,n+1)] = (store_Fdot,resettable,'outputfolder', n+1, n+1 ,
'pull_F{}(gamma{})'.format(n,n+1),
'pull_gamma{}'.format(n+1),
'beta{}'.format(n),
'dm{}'.format(n),
'dgamma{}'.format(n),
'v{}'.format(n),
'sigma_m','L_m','covC'
)
for i in range(1,n+1):
dsk['store_F{}(gamma{})'.format(i,n+1)] = (store_Fdot, resettable,'outputfolder',i, n+1 ,
'pull_F{}(gamma{})'.format(i-1,n+1),
'pull_gamma{}'.format(n+1),
'beta{}'.format(i-1),
'dm{}'.format(i-1),
'dgamma{}'.format(i-1),
'v{}'.format(i-1),
'sigma_m','L_m','covC'
)
dsk['pull_F{}(gamma{})'.format(i,n+1)] = (pull_Fdot,resettable,'store_F{}(gamma{})'.format(i,n+1))
#should replace for n=0
dsk['store_F0(gamma{})'.format(n)] = (store_F0dot, resettable,'outputfolder',n, 'pull_F0','pull_gamma{}'.format(n))
dsk['pull_F0(gamma{})'.format(n)] = (pull_Fdot,resettable,'store_F0(gamma{})'.format(n))
# #epsilon_n
# dsk['ep{}'.format(n)] = (calcEpsilon,n,'pull_phi{}'.format(n),'pull_m{}'.format(n),'rays',
# 'K_ne','i0','pull_forwardEq{}'.format(n),'dobs','CdCt')
#
dsk['beta{}'.format(n)] = (calcBeta,'dgamma{}'.format(n),'v{}'.format(n),'dm{}'.format(n),'sigma_m','L_m','covC')
dsk['dgamma{}'.format(n)] = (diffTCI,'pull_gamma{}'.format(n+1),'pull_gamma{}'.format(n))
dsk['dm{}'.format(n)] = (diffTCI,'pull_m{}'.format(n+1),'pull_m{}'.format(n))
dsk['v{}'.format(n)] = (diffTCI,'pull_F{}(gamma{})'.format(n,n+1),'pull_phi{}'.format(n))
dsk['pull_F0'] = F0
dsk['templateDatapack'] = datapack
dsk['antIdx'] = antIdx
dsk['timeIdx'] = timeIdx
dsk['dirIdx'] = dirIdx
dsk['pull_m0'] = TriCubic(filename='output/test/bfgs_3_1/m_25.hdf5')
dsk['i0'] = i0
dsk['K_ne'] = K_ne
dsk['dobs'] = dobs
dsk['mprior'] = mTCI
dsk['CdCt'] = CdCt
dsk['sigma_m'] = sigma_m
dsk['Nkernel'] = Nkernel
dsk['L_m'] = L_ne
dsk['sizeCell'] = sizeCell
dsk['covC'] = covC
#calc rays
#dsk['rays'] = (calcRays_dask,'antennas','patches','times', 'arrayCenter', 'fixtime', 'phase', 'neTCI', 'frequency', 'straightLineApprox','tmax')
dsk['rays'] = rays
dsk['outputfolder'] = outputfolder
dsk['resettable'] = resettable
dsk['stateFile'] = stateFile
return dsk
class TrackingCallbacks(Callback):
def _start(self,dsk):
self.startTime = clock()
def _pretask(self, key, dask, state):
"""Print the key of every task as it's started"""
self.t1 = clock()
print('Starting {} at {} seconds'.format(key,self.t1-self.startTime))
def _posttask(self,key,result,dsk,state,id):
print("{} took {} seconds".format(repr(key),clock() - self.t1))
def _finish(self,dsk,state,errored):
self.endTime = clock()
dt = (self.endTime - self.startTime)
print("Approximate time to complete: {} time units".format(dt))
if __name__=='__main__':
from RealData import DataPack
from AntennaFacetSelection import selectAntennaFacets
from dask.diagnostics import Profiler, ResourceProfiler, CacheProfiler
from dask.diagnostics import visualize
#from InitialModel import createTurbulentlModel
i0 = 0
datapack = DataPack(filename="output/test/simulate/simulate_3/datapackSim.hdf5")
#datapack = DataPack(filename="output/test/datapackObs.hdf5")
#flags = datapack.findFlaggedAntennas()
#datapack.flagAntennas(flags)
datapackSel = selectAntennaFacets(20, datapack, antIdx=-1, dirIdx=-1, timeIdx = np.arange(1))
#pertTCI = createTurbulentlModel(datapackSel,antIdx = -1, timeIdx = -1, dirIdx = -1, zmax = 1000.)
L_ne = 25.
sizeCell = 5.
dsk = createBFGSDask(True, "output/test/bfgs_3_2/", 5,datapackSel,L_ne,sizeCell,i0, antIdx=-1, dirIdx=-1, timeIdx = np.arange(1))
#dot_graph(dsk,filename="{}/BFGS_graph".format(outputfolder),format='png')
#dot_graph(dsk,filename="{}/BFGS_graph".format(outputfolder),format='svg')
#client = Client()
#with TrackingCallbacks():
with Profiler() as prof, ResourceProfiler(dt=0.25) as rprof, CacheProfiler() as cprof:
get(dsk,['plot_m5'])
visualize([prof,rprof,cprof])