This notebook contains analysis related to a paper on marriage patterns in the U.S., based on data from the National Survey of Family Growth (NSFG).
It is based on Chapter 13 of Think Stats, 2nd Edition, by Allen Downey, available from thinkstats2.com
In [1]:
%matplotlib inline
import pandas as pd
import numpy as np
import seaborn as sns
import math
import matplotlib.pyplot as plt
from matplotlib import pylab
from scipy.interpolate import interp1d
from scipy.misc import derivative
import thinkstats2
import thinkplot
from thinkstats2 import Cdf
import survival
import marriage
In [2]:
%time df = pd.read_hdf('FemMarriageData.hdf', 'FemMarriageData')
df.shape
Out[2]:
Make a table showing the number of respondents in each cycle:
In [3]:
df.cycle.value_counts().sort_index()
Out[3]:
In [4]:
def format_date_range(array):
a, b = array.astype(int)
return '%d--%d' % (a, b)
def SummarizeCycle(cycle, df):
ages = df.age.min(), df.age.max()
ages= np.array(ages)
intvws = df.cmintvw.min(), df.cmintvw.max()
intvws = np.array(intvws) / 12 + 1900
births = df.cmbirth.min(), df.cmbirth.max()
births = np.array(births) / 12 + 1900
intvw_dates = format_date_range(intvws)
birth_dates = format_date_range(births)
print(cycle, ' & ', intvw_dates, '&', len(df), '&', birth_dates, r'\\')
In [5]:
for cycle, group in df.groupby('cycle'):
SummarizeCycle(cycle, group)
Check for missing values in agemarry
:
def CheckAgeVars(df): print(sum(df[df.evrmarry].agemarry.isnull()))
for cycle, group in df.groupby('cycle'): CheckAgeVars(group)
Generate a table with the number of respondents in each cohort:
In [6]:
marriage.DigitizeResp(df)
grouped = df.groupby('birth_index')
for name, group in iter(grouped):
age_range = '%d--%d' % (int(group.age.min()), int(group.age_index.max()))
print(name, '&', len(group), '&', age_range,
'&', len(group[group.evrmarry]), '&', sum(group.missing), r'\\')
In [7]:
def ComputeCutoffs(df):
grouped = df.groupby('birth_index')
cutoffs = {}
for name, group in sorted(grouped):
cutoffs[name] = int(group.age.max())
return cutoffs
In [8]:
cutoffs = ComputeCutoffs(df)
cutoffs
Out[8]:
Estimate the hazard function for the 80s cohort (curious to see what's going on during the "marriage strike")
In [9]:
cohort = grouped.get_group(80)
missing = (cohort.evrmarry & cohort.agemarry.isnull())
cohort = cohort[~missing]
complete = cohort[cohort.evrmarry].agemarry_index
ongoing = cohort[~cohort.evrmarry].age_index
hf = survival.EstimateHazardFunction(complete, ongoing, verbose=True)
Run the same analysis for the 70s cohort (to extract $\lambda(33)$).
In [10]:
cohort = grouped.get_group(70)
missing = (cohort.evrmarry & cohort.agemarry.isnull())
cohort = cohort[~missing]
complete = cohort[cohort.evrmarry].agemarry_index
ongoing = cohort[~cohort.evrmarry].age_index
hf = survival.EstimateHazardFunction(complete, ongoing, verbose=True)
Use the 30s cohort to demonstrate the simple way to do survival analysis, by computing the survival function directly.
In [11]:
cohort = grouped.get_group(30)
sf = survival.MakeSurvivalFromSeq(cohort.agemarry_index.fillna(np.inf))
ts, ss = sf.Render()
print(ss)
thinkplot.Plot(ts, ss)
thinkplot.Config(xlim=[12, 42])
Then use the SurvivalFunction to compute the HazardFunction:
In [12]:
hf = sf.MakeHazardFunction()
ts, lams = hf.Render()
print(lams)
thinkplot.Plot(ts, lams)
thinkplot.Config(xlim=[12, 42])
Make the first figure, showing sf and hf for the 30s cohort:
In [13]:
options = dict(formats=['pdf', 'png'], clf=False)
In [14]:
thinkplot.PrePlot(rows=2)
thinkplot.Plot(sf, label='survival')
thinkplot.Config(xlim=[13, 41], ylim=[0, 1.05])
plt.ylabel('Survival Function')
thinkplot.SubPlot(2)
thinkplot.Plot(hf, label='hazard')
thinkplot.Config(xlabel='age(years)', ylabel='Hazard function', xlim=[13, 41])
plt.ylabel('Hazard Function')
plt.xlabel('Age (years)')
thinkplot.Save(root='figs/marriage1', **options)
In [15]:
thinkplot.Plot(sf, label='30s')
thinkplot.Config(xlim=[13, 41], ylim=[0, 1.05])
plt.xlabel('Age (years)', fontsize=14)
plt.ylabel('Survival function', fontsize=14)
thinkplot.Save(root='figs/marriage2', **options)
In [16]:
thinkplot.Plot(hf, label='30s')
thinkplot.Config(xlim=[13, 41])
plt.xlabel('Age (years)', fontsize=14)
plt.ylabel('Hazard function', fontsize=14)
thinkplot.Save(root='figs/marriage3', **options)
Make some pivot tables, just to see where the data are:
In [17]:
pt = df.pivot_table(index='birth_index', columns='age_index', values='age', aggfunc=len, fill_value=0)
pt
Out[17]:
The following pivot table is not as helpful as it could be, since it doesn't show the number at risk.
In [18]:
df.pivot_table(index='birth_index', columns='agemarry_index', values='age', aggfunc=len, fill_value=0)
Out[18]:
Estimate the survival curve for each cohort:
In [19]:
df['complete'] = df.evrmarry
df['complete_var'] = df.agemarry_index
df['ongoing_var'] = df.age_index
df['complete_missing'] = df.complete & df.complete_var.isnull()
df['ongoing_missing'] = ~df.complete & df.ongoing_var.isnull()
In [20]:
# for some marriages, we don't have the date of marriage
for cycle, group in df.groupby('cycle'):
print(cycle, sum(group.complete_missing), sum(group.ongoing_missing))
In [21]:
resps = [group for cycle, group in df.groupby('cycle')]
iters = 101
In [22]:
%time sf_map = marriage.EstimateSurvivalByCohort(resps, iters=iters, cutoffs=cutoffs)
In [23]:
del sf_map[30]
try:
del sf_map[100]
except KeyError:
pass
Check a sample:
In [24]:
for sf in sf_map[90]:
print(sf.ss)
print(sf.Prob(34))
break
In [25]:
for sf in sf_map[80]:
print(sf.ss)
print(sf.Prob(34))
break
Make the figure showing estimated survival curves:
In [26]:
def PlotSurvivalFunctions(root, sf_map, sf_map_pred=None, **options):
if sf_map_pred:
marriage.PlotSurvivalFunctions(sf_map_pred, predict_flag=True)
marriage.PlotSurvivalFunctions(sf_map)
thinkplot.config(xlabel='Age (years)',
ylabel='Fraction never married',
xlim=[13, 50],
ylim=[0, 1.05],
loc='upper right',
frameon=False,
**options)
plt.tight_layout()
thinkplot.save(root=root, formats=['pdf', 'png'])
In [27]:
def set_palette(*args, **kwds):
"""Set the matplotlib color cycler.
args, kwds: same as for sns.color_palette
Also takes a boolean kwd, `reverse`, to indicate
whether the order of the palette should be reversed.
returns: list of colors
"""
reverse = kwds.pop('reverse', False)
palette = sns.color_palette(*args, **kwds)
palette = list(palette)
if reverse:
palette.reverse()
cycler = plt.cycler(color=palette)
plt.gca().set_prop_cycle(cycler)
return palette
In [56]:
def draw_age_lines(ages):
for age in ages:
plt.axvline(age, color='gray', linestyle='dotted', alpha=0.3)
In [57]:
palette = set_palette('hls', 6)
draw_age_lines(ages)
options_w = dict(title='Women in the U.S. by decade of birth')
PlotSurvivalFunctions('figs/marriage4', sf_map, None, **options_w)
Make a table of marriage rates for each cohort at each age:
In [29]:
def MakeTable(sf_map, ages):
t = []
for name, sf_seq in sorted(sf_map.items()):
ts, ss = marriage.MakeSurvivalCI(sf_seq, [50])
ss = ss[0]
vals = [np.interp(age, ts, ss, right=np.nan) for age in ages]
t.append((name, vals))
return t
In [30]:
def MakePercentageTable(sf_map, ages):
"""Prints percentage unmarried for each cohort at each age.
"""
t = MakeTable(sf_map, ages)
for name, sf_seq in sorted(sf_map.items()):
ts, ss = marriage.MakeSurvivalCI(sf_seq, [50])
ss = ss[0]
vals = [np.interp(age, ts, ss, right=np.nan) for age in ages]
print(name, '&', ' & '.join('%0.0f' % (val*100) for val in vals), r'\\')
ages = [24, 34, 44]
MakePercentageTable(sf_map, ages=ages)
Generate projections:
In [31]:
%time sf_map_pred = marriage.EstimateSurvivalByCohort(resps, iters=iters, cutoffs=cutoffs, predict_flag=True)
del sf_map_pred[30]
try:
del sf_map[100]
except KeyError:
pass
In [32]:
for cohort, seq in sf_map_pred.items():
if cohort > 90:
break
medians = [sf.MakeCdf().Value(0.5) for sf in seq]
print(cohort, np.median(medians))
And make the figure showing projections:
In [58]:
palette = set_palette('hls', 6)
draw_age_lines(ages)
PlotSurvivalFunctions('figs/marriage5', sf_map, sf_map_pred, **options_w)
Make the table again with the projections filled in.
In [34]:
MakePercentageTable(sf_map_pred, ages)
In [35]:
def PlotFractions(sf_map, ages, label_flag=False, **options):
t = MakeTable(sf_map, ages)
cohorts, cols = zip(*t)
rows = zip(*cols)
thinkplot.PrePlot(3)
t = list(zip(ages, rows))
for age, row in reversed(t):
label = 'at age %d' % age if label_flag else ''
thinkplot.Plot(cohorts, row, label=label, **options)
In [36]:
PlotFractions(sf_map_pred, ages, color='gray', linestyle='dashed', linewidth=2)
PlotFractions(sf_map, ages, label_flag=True, alpha=1)
fontsize=12
thinkplot.Text(36, 0.26, '24', fontsize=fontsize)
thinkplot.Text(37, 0.13, '9', fontsize=fontsize)
thinkplot.Text(37, 0.07, '7', fontsize=fontsize)
thinkplot.Text(90, 0.85, '80', fontsize=fontsize)
thinkplot.Text(90, 0.56, '51', fontsize=fontsize)
thinkplot.Text(89.5, 0.47, '42', fontsize=fontsize)
thinkplot.Text(80, 0.42, '35', fontsize=fontsize)
thinkplot.Text(70, 0.18, '18', fontsize=fontsize)
thinkplot.Config(xlim=[34, 97], ylim=[0, 1], legend=True, loc='upper left',
xlabel='cohort (decade)', ylabel='Fraction unmarried',
title='Women in the U.S.')
thinkplot.Save(root='figs/marriage6', **options)
In [37]:
%time df2 = pd.read_hdf('MaleMarriageData.hdf', 'MaleMarriageData')
df2.shape
Out[37]:
In [38]:
for cycle, group in df2.groupby('cycle'):
SummarizeCycle(cycle, group)
In [39]:
sum(df2.missing)
Out[39]:
In [40]:
marriage.DigitizeResp(df2)
grouped = df2.groupby('birth_index')
for name, group in iter(grouped):
age_range = '%d--%d' % (int(group.age.min()), int(group.age_index.max()))
print(name, '&', len(group), '&', age_range,
'&', len(group[group.evrmarry]), '&', sum(group.missing), r'\\')
In [41]:
cutoffs2 = ComputeCutoffs(df2)
cutoffs2
Out[41]:
In [42]:
resps2 = [group for cycle, group in df2.groupby('cycle')]
In [43]:
%time sf_map_male = marriage.EstimateSurvivalByCohort(resps2, iters=iters, cutoffs=cutoffs2)
del sf_map_male[100]
In [59]:
palette = set_palette('hls', 6)
draw_age_lines(ages)
options_m = dict(title='Men in the U.S. by decade of birth')
PlotSurvivalFunctions('figs/marriage7', sf_map_male, None, **options_m)
In [45]:
%time sf_map_male_pred = marriage.EstimateSurvivalByCohort(resps2, iters=iters, cutoffs=cutoffs2, predict_flag=True)
del sf_map_male_pred[100]
In [46]:
for cohort, seq in sf_map_male_pred.items():
if cohort > 90:
break
medians = [sf.MakeCdf().Value(0.5) for sf in seq]
print(cohort, np.median(medians))
In [60]:
palette = set_palette('hls', 6)
draw_age_lines(ages)
PlotSurvivalFunctions('figs/marriage8', sf_map_male, sf_map_male_pred, **options_m)
In [48]:
MakePercentageTable(sf_map_male, ages)
In [49]:
MakePercentageTable(sf_map_male_pred, ages)
In [50]:
PlotFractions(sf_map_male_pred, ages, color='gray', linestyle='dashed', linewidth=2)
PlotFractions(sf_map_male, ages, label_flag=True, alpha=1)
fontsize=12
thinkplot.Text(46, 0.69, '68', fontsize=fontsize)
thinkplot.Text(46, 0.30, '26', fontsize=fontsize)
thinkplot.Text(46, 0.20, '18', fontsize=fontsize)
thinkplot.Text(70, 0.18, '19', fontsize=fontsize)
thinkplot.Text(80, 0.43, '43', fontsize=fontsize)
thinkplot.Text(90, 0.89, '86', fontsize=fontsize)
thinkplot.Text(90, 0.56, '52', fontsize=fontsize)
thinkplot.Text(90, 0.40, '38', fontsize=fontsize)
thinkplot.Config(xlim=[34, 97], ylim=[0, 1], legend=True, loc='upper left',
xlabel='cohort (decade)', ylabel='Fraction unmarried',
title='Men in the U.S.')
thinkplot.Save(root='figs/marriage9', **options)
In [ ]: