We are interested in inference after a selection procedure, say, $S$, for which the selection event is the intersection of a list of quadratic inequalities:
$$ \begin{aligned} \{S(y)=s\} = \cap_{i \in I(s)} \{y: y^TQ_iy + a_i^Ty\leq b_i \}. \end{aligned} $$The quadratic forms are not assumed non-negative definite, but we can, without loss of generality assume they are symmetric.
If we consider forward stepwise with groups the selection procedure is just choosing the first group, $g^*$ and we have $$ \{g^*(y)=g\} = \cap_{h \neq g} \{y: y^T(X_hX_h^T/w_h^2 - X_gX_g^T/w_g^2)y \leq 0\}. $$
We want to slice through the selection event along a ray with direction $\eta$ that passes through $y$. That is, we need to find $$ \begin{aligned} \left\{t: S(y+t\eta)=s \right\} &= \cap_{i \in I(s)} \left\{t: (y+t\eta)^T Q_i (y+t\eta) + a_i^T(y+t\eta) \leq b_i\right\}. \\ \end{aligned} $$
For any given $i \in I(s)$ we see that $$ \begin{aligned} \left\{t: (y+t\eta)^T Q_i (y+t\eta) + a_i^T(y+t\eta) \leq b_i\right\} &= \left\{t: t^2 \cdot \eta^TQ_i\eta + t \cdot(2 y^TQ_i\eta + a_i^T\eta) + y^TQ_iy + a_i^Ty - b_i \leq 0 \right\} \\ &= \text{Intervals}(Q_i,a_i,b_i,y,\eta) \end{aligned} $$
Each value of $\text{Intervals}(Q_i,a_i,b_i,y,\eta)$ above is one of:
In [15]:
import numpy as np
%load_ext rmagic
DEBUG = False
In [16]:
def Intervals(Q, a, b, y, eta, tol=1.e-8):
quad_term = (eta * np.dot(Q, eta)).sum()
linear_term = 2 * (y * np.dot(Q, eta)).sum() + (a * eta).sum()
constant_term = (y * np.dot(Q, y)).sum() + (a * y).sum() - b
if DEBUG: print quad_term, linear_term, constant_term
discr = linear_term**2 - 4 * quad_term * constant_term
if discr < 0:
if DEBUG: print 'case5'
return []
elif np.fabs(quad_term / constant_term) > tol:
L, U = sorted(((- linear_term - np.sqrt(discr)) / (2 * quad_term),
(- linear_term + np.sqrt(discr)) / (2 * quad_term)))
if quad_term < 0:
if DEBUG: print 'case1'
return [(-np.inf,L), (U, np.inf)]
else:
if DEBUG: print 'case2', discr
return [(L, U)]
elif np.fabs(linear_term / constant_term) > tol:
if linear_term > 0:
if DEBUG: print 'case3'
return [(-np.inf, -constant_term / linear_term)]
else:
if DEBUG: print 'case4'
return [(-constant_term, np.inf)]
elif constant_term < 0:
if DEBUG: print 'case5'
return [(-np.inf,np.inf)]
else:
if DEBUG: print 'case6'
return []
In [17]:
Q1, Q2, Q3, Q4 = np.identity(4)
y1, eta1, a1 = np.random.standard_normal((3,4))
b1 = 0
Intervals(Q1, a1, b1, y1, eta1)
Out[17]:
In [18]:
Q2 = -np.identity(4)
y2, eta2, a2 = np.random.standard_normal((3,4))
b2 = 0
Intervals(Q2, a2, b2, y2, eta2)
Out[18]:
In [19]:
Q = np.zeros((4,4))
y, eta, a = np.random.standard_normal((3,4))
b = 0
Intervals(Q, a, b, y, eta)
Out[19]:
In [20]:
Q = np.diag([-1,-1,1,1])
y, eta, a = np.random.standard_normal((3,4))
b = 0
Intervals(Q, a, b, y, eta)
Out[20]:
To find the intersection of a set of intervals, one only needs to track the upper and lower bounds. In our case, there will be some intervals and some complements of intervals.
This can be handled by handling the union of the complements and then taking the complement.
In [21]:
def find_intersection(y, eta, *quadratic_list):
'''
Find the intersection of
[Interval(Q, a, b, y, eta) for Q, a, b in quadratic_list]
'''
upper_int, lower_int = (np.inf, -np.inf)
upper_union, lower_union = (-np.inf, np.inf)
for Q, a, b in quadratic_list:
intervals = Intervals(Q, a, b, y, eta)
if DEBUG: print intervals
if len(intervals) == 1: # a single interval
upper_int = min(upper_int, intervals[0][1])
lower_int = max(upper_int, intervals[0][0])
elif len(intervals) == 2: # complement of an interval
L, U = intervals[0][1], intervals[1][0] # by construction our intervals always have this form
# could also use sorted(intervals[np.isfinite(intervals)])
upper_union = max(upper_union, U)
lower_union = min(lower_union, L)
complement = [(-np.inf,lower_union),(upper_union, np.inf)]
if upper_int < upper_union and lower_int > lower_union:
return [] # the intersection of the intervals is in the complement
elif upper_int >= upper_union:
if lower_int >= lower_union:
if DEBUG: print 'case1'
return [(max(lower_int, lower_union), upper_int)]
else:
if DEBUG: print 'case2'
return [(lower_int, lower_union), (upper_union, upper_int)]
else: # lower_int <= lower_union
if upper_int <= upper_union:
if DEBUG: print 'case3'
return [(lower_int, min(upper_int, upper_union))]
else:
if DEBUG: print 'case4'
return [(lower_int, lower_union), (upper_int, upper_union)]
In [28]:
%%R -o X,Y,groups,weights,sigma
source('http://statweb.stanford.edu/~jtaylo/notebooks/group_lasso.R')
set.seed(0)
n = 20
p = 10
sigma = 1.3
X = matrix(rnorm(n*p),n,p)
Y = rnorm(n)*sigma
groups = c(1,1,2,2,2,3,3,4,4,5)
weights = c(2,2.5,2,2,1.4)
results = group_lasso_knot(X, Y, groups, weights)
Z = results$L / (sqrt(results$var)*sigma)
Zlower = results$lower_bound / (sqrt(results$var)*sigma)
print(data.frame(Z,Zlower))
print(pvalue(results$L, results$lower_bound, results$upper_bound, sqrt(results$var), results$k, sigma=sigma))
In [29]:
%%R
print(results)
In [30]:
Qs = []
Xs = []
norms = []
for g, w in zip(np.unique(groups), weights):
Xg = X[:,groups==g] / w
Qs.append(np.dot(Xg, Xg.T))
Xs.append(Xg)
norms.append(np.linalg.norm(np.dot(Xg.T,Y)))
max(norms)
Out[30]:
In [31]:
imax = np.argmax(norms)
P = np.dot(Xs[imax], np.linalg.pinv(Xs[imax]))
eta = np.dot(P, Y)
eta /= np.linalg.norm(eta)
W = (eta*Y).sum()
final_Qs = []
for i in range(len(Qs)):
if i != imax:
final_Qs.append((Qs[i] - Qs[imax]).copy())
I1, I2 = find_intersection(Y, eta, *[(Q, 0, 0) for Q in final_Qs])
For the group LASSO first step, what we called Zlower can be recoverd from W and I2[0].
I'm not quite sure what can be done with I1, if anything. It can probably be used to get a little more power,
but I'm not 100% sure.
In [26]:
V1, V2 = W / sigma, (W+I2[0]) / sigma
%R -i V1,V2
V1, V2
Out[26]:
In [27]:
%%R
print((1 - pchisq(V1^2, results$k)) / (1 - pchisq(V2^2, results$k)))
This is the same $p$-value as above.
In [27]:
In [27]: