To create the training data the SEEK package (https://github.com/cosmo-ethz/seek) has to be installed
In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import glob
plt.rcParams['image.cmap'] = 'gist_earth'
In [2]:
!wget -q -r -nH -np --cut-dirs=2 https://people.phys.ethz.ch/~ipa/cosmo/bgs_example_data/
In [3]:
!mkdir -p bgs_example_data/seek_cache
In [5]:
!seek --file-prefix='./bgs_example_data' --post-processing-prefix='bgs_example_data/seek_cache' --chi-1=20 --overwrite=True seek.config.process_survey_fft
In [3]:
from scripts.rfi_launcher import DataProvider
from tf_unet import unet
In [3]:
files = glob.glob('bgs_example_data/seek_cache/*')
In [5]:
data_provider = DataProvider(600, files)
net = unet.Unet(channels=data_provider.channels,
n_class=data_provider.n_class,
layers=3,
features_root=64,
cost_kwargs=dict(regularizer=0.001),
)
In [6]:
trainer = unet.Trainer(net, optimizer="momentum", opt_kwargs=dict(momentum=0.2))
path = trainer.train(data_provider, "./unet_trained_bgs_example_data",
training_iters=32,
epochs=1,
dropout=0.5,
display_step=2)
In [10]:
data_provider = DataProvider(10000, files)
x_test, y_test = data_provider(1)
prediction = net.predict(path, x_test)
In [11]:
fig, ax = plt.subplots(1,3, figsize=(12,4))
ax[0].imshow(x_test[0,...,0], aspect="auto")
ax[1].imshow(y_test[0,...,1], aspect="auto")
ax[2].imshow(prediction[0,...,1], aspect="auto")
Out[11]:
In [ ]: