In [ ]:
# %load /Users/dsuess/Code/Pythonlibs/cluster_template.ipy
import ipyparallel
import numpy as np
from os import environ
CLUSTER_ID = environ.get('CLUSTER_ID', None)
_CLIENTS = ipyparallel.Client(cluster_id=CLUSTER_ID)
_VIEW = _CLIENTS.load_balanced_view()
print("Kernels available: {}".format(len(_CLIENTS)))
RGEN = np.random.RandomState(seed=1234)
for client_id, c in enumerate(_CLIENTS):
seed = RGEN.randint(10000000)
c.push({'RGEN': np.random.RandomState(seed=seed),
'CLIENT_ID': client_id}
)
In [ ]:
import numpy as np
import matplotlib.pyplot as pl
import itertools as it
import csalgs.lowrank as lr
from tools.helpers import AsyncTaskWatcher
import h5py
import functools as ft
from time import sleep
from h5merge import merge as h5merge
In [ ]:
%%px --local
from collections import Iterator
from scipy.linalg import svdvals
def new_h5group(group, name, overwrite=True):
try:
del group[name]
except KeyError:
pass
return group.create_group(name) if overwrite else group[name]
def take(n, iterator):
iterator = iterator if isinstance(iterator, Iterator) else iter(iterator)
for _ in range(n):
a = next(iterator)
return a
def condition_number(X, threshold=1e-10):
sigma = svdvals(X)
sigma_true = sigma[sigma > threshold]
return sigma[0] / sigma_true[-1]
In [ ]:
%%px --local
import numpy as np
import csalgs.lowrank as lr
import os
import h5py
def recover_random(index, dim, rank, nr_measurements, cnr_scale, iterations, tmpfile):
with h5py.File(tmpfile.format(CLIENT_ID)) as dump:
try:
dimgroup = dump['d={}'.format(dim)]
except KeyError:
dimgroup = dump.create_group('d={}'.format(dim))
h5group = new_h5group(dimgroup, 'X_{}'.format(index))
h5group.attrs['DIM'] = dim
h5group.attrs['RANK'] = rank
h5group.attrs['RGEN_TYPE'], h5group.attrs['RGEN_STATE'], *_ = RGEN.get_state()
A = lr.sensingmat_rank1(max(nr_measurements), dim, hermitian=False, rgen=RGEN)
X = lr.random_lowrank_matrix_cnr(dim, rank, condition_scale=cnr_scale,
hermitian=False, rgen=RGEN)
X /= np.linalg.norm(X)
y = np.tensordot(A, X, axes=((1, 2), (0, 1)))
h5group['X'] = X
# since we can get A from the RGEN_STATE and it quickly becomes too large
#h5group['A'] = A
dists = []
for m in nr_measurements:
U, V = take(iterations, lr.altmin_estimator(A[:m], y[:m], rank))
X_sharp = U @ V.T
dists.append(np.linalg.norm(X - X_sharp))
mgroup = new_h5group(h5group, 'm={}'.format(m))
mgroup.attrs['NR_MEASUREMENTS'] = m
mgroup['X_SHARP'] = X_sharp
return dists
In [ ]:
CS = np.linspace(3.5, 6.0, 15)
DIMS = range(20, 251, 20)
RANK = 2
SAMPLES = 100
OUTFILE = 'altmin_condition_nr.h5'
TMPFILE_MASK = 'altmin_cr={}'.format(RANK) + '_{}.h5'
atw = AsyncTaskWatcher()
for dim in DIMS:
nr_measurements = [int(c * dim * RANK) for c in CS]
recover = ft.partial(recover_random,
dim=dim, rank=RANK, nr_measurements=nr_measurements,
cnr_scale=[1.0], iterations=30, tmpfile=TMPFILE_MASK)
atw.append(_VIEW.map_async(recover, range(SAMPLES)))
atw.block()
In [ ]:
infiles = [h5py.File(TMPFILE_MASK.format(client_id), 'r') for client_id, _ in enumerate(_CLIENTS)]
with h5py.File(OUTFILE, 'w') as outfile:
root = new_h5group(outfile, 'rank={}'.format(RANK))
h5merge(infiles, root)
for infile in infiles:
infile.close()
infile_sel = TMPFILE_MASK.format('*')
In [ ]:
!rm $infile_sel
In [ ]:
def recons_error(root, dim):
"""returns dictionary with {nr_measurements: list_of_errors}"""
results = dict()
for Xgroup in root["d={}".format(dim)].values():
X = Xgroup['X'].value
recoveries = {name: Xsharp_group['X_SHARP'].value
for name, Xsharp_group in Xgroup.items()
if name.startswith('m=')}
for name, X_sharp in recoveries.items():
tmp = results.get(name, [])
tmp.append(np.linalg.norm(X - X_sharp))
results[name] = tmp
return results
def recons_stat(root, dim, thresh):
errors = recons_error(root, dim)
indices = {key: np.argmin(np.abs(CS - int(key[2:]) / (RANK * dim)))
for key in errors}
result = np.zeros(len(CS))
for key, index in indices.items():
result[index] = np.mean(np.array(errors[key]) < thresh)
return result
THRESH = 1e-6
with h5py.File(OUTFILE, 'r') as source:
root = source['rank={}'.format(RANK)]
recons_errors = np.array([recons_stat(root, dim, THRESH) for dim in DIMS]).T
In [ ]:
from tools.plot import matshow
ax = pl.gca()
matshow(recons_errors[::-1], ax=ax, show=False, cmap='gray')
ax.set_xticks(range(len(DIMS)))
ax.set_xticklabels(DIMS)
ax.set_yticks(range(len(CS)))
ax.set_yticklabels(CS[::-1])
pl.show()