ML for Causal Inference: High-Dimensional Controls


In [10]:
library(simstudy)
library(glmnet)
library(stats)
library(rdd)
library(ggplot2)

Simulate Data


In [2]:
set.seed(1)
# Number of Observations
N <- 1e3
total.covar <- 50 + 1e3
# Number of covariates (excluding W and unobservable)
p <- total.covar - 2
# Simulate Data
mu.vector <- rep(0, total.covar)
variance.vector <- abs(rnorm(total.covar, mean = 1, sd = .5))
simulated.data <- as.data.frame.matrix(genCorGen(n = N, nvars = total.covar, params1 = mu.vector, params2 = variance.vector, dist = 'normal',  rho = .5,
                            corstr = 'ar1', wide='True'))[2:(total.covar+1)]
colnames(simulated.data)[1] <- 'W' # Running variable for RDD
colnames(simulated.data)[total.covar] <- 'C' # Unobservable variable
# Random assignment for A/B test
T <- rep(0, N)
T[0:(N/2)] <- 1
T <- sample(T)
X <- simulated.data[, 2:(total.covar-1)]
covariate.names <- colnames(X)
# Independent error terms
error <- rnorm(n = N)
# Make W a function of the X's and unobservable (for RDD)
simulated.data$W <- simulated.data$W + .5 * simulated.data$C + 3 * simulated.data$V80 - 6 * simulated.data$V81
# Assign treatment, based on a threshold along W
treated <- (simulated.data$W > 0) * 1.0
# True coefficients on controls
beta.true.linear <- rnorm(p, mean = 5, sd = 5)
beta.true.linear[30:p] <- 0
# Functional form of Y for A/B test:
Y.ab <- 2.0 * T + data.matrix(X) %*% beta.true.linear + .6 * simulated.data$C + error
# Functional form of Y for RDD (function of treatment, W, X's, and unobservable C)
Y.rdd <- 1.2 * treated - 4 * simulated.data$W  + data.matrix(X) %*% beta.true.linear + .6 * simulated.data$C + error
df <- cbind(Y.ab, Y.rdd, T, simulated.data)
colnames(df)[1:3] <- c('Y.ab', 'Y.rdd', 'T')
X.colnames <- colnames(X)

A/B Test

Use LASSO of Y on X to select H


In [3]:
lasso.fit.outcome <- cv.glmnet(data.matrix(X), df$Y.ab, alpha=1)
coef <- predict(lasso.fit.outcome, type = "nonzero")
H <- X.colnames[unlist(coef)]
# Variables selected by LASSO:
H


  1. 'V2'
  2. 'V3'
  3. 'V4'
  4. 'V6'
  5. 'V7'
  6. 'V8'
  7. 'V9'
  8. 'V10'
  9. 'V11'
  10. 'V12'
  11. 'V13'
  12. 'V14'
  13. 'V15'
  14. 'V16'
  15. 'V17'
  16. 'V18'
  17. 'V19'
  18. 'V20'
  19. 'V21'
  20. 'V22'
  21. 'V23'
  22. 'V24'
  23. 'V25'
  24. 'V26'
  25. 'V27'
  26. 'V28'
  27. 'V29'
  28. 'V30'

Use LASSO of T on X to select K


In [4]:
lasso.fit.propensity <- cv.glmnet(data.matrix(X), df$T, alpha=1)
coef <- predict(lasso.fit.propensity, type = "nonzero")
K <- X.colnames[unlist(coef)]
# Variables selected by LASSO:
K


  1. 'V106'
  2. 'V150'
  3. 'V170'
  4. 'V194'
  5. 'V462'
  6. 'V743'
  7. 'V745'
  8. 'V754'
  9. 'V843'
  10. 'V985'

Perform OLS of Y on T, Controlling for H union K


In [5]:
# Union of selected variables:
H_union_K.names <- unique(c(H, K))
H_union_K.names
sum.H_union_K <- paste(H_union_K.names, collapse = " + ")
eq.H_union_K <- paste("Y.ab ~ T + ", sum.H_union_K)

# OLS regression, using all covariates selected by double selection
fit.double <- lm(eq.H_union_K, data = df)
T.double <- fit.double$coefficients[2]
ci.double <- confint(fit.double, 'T', level = 0.95)

# Results:
T.double
ci.double
summary(fit.double)


  1. 'V2'
  2. 'V3'
  3. 'V4'
  4. 'V6'
  5. 'V7'
  6. 'V8'
  7. 'V9'
  8. 'V10'
  9. 'V11'
  10. 'V12'
  11. 'V13'
  12. 'V14'
  13. 'V15'
  14. 'V16'
  15. 'V17'
  16. 'V18'
  17. 'V19'
  18. 'V20'
  19. 'V21'
  20. 'V22'
  21. 'V23'
  22. 'V24'
  23. 'V25'
  24. 'V26'
  25. 'V27'
  26. 'V28'
  27. 'V29'
  28. 'V30'
  29. 'V106'
  30. 'V150'
  31. 'V170'
  32. 'V194'
  33. 'V462'
  34. 'V743'
  35. 'V745'
  36. 'V754'
  37. 'V843'
  38. 'V985'
T: 2.04577777861444
2.5 %97.5 %
T1.8786372.212918
Call:
lm(formula = eq.H_union_K, data = df)

Residuals:
    Min      1Q  Median      3Q     Max 
-4.2479 -0.7888 -0.0127  0.8226  4.3081 

Coefficients:
             Estimate Std. Error  t value Pr(>|t|)    
(Intercept) -0.024681   0.059058   -0.418    0.676    
T            2.045778   0.085170   24.020   <2e-16 ***
V2           9.659512   0.045164  213.875   <2e-16 ***
V3           6.454994   0.070120   92.056   <2e-16 ***
V4           6.845535   0.036900  185.516   <2e-16 ***
V6           1.815218   0.062778   28.915   <2e-16 ***
V7           6.672606   0.046696  142.896   <2e-16 ***
V8           7.484199   0.046266  161.765   <2e-16 ***
V9           3.073507   0.046074   66.708   <2e-16 ***
V10          3.419821   0.058722   58.237   <2e-16 ***
V11          3.547168   0.039822   89.075   <2e-16 ***
V12         -4.823915   0.046970 -102.702   <2e-16 ***
V13          9.315493   0.064399  144.653   <2e-16 ***
V14          8.510116   0.158479   53.699   <2e-16 ***
V15          8.720919   0.043183  201.954   <2e-16 ***
V16          6.351646   0.052601  120.751   <2e-16 ***
V17          1.755003   0.053645   32.715   <2e-16 ***
V18          3.695994   0.043067   85.820   <2e-16 ***
V19          4.002368   0.045420   88.119   <2e-16 ***
V20          8.848228   0.045442  194.714   <2e-16 ***
V21         -2.658985   0.043853  -60.634   <2e-16 ***
V22         14.186505   0.044077  321.855   <2e-16 ***
V23          5.140462   0.051819   99.199   <2e-16 ***
V24          7.139716   0.715824    9.974   <2e-16 ***
V25          4.227116   0.047177   89.601   <2e-16 ***
V26         -5.499456   0.055262  -99.517   <2e-16 ***
V27          4.096075   0.056024   73.113   <2e-16 ***
V28         13.361963   0.105050  127.197   <2e-16 ***
V29         13.257553   0.059211  223.905   <2e-16 ***
V30         11.774925   0.041322  284.958   <2e-16 ***
V106         0.012364   0.029585    0.418    0.676    
V150        -0.117295   0.099411   -1.180    0.238    
V170         0.022119   0.038805    0.570    0.569    
V194         0.001614   0.034374    0.047    0.963    
V462        -0.021618   0.027503   -0.786    0.432    
V743         0.024642   0.033463    0.736    0.462    
V745         0.008998   0.058456    0.154    0.878    
V754         0.014578   0.035224    0.414    0.679    
V843         0.025185   0.025732    0.979    0.328    
V985         0.039353   0.037452    1.051    0.294    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 1.272 on 960 degrees of freedom
Multiple R-squared:  0.9994,	Adjusted R-squared:  0.9994 
F-statistic: 4.204e+04 on 39 and 960 DF,  p-value: < 2.2e-16

Comparison Plots


In [6]:
# Alternate methods:
#    OLS of Y on T
fit.simple <- lm('Y.ab ~ T', data = df)
#summary(fit.simple)
T.simple <- fit.simple$coefficients[2]
ci.simple <- confint(fit.simple, 'T', level = 0.95)

#    OLS of Y on T,controlling for all of X
sum.X <- paste(X.colnames, collapse = " + ")
eq.control.all <- paste("Y.ab ~ T + ", sum.X)
fit.allX <- lm(eq.control.all, data = df)
#summary(fit.allX)
T.allX <- fit.allX$coefficients[2]
ci.allX <- confint(fit.allX, 'T', level = 0.95)

#    OLS of Y on T,controlling for (almost) all of X
sum.Xmost <- paste(X.colnames[0:(N-10)], collapse = " + ")
eq.control.almost <- paste("Y.ab ~ T + ", sum.Xmost)
fit.mostX <- lm(eq.control.almost, data = df)
#summary(fit.allX)
T.mostX <- fit.mostX$coefficients[2]
ci.mostX <- confint(fit.mostX, 'T', level = 0.95)

#    OLS of Y on T, with a subset of X
sum.X.subset <- paste(c(X.colnames[5:15], X.colnames[80:90]), collapse = " + ")
eq.control.subset <- paste("Y.ab ~ T + ", sum.X.subset)
fit.subsetX <- lm(eq.control.subset, data = df)
#summary(fit.subsetX)
T.subsetX <- fit.subsetX$coefficients[2]
ci.subsetX <- confint(fit.subsetX, 'T', level = 0.95)

# Vector of T coefficients and confidence intervals
coefs.ab <- c(2.0, T.simple, T.allX, T.mostX, T.subsetX, T.double)
ci.low <- c(NaN, ci.simple[1], ci.allX[1], ci.mostX[1], ci.subsetX[1], ci.double[1])
ci.high <- c(NaN, ci.simple[2], ci.allX[2], ci.mostX[2], ci.subsetX[2], ci.double[2])
dat <- cbind(coefs.ab, ci.low, ci.high, c('True Effect', 'No Controls', 'All X', 'Largest Subset of X', 'Limited Controls', 'Double Selection'))
colnames(dat)[4] <- 'Model'
dat <- data.frame(dat)
dat$coefs.ab <- as.double(levels(dat$coefs.ab))[dat$coefs.ab]
dat$ci.low <- as.double(levels(dat$ci.low))[dat$ci.low]
dat$ci.high <- as.double(levels(dat$ci.high))[dat$ci.high]
dat$Model <- factor(dat$Model, levels = dat$Model)

# Create bar graph
dodge <- position_dodge(width = 0.9)
limits <- aes(ymax = dat$ci.high,
              ymin = dat$ci.low)
p <- ggplot(data = dat, aes(x = Model, y = coefs.ab, fill = Model))
p + geom_bar(stat = "identity", position = dodge) +
  geom_errorbar(limits, position = dodge, width = 0.25) +
    ylab("Coefficient on Treatment") +
    ggtitle("Estimated Causal Effect of T on Y, for various models") +
    theme(axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank())


Warning message in qt(a, object$df.residual):
"NaNs produced"Warning message in data.row.names(row.names, rowsi, i):
"some row.names duplicated: 3,4,5,6 --> row.names NOT used"Warning message:
"Removed 2 rows containing missing values (geom_errorbar)."

RDD

Use LASSO of Y on X to select H


In [7]:
lasso.fit.outcome <- cv.glmnet(data.matrix(X), df$Y.rdd, alpha=1)
coef <- predict(lasso.fit.outcome, type = "nonzero")
H <- X.colnames[unlist(coef)]
# Variables selected by LASSO:
H


  1. 'V2'
  2. 'V3'
  3. 'V4'
  4. 'V6'
  5. 'V7'
  6. 'V8'
  7. 'V9'
  8. 'V10'
  9. 'V11'
  10. 'V12'
  11. 'V13'
  12. 'V14'
  13. 'V15'
  14. 'V16'
  15. 'V17'
  16. 'V18'
  17. 'V19'
  18. 'V20'
  19. 'V21'
  20. 'V22'
  21. 'V23'
  22. 'V24'
  23. 'V25'
  24. 'V26'
  25. 'V27'
  26. 'V28'
  27. 'V29'
  28. 'V30'
  29. 'V48'
  30. 'V80'
  31. 'V81'
  32. 'V126'
  33. 'V133'
  34. 'V186'
  35. 'V243'
  36. 'V307'
  37. 'V381'
  38. 'V419'
  39. 'V679'
  40. 'V680'
  41. 'V745'
  42. 'V787'
  43. 'V832'
  44. 'V835'
  45. 'V880'
  46. 'V987'
  47. 'V999'
  48. 'V1043'
  49. 'V1049'

Use LASSO of W on X to select K


In [8]:
lasso.fit.propensity <- cv.glmnet(data.matrix(X), df$W, alpha=1)
coef <- predict(lasso.fit.propensity, type = "nonzero")
K <- X.colnames[unlist(coef)]
# Variables selected by LASSO:
K


  1. 'V2'
  2. 'V80'
  3. 'V81'
  4. 'V1049'

Perform RDD of Y on W, Controlling for H union K


In [9]:
# Union of selected variables:
H_union_K.names <- unique(c(H, K))
H_union_K.names
sum.H_union_K <- paste(H_union_K.names, collapse = " + ")
eq.H_union_K <- paste("Y.rdd ~ W | ", sum.H_union_K)

# RDD, using all covariates selected by double selection
fit.rdd <- RDestimate(eq.H_union_K, data = df)
summary(fit.rdd)


  1. 'V2'
  2. 'V3'
  3. 'V4'
  4. 'V6'
  5. 'V7'
  6. 'V8'
  7. 'V9'
  8. 'V10'
  9. 'V11'
  10. 'V12'
  11. 'V13'
  12. 'V14'
  13. 'V15'
  14. 'V16'
  15. 'V17'
  16. 'V18'
  17. 'V19'
  18. 'V20'
  19. 'V21'
  20. 'V22'
  21. 'V23'
  22. 'V24'
  23. 'V25'
  24. 'V26'
  25. 'V27'
  26. 'V28'
  27. 'V29'
  28. 'V30'
  29. 'V48'
  30. 'V80'
  31. 'V81'
  32. 'V126'
  33. 'V133'
  34. 'V186'
  35. 'V243'
  36. 'V307'
  37. 'V381'
  38. 'V419'
  39. 'V679'
  40. 'V680'
  41. 'V745'
  42. 'V787'
  43. 'V832'
  44. 'V835'
  45. 'V880'
  46. 'V987'
  47. 'V999'
  48. 'V1043'
  49. 'V1049'
Call:
RDestimate(formula = eq.H_union_K, data = df)

Type:
sharp 

Estimates:
           Bandwidth  Observations  Estimate  Std. Error  z value  Pr(>|z|) 
LATE       2.514      403           1.209     0.2502      4.831    1.358e-06
Half-BW    1.257      218           1.353     0.3349      4.040    5.353e-05
Double-BW  5.028      727           1.226     0.1821      6.730    1.694e-11
              
LATE       ***
Half-BW    ***
Double-BW  ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

F-statistics:
           F      Num. DoF  Denom. DoF  p
LATE       13880  52        350         0
Half-BW     7309  52        165         0
Double-BW  26782  52        674         0

In [ ]: