We are interested in inference after a selection procedure, say, $S$, for which the selection event is the intersection of a list of quadratic inequalities:

$$ \begin{aligned} \{S(y)=s\} = \cap_{i \in I(s)} \{y: y^TQ_iy + a_i^Ty\leq b_i \}. \end{aligned} $$

The quadratic forms are not assumed non-negative definite, but we can, without loss of generality assume they are symmetric.

Example: first step of forward stepwise

If we consider forward stepwise with groups the selection procedure is just choosing the first group, $g^*$ and we have $$ \{g^*(y)=g\} = \cap_{h \neq g} \{y: y^T(X_hX_h^T/w_h^2 - X_gX_g^T/w_g^2)y \leq 0\}. $$

We want to slice through the selection event along a ray with direction $\eta$ that passes through $y$. That is, we need to find $$ \begin{aligned} \left\{t: S(y+t\eta)=s \right\} &= \cap_{i \in I(s)} \left\{t: (y+t\eta)^T Q_i (y+t\eta) + a_i^T(y+t\eta) \leq b_i\right\}. \\ \end{aligned} $$

For any given $i \in I(s)$ we see that $$ \begin{aligned} \left\{t: (y+t\eta)^T Q_i (y+t\eta) + a_i^T(y+t\eta) \leq b_i\right\} &= \left\{t: t^2 \cdot \eta^TQ_i\eta + t \cdot(2 y^TQ_i\eta + a_i^T\eta) + y^TQ_iy + a_i^Ty - b_i \leq 0 \right\} \\ &= \text{Intervals}(Q_i,a_i,b_i,y,\eta) \end{aligned} $$

Each value of $\text{Intervals}(Q_i,a_i,b_i,y,\eta)$ above is one of:

  • $\emptyset$, in which case the selection event is also the $\emptyset$ (so this never occurs);
  • $(-\infty, L(Q_i,a_i,b_i,y,\eta)] \cap [U(Q_i,a_i,b_i,y,\eta), \infty)$ if $\eta^TQ_i\eta < 0$;
  • $[L(Q_i,a_i,b_i,y,\eta), U(Q_i,a_i,b_i,y,\eta)]$ if $\eta^TQ_i\eta > 0$;
  • $(-\infty, L(Q_i,a_i,b_i,y,\eta)]$ if $\eta^TQ_i\eta=0$ and $2 y^TQ_i\eta + a_i^T\eta > 0$;
  • $[U(Q_i,a_i,b_i,y,\eta), \infty)$ if $\eta^TQ_i\eta=0$ and $2 y^TQ_i\eta + a_i^T\eta < 0$;
  • $(-\infty,\infty)$ if $\eta^TQ_i\eta=0, 2 y^TQ_i\eta + a_i^T\eta = 0$ and $y^TQ_iy + a_i^Ty - b_i \leq 0$.

In [15]:
import numpy as np
%load_ext rmagic

DEBUG = False


The rmagic extension is already loaded. To reload it, use:
  %reload_ext rmagic

In [16]:
def Intervals(Q, a, b, y, eta, tol=1.e-8):
    quad_term = (eta * np.dot(Q, eta)).sum()
    linear_term = 2 * (y * np.dot(Q, eta)).sum() + (a * eta).sum()
    constant_term = (y * np.dot(Q, y)).sum() + (a * y).sum() - b

    if DEBUG: print quad_term, linear_term, constant_term

    discr = linear_term**2 - 4 * quad_term * constant_term 
    if discr < 0:
        if DEBUG: print 'case5'
        return []
    elif np.fabs(quad_term / constant_term) > tol:
        L, U = sorted(((- linear_term - np.sqrt(discr)) / (2 * quad_term),
                    (- linear_term + np.sqrt(discr)) / (2 * quad_term)))
        if quad_term < 0:
            if DEBUG: print 'case1'
            return [(-np.inf,L), (U, np.inf)]
        else:
            if DEBUG: print 'case2', discr
            return [(L, U)]
    elif np.fabs(linear_term / constant_term) > tol:
        if linear_term > 0:
            if DEBUG: print 'case3'
            return [(-np.inf, -constant_term / linear_term)]
        else:
            if DEBUG: print 'case4'
            return [(-constant_term, np.inf)]
    elif constant_term < 0:
        if DEBUG: print 'case5'
        return [(-np.inf,np.inf)]
    else:
        if DEBUG: print 'case6'
        return []

In [17]:
Q1, Q2, Q3, Q4 = np.identity(4)
y1, eta1, a1 = np.random.standard_normal((3,4))
b1 = 0
Intervals(Q1, a1, b1, y1, eta1)


Out[17]:
[(-inf, -2.0706897009098504), (0.77985629493237207, inf)]

In [18]:
Q2 = -np.identity(4)
y2, eta2, a2 = np.random.standard_normal((3,4))
b2 = 0
Intervals(Q2, a2, b2, y2, eta2)


Out[18]:
[]

In [19]:
Q = np.zeros((4,4))
y, eta, a = np.random.standard_normal((3,4))
b = 0
Intervals(Q, a, b, y, eta)


Out[19]:
[(0.30949914672354295, inf)]

In [20]:
Q = np.diag([-1,-1,1,1])
y, eta, a = np.random.standard_normal((3,4))
b = 0
Intervals(Q, a, b, y, eta)


Out[20]:
[]

To find the intersection of a set of intervals, one only needs to track the upper and lower bounds. In our case, there will be some intervals and some complements of intervals.

This can be handled by handling the union of the complements and then taking the complement.


In [21]:
def find_intersection(y, eta, *quadratic_list):
    '''
    Find the intersection of 
    
    [Interval(Q, a, b, y, eta) for Q, a, b in quadratic_list]
    '''
    
    upper_int, lower_int = (np.inf, -np.inf)
    upper_union, lower_union = (-np.inf, np.inf)

    for Q, a, b in quadratic_list:
        intervals = Intervals(Q, a, b, y, eta)
        if DEBUG: print intervals
        if len(intervals) == 1: # a single interval
            upper_int = min(upper_int, intervals[0][1])
            lower_int = max(upper_int, intervals[0][0])
        elif len(intervals) == 2: # complement of an interval
            L, U = intervals[0][1], intervals[1][0] # by construction our intervals always have this form
            # could also use sorted(intervals[np.isfinite(intervals)])
            upper_union = max(upper_union, U)
            lower_union = min(lower_union, L)
    complement = [(-np.inf,lower_union),(upper_union, np.inf)]
    
    if upper_int < upper_union and lower_int > lower_union:
        return [] # the intersection of the intervals is in the complement
    elif upper_int >= upper_union:
        if lower_int >= lower_union:
            if DEBUG: print 'case1'
            return [(max(lower_int, lower_union), upper_int)]
        else:
            if DEBUG: print 'case2'
            return [(lower_int, lower_union), (upper_union, upper_int)]
    else: # lower_int <= lower_union
        if upper_int <= upper_union:
            if DEBUG: print 'case3'
            return [(lower_int, min(upper_int, upper_union))]
        else:
            if DEBUG: print 'case4'
            return [(lower_int, lower_union), (upper_int, upper_union)]

In [28]:
%%R -o X,Y,groups,weights,sigma
source('http://statweb.stanford.edu/~jtaylo/notebooks/group_lasso.R')
set.seed(0)
n = 20
p = 10
sigma = 1.3
X = matrix(rnorm(n*p),n,p)
Y = rnorm(n)*sigma
groups = c(1,1,2,2,2,3,3,4,4,5)
weights = c(2,2.5,2,2,1.4)
results = group_lasso_knot(X, Y, groups, weights)
Z = results$L / (sqrt(results$var)*sigma)
Zlower = results$lower_bound / (sqrt(results$var)*sigma)
print(data.frame(Z,Zlower))
print(pvalue(results$L, results$lower_bound, results$upper_bound, sqrt(results$var), results$k, sigma=sigma))


         Z   Zlower
1 2.650602 1.296633
[1] 0.06909894

In [29]:
%%R
print(results)


$L
[1] 8.502982

$lower_bound
[1] 4.159526

$upper_bound
[1] Inf

$var
[1] 6.089294

$k
[1] 2

$chi_max
[1] 3.445782

$gmax
[1] 4


In [30]:
Qs = []
Xs = []
norms = []
for g, w in zip(np.unique(groups), weights):
    Xg = X[:,groups==g] / w
    Qs.append(np.dot(Xg, Xg.T))
    Xs.append(Xg)
    norms.append(np.linalg.norm(np.dot(Xg.T,Y)))
max(norms)


Out[30]:
8.5029821317815024

In [31]:
imax = np.argmax(norms)
P = np.dot(Xs[imax], np.linalg.pinv(Xs[imax]))
eta = np.dot(P, Y)
eta /= np.linalg.norm(eta)
W = (eta*Y).sum() 
final_Qs = []
for i in range(len(Qs)):
    if i != imax:
        final_Qs.append((Qs[i] - Qs[imax]).copy())
I1, I2 = find_intersection(Y, eta, *[(Q, 0, 0) for Q in final_Qs])

For the group LASSO first step, what we called Zlower can be recoverd from W and I2[0]. I'm not quite sure what can be done with I1, if anything. It can probably be used to get a little more power, but I'm not 100% sure.


In [26]:
V1, V2 = W / sigma, (W+I2[0]) / sigma
%R -i V1,V2
V1, V2


Out[26]:
(array([ 2.65060154]), array([ 1.29663281]))

In [27]:
%%R
print((1 - pchisq(V1^2, results$k)) / (1 - pchisq(V2^2, results$k)))


[1] 0.06909894

This is the same $p$-value as above.


In [27]:


In [27]: