In [ ]:
# %load /Users/dsuess/Code/Pythonlibs/cluster_template.ipy
import ipyparallel
import numpy as np
from os import environ

CLUSTER_ID = environ.get('CLUSTER_ID', None)
_CLIENTS = ipyparallel.Client(cluster_id=CLUSTER_ID)
_VIEW = _CLIENTS.load_balanced_view()
print("Kernels available: {}".format(len(_CLIENTS)))

RGEN = np.random.RandomState(seed=1234)
for client_id, c in enumerate(_CLIENTS):
    seed = RGEN.randint(10000000)
    c.push({'RGEN': np.random.RandomState(seed=seed),
            'CLIENT_ID': client_id}
          )

In [ ]:
import numpy as np
import matplotlib.pyplot as pl
import itertools as it

import csalgs.lowrank as lr
from tools.helpers import AsyncTaskWatcher
import h5py
import functools as ft
from time import sleep

from h5merge import merge as h5merge

In [ ]:
%%px --local

from collections import Iterator
from scipy.linalg import svdvals


def new_h5group(group, name, overwrite=True):
    try:
        del group[name]
    except KeyError:
        pass
    return group.create_group(name) if overwrite else group[name]

def take(n, iterator):
    iterator = iterator if isinstance(iterator, Iterator) else iter(iterator)
    for _ in range(n):
        a = next(iterator)
    return a


def condition_number(X, threshold=1e-10):
    sigma = svdvals(X)
    sigma_true = sigma[sigma > threshold]
    return sigma[0] / sigma_true[-1]

In [ ]:
%%px --local

import numpy as np
import csalgs.lowrank as lr
import os
import h5py

def recover_random(index, dim, rank, nr_measurements, cnr_scale, iterations, tmpfile):
    with h5py.File(tmpfile.format(CLIENT_ID)) as dump:
        try:
            dimgroup = dump['d={}'.format(dim)]
        except KeyError:
            dimgroup = dump.create_group('d={}'.format(dim))
            
        h5group = new_h5group(dimgroup, 'X_{}'.format(index))
        h5group.attrs['DIM'] = dim
        h5group.attrs['RANK'] = rank
        h5group.attrs['RGEN_TYPE'], h5group.attrs['RGEN_STATE'], *_ = RGEN.get_state()
        
        A = lr.sensingmat_rank1(max(nr_measurements), dim, hermitian=False, rgen=RGEN)
        X = lr.random_lowrank_matrix_cnr(dim, rank, condition_scale=cnr_scale, 
                                         hermitian=False, rgen=RGEN)
        X /= np.linalg.norm(X)
        y = np.tensordot(A, X, axes=((1, 2), (0, 1)))
        h5group['X'] = X
        # since we can get A from the RGEN_STATE and it quickly becomes too large
        #h5group['A'] = A
        
        dists = []
        for m in nr_measurements:
            U, V = take(iterations, lr.altmin_estimator(A[:m], y[:m], rank))
            X_sharp = U @ V.T
            dists.append(np.linalg.norm(X - X_sharp))
            
            mgroup = new_h5group(h5group, 'm={}'.format(m))
            mgroup.attrs['NR_MEASUREMENTS'] = m
            mgroup['X_SHARP'] = X_sharp

    return dists

In [ ]:
CS = np.linspace(3.5, 6.0, 15)
DIMS = range(20, 251, 20)
RANK = 2
SAMPLES = 100
OUTFILE = 'altmin_condition_nr.h5'
TMPFILE_MASK = 'altmin_cr={}'.format(RANK) + '_{}.h5'

atw = AsyncTaskWatcher()

for dim in DIMS:
    nr_measurements = [int(c * dim * RANK) for c in CS]
    recover = ft.partial(recover_random,
                         dim=dim, rank=RANK, nr_measurements=nr_measurements,
                         cnr_scale=[1.0], iterations=30, tmpfile=TMPFILE_MASK)
    atw.append(_VIEW.map_async(recover, range(SAMPLES)))

atw.block()

In [ ]:
infiles = [h5py.File(TMPFILE_MASK.format(client_id), 'r') for client_id, _ in enumerate(_CLIENTS)]

with h5py.File(OUTFILE, 'w') as outfile:
    root = new_h5group(outfile, 'rank={}'.format(RANK))
    h5merge(infiles, root)
    
for infile in infiles:
    infile.close()
        
infile_sel = TMPFILE_MASK.format('*')

In [ ]:
!rm $infile_sel

In [ ]:
def recons_error(root, dim):
    """returns dictionary with {nr_measurements: list_of_errors}"""
    results = dict()
    for Xgroup in root["d={}".format(dim)].values():
        X = Xgroup['X'].value
        recoveries = {name: Xsharp_group['X_SHARP'].value 
                      for name, Xsharp_group in Xgroup.items()
                      if name.startswith('m=')}
        for name, X_sharp in recoveries.items():
            tmp =  results.get(name, [])
            tmp.append(np.linalg.norm(X - X_sharp))
            results[name] = tmp

    return results

def recons_stat(root, dim, thresh):
    errors = recons_error(root, dim)
    indices = {key: np.argmin(np.abs(CS - int(key[2:]) / (RANK * dim))) 
               for key in errors}
    result = np.zeros(len(CS))
    for key, index in indices.items():
        result[index] = np.mean(np.array(errors[key]) < thresh)
    return result

THRESH = 1e-6

with h5py.File(OUTFILE, 'r') as source:
    root = source['rank={}'.format(RANK)]
    recons_errors = np.array([recons_stat(root, dim, THRESH) for dim in DIMS]).T

In [ ]:
from tools.plot import matshow

ax = pl.gca()

matshow(recons_errors[::-1], ax=ax, show=False, cmap='gray')

ax.set_xticks(range(len(DIMS)))
ax.set_xticklabels(DIMS)
ax.set_yticks(range(len(CS)))
ax.set_yticklabels(CS[::-1])

pl.show()