Copyright (C) 2017 Constantine Savenkov
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
In [1]:
import pandas as pd
import numpy as np
In [2]:
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')
answers = pd.read_csv('gender_submission.csv')
test.insert(1, 'Survived', answers['Survived'])
dataframe = train.append(test, ignore_index=True)
del train, test, answers
In [3]:
dataframe.head()
Out[3]:
In [4]:
from re import split
# split_name split full name into surname, title and name.
# Example:
# Input: 'Kelly, Mr. James'
# Output: ['Kelly', 'Mr', 'James']
def split_name(name):
return split(r'[,.]', name.replace(' ', ''), maxsplit=2)
# unwrap_names unwrap full names into new data frame.
# Example:
# Input: ['Kelly, Mr. James', ..., 'Wirz, Mr. Albert']
# Output: pd.DataFrame({'Surname': [Kelly, ..., Wirz ],
# 'Title' : [Mr , ..., Mr ],
# 'Name' : [James, ..., Albert]})
def unwrap_names(names):
surname, title, name = [], [], []
for full_name in map(split_name, names):
surname.append(full_name[0])
title .append(full_name[1])
name .append(full_name[2])
return pd.DataFrame({'Surname': surname,
'Title' : title,
'Name' : name})
In [5]:
unwrapped_names = unwrap_names(dataframe['Name'].values)
dataframe.drop(['PassengerId', 'Name'], axis=1, inplace=True)
dataframe = pd.concat([unwrapped_names, dataframe], axis=1)
dataframe['Sex'].replace(to_replace=['male', 'female'], value=[0, 1], inplace=True)
In [6]:
dataframe.head()
Out[6]:
In [7]:
def entropy(probs):
return -sum(probs[:] * np.log(probs[:]))
def get_probabilities(data):
val_freq = dict()
for value in data:
if value not in val_freq:
val_freq[value] = 1
else:
val_freq[value] += 1
return [val_freq[value] / data.shape[0] for value in val_freq]
def calculate_entropy(data):
return entropy(get_probabilities(data))
In [8]:
def show_unique_data(dataframe, col):
a = len(dataframe[col].unique())
b = dataframe.shape[0]
c = round((a / b) * 100, 2)
e = calculate_entropy(dataframe[col])
print('Unique', col, '\t:', a, '\t/', b, '=', c, '\t% | Entropy:', e)
show_unique_data(dataframe, col='Surname')
show_unique_data(dataframe, col='Title' )
show_unique_data(dataframe, col='Name' )
show_unique_data(dataframe, col='Ticket' )
In [9]:
dataframe.drop(['Surname', 'Name', 'Ticket'], axis=1, inplace=True)
In [10]:
dataframe.head()
Out[10]:
In [11]:
pd.isnull(dataframe).sum()
Out[11]:
In [12]:
titles = dataframe['Title'].unique()
lst_s, lst_d, surv, died = [], [], [], []
for i in range(len(titles)):
title_groups = dataframe.loc[dataframe['Title'] == titles[i]]
surv.append(title_groups.loc[dataframe['Survived'] == 1])
died.append(title_groups.loc[dataframe['Survived'] == 0])
lst_s.append(surv[i]['Age'].mean())
lst_d.append(died[i]['Age'].mean())
if np.isnan(lst_s[i]):
lst_s[i] = 0
if np.isnan(lst_d[i]):
lst_d[i] = 0
In [13]:
import matplotlib.pyplot as plt
ind = np.arange(len(titles))
width = 0.25
fig = plt.figure(figsize=(20, 13))
ax = fig.add_subplot(111)
rects1 = ax.bar(ind, lst_s, width, color='b')
rects2 = ax.bar(ind + width, lst_d, width, color='r')
ax.set_xlim(-width, len(ind) - 2*width)
ax.set_ylabel('Mean Age')
ax.set_xticklabels(titles)
ax.set_xticks(ind + width - 0.125)
ax.legend( (rects1[0], rects2[0]), ('Survived', 'Died') )
def autolabel(rects):
for rect in rects:
h = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2, h, '%d' % int(h),
ha='center', va='bottom')
autolabel(rects1)
autolabel(rects2)
plt.show()
In [14]:
# fill nans with gauss. one title - one gauss
In [15]:
import random as rd
a = (3, (2147483648, 1851096110, 2866142454, 2213168853, 2966376414, 2570123344, 1643847870, 3373511583, 361183140, 3069980908, 3636819335, 253239505, 1684944788, 3708537091, 3917540782, 483541571, 12963668, 1510925890, 3869002024, 2295777454, 2517400042, 1231147172, 293668687, 3284055208, 1275802902, 4210188983, 3324165861, 810397729, 1736959953, 1607355422, 1790074471, 523531271, 2525950907, 196564544, 4100046118, 1861426332, 1058643867, 2739850505, 1257838175, 3201235128, 1122154264, 515034481, 1698407528, 2979783520, 4039517870, 3991424769, 2178234569, 1529750483, 3800438920, 3825342818, 3370945083, 255516219, 1296178156, 3660020804, 2403055490, 3106490304, 2731761528, 3572369258, 1077848335, 598782711, 1915187406, 2899262893, 1623376048, 513352569, 802695695, 3262223229, 2617218553, 2373618481, 927315393, 1612731212, 1215228152, 1259867251, 4046013551, 1465964379, 2065896275, 476619040, 2090849094, 1207685863, 2799072133, 3505568762, 399644573, 2750542585, 379414458, 49845401, 2702551842, 3592150869, 2969666520, 965689978, 373439941, 3733949036, 3910438904, 1556287377, 3174627080, 1830416034, 3620516674, 3303347409, 1417941999, 3575627710, 520404421, 2187152975, 3160509627, 3242698107, 2985596868, 4045833491, 3969232631, 3649830173, 3482792659, 228464335, 2594280737, 2531046782, 3424142655, 1463838103, 2577472740, 717492212, 1864459576, 3315298693, 860618256, 926218304, 978192224, 3190123887, 2103959098, 2433114975, 2870394674, 1575796004, 2345824609, 2590904042, 3135301217, 1934139310, 489978034, 807666185, 6198185, 1415219703, 1393357770, 4013303514, 320568052, 367187709, 68384051, 838337513, 3093960785, 1731576017, 2006658223, 163265980, 3537459368, 2352847384, 3734940595, 1596785745, 1227477689, 1595992791, 3308788244, 3157363196, 2342767054, 920536711, 2858842910, 3900828880, 2070914661, 282356093, 2978494016, 3889150260, 1646325634, 565437305, 1257723669, 812100453, 3716243518, 1547576279, 3275380123, 4075920914, 1636694649, 996686704, 1292015169, 1434940969, 4210157270, 3234964744, 3474995895, 2015362814, 1997140158, 444163970, 3040958786, 1122522000, 2694010072, 3519203842, 1427400119, 1103867560, 2140203324, 4246724675, 3967256062, 3610162365, 1134776907, 810489795, 140416414, 48146902, 1367698488, 1930214249, 4024699266, 1314202724, 151516031, 178868182, 975132952, 570201168, 2151213259, 619093549, 2534782806, 3949116174, 3235450187, 395723937, 3401578423, 3365288046, 2232111578, 1948280286, 334470981, 1825395609, 346790854, 1506452000, 2856114362, 1898274816, 2426532148, 3763900684, 660018356, 2082647080, 1106825082, 1145493737, 3786046995, 721753713, 1706412558, 3665992710, 3728569966, 3763802596, 4221653099, 874238897, 2804311013, 2173964574, 3718342184, 721929199, 1837325800, 2867228568, 1512553016, 2934846378, 402340193, 3304615674, 2443107883, 382867916, 3520120155, 3308529562, 494991602, 3269024639, 1102518009, 1017651323, 2576925203, 688467332, 903302392, 3918416639, 2937546826, 2355114688, 2369088610, 2708178851, 3335341503, 1550179514, 1434093098, 2109857984, 1431966022, 2277733043, 4203924888, 173332474, 3631517909, 360122722, 1922029588, 4042065610, 1562463235, 4184639225, 2325763455, 231293130, 4064563100, 1295399897, 2203100826, 4089594839, 1329537548, 3770845255, 1018162052, 1930164032, 3319374071, 2897003899, 3952551522, 2244200809, 3016050003, 1385774615, 2109951023, 2287636568, 1160699074, 1499369569, 3175500101, 2836786697, 293295834, 1462670785, 513081501, 3623937853, 893046559, 1088997512, 1424573124, 74222180, 1059420964, 3255659025, 3994342006, 1221741356, 1178645963, 2687547604, 1095362291, 1159995337, 3843035753, 4232435832, 692214755, 67145351, 4077388299, 1450006464, 1463895000, 1712801339, 687410850, 255348805, 2437575397, 3750485616, 4286770999, 690917800, 3840352946, 3801005538, 3088232912, 3998217216, 1451277982, 2724865526, 933800497, 3762698500, 771849517, 3544472781, 20590831, 2055487270, 3892669907, 474881811, 1022474553, 1266953327, 1334847368, 1990790832, 3787673331, 3013604374, 300260481, 2792816160, 2604073031, 3612163102, 3370797141, 850072622, 3027011327, 1852347007, 3206579784, 1379037088, 3180197818, 2365640985, 3298404605, 1166908852, 99387529, 1911097757, 165914281, 3499164933, 2795792263, 2874449873, 1733671752, 1448009971, 3989084812, 115758874, 419256522, 3225152822, 2749470030, 3048821854, 2903637276, 2933823410, 4278668212, 347009877, 3328078451, 3550544448, 3595071986, 2586816878, 281606575, 2590292604, 810035454, 163086913, 1375235201, 3845477138, 3428143782, 4248159450, 3688247131, 1425303907, 2254140000, 970792606, 67136468, 2418110611, 3259343967, 3312536366, 3868037744, 2772460016, 3072247602, 2799575492, 3878856994, 2351914550, 413583233, 3451433732, 959988019, 2327262686, 1653453300, 344229475, 4239798473, 3266648260, 2776132514, 3629280166, 357734457, 520760266, 1873898163, 3432750638, 2339551008, 2033687963, 2733289682, 3765876541, 60110637, 4223430296, 1428771378, 3440330824, 3031579078, 3757613275, 1910038544, 1630697409, 1270850587, 3590995967, 731475498, 1262879721, 87888878, 4167918999, 2471838068, 4006692924, 3871793773, 81388916, 2038613557, 105286741, 588581665, 2547189435, 1829679123, 2689865512, 3463608759, 2363303283, 3483330925, 2817141274, 3308080009, 3173336352, 1820318949, 1758033120, 4183085892, 3415929440, 2300902855, 190592094, 2501798924, 843034062, 4007623538, 643124501, 638723647, 1674873218, 1001209983, 3060006866, 1166842004, 1842080074, 589014670, 4101739823, 3638660223, 3440143429, 1028052655, 3759702877, 1245712888, 2049230930, 3748272401, 1766267060, 3096319793, 1015457773, 2254384852, 1278448505, 772232066, 1505225651, 2401667252, 1738251827, 476802743, 259004363, 416323369, 3728622920, 2726950744, 2182259277, 1323931528, 1045758372, 3866213417, 375784818, 725259365, 3771062226, 56816842, 3383642409, 1939029599, 3938172910, 387180472, 3769924566, 463982306, 2927673395, 1258465483, 3524491614, 309782204, 2157923838, 4194512009, 2745259159, 1010978254, 805601070, 3400845721, 1702573363, 3233995859, 3101504047, 1145765763, 1267607357, 2183356969, 865932884, 1501723768, 131678027, 3308491029, 4043686617, 1713253551, 2396999722, 2485903608, 3345819704, 235879223, 2806767472, 1955311863, 2950779552, 2116493229, 3386283169, 3766983501, 292090566, 1939583653, 802015122, 348564500, 4043123438, 1085017301, 3133325061, 2476797483, 1707132836, 1562272277, 903989299, 1503051657, 1571353147, 1629018081, 1125663084, 3304977672, 1013481085, 2051062067, 908171141, 2706953072, 2972685580, 2403884518, 587434615, 253875000, 1104335682, 3505528706, 2322625435, 782887538, 17709396, 269745224, 349275670, 1707387597, 2247254750, 1690122194, 3104524065, 2466356487, 2090759605, 552898264, 2475849156, 2132950934, 2402131602, 114354437, 2064388775, 1058771957, 1311542068, 1263987498, 2379555943, 2897622409, 1374763661, 1303815319, 1216758141, 1080637683, 1183937886, 2743840792, 3356823178, 4106312040, 36271905, 2720404876, 517454891, 376219569, 2207126829, 3721775191, 2686221308, 610052558, 3449296193, 3791809490, 763746567, 2263164886, 35795338, 3635640757, 3162754875, 3303410065, 77015595, 41180192, 1200825208, 2207515568, 2014104617, 1433812313, 1924737785, 1452731281, 2622602865, 2625770785, 2409258955, 3263023456, 3596679101, 1132383820, 4171043513, 1089308150, 2478788414, 1789550697, 1808932299, 97630476, 1658849159, 624), None)
rd.setstate(a)
#with open('random_state.txt', 'w') as fout:
# fout.write(str(rd.getstate()))
In [16]:
from random import gauss
def get_gauss_config(age):
mu = np.mean(age)
sigma = np.sqrt(sum((age[:] - mu)**2) / len(age))
return mu, sigma
def make_gauss_config_dict(df):
titles = df['Title'].unique()
config = dict()
for i in range(len(titles)):
age_group = list(df.loc[df['Title'] == titles[i]]['Age'])
config[titles[i]] = get_gauss_config([age for age in age_group if not np.isnan(age)])
return config
def fill_age_nans(data):
titles_and_ages = make_gauss_config_dict(data)
for i in range(len(data['Age'])):
while 'nan' == str(data.loc[i, 'Age']):
mu, sigma = titles_and_ages[data.loc[i, 'Title']]
value = gauss(mu, sigma)
if value <= 0.0:
continue
data.loc[i, 'Age'] = value
return data
In [17]:
dataframe = fill_age_nans(dataframe)
In [18]:
dataframe[pd.isnull(dataframe.drop(['Cabin'], axis=1)).any(axis=1)]
Out[18]:
In [19]:
class3_mean_age = dataframe.loc[dataframe['Pclass'] == 3]['Fare'].mean()
dataframe['Fare'].fillna(class3_mean_age, inplace=True)
In [20]:
pd.isnull(dataframe).sum()
Out[20]:
In [21]:
loading_ports = np.sort(dataframe['Pclass'].unique())
class_1_passengers, class_2_passengers, class_3_passengers = [], [], []
for i in range(len(loading_ports)):
loading_port = dataframe.loc[dataframe['Pclass'] == loading_ports[i]]
class_1_passengers.append(loading_port.loc[dataframe['Embarked'] == 'C'].shape[0])
class_2_passengers.append(loading_port.loc[dataframe['Embarked'] == 'Q'].shape[0])
class_3_passengers.append(loading_port.loc[dataframe['Embarked'] == 'S'].shape[0])
In [22]:
ind = np.arange(len(loading_ports))
width = 0.2
fig = plt.figure(figsize=(20, 10))
ax = fig.add_subplot(111)
rects1 = ax.bar(ind, class_1_passengers, width, color='r')
rects2 = ax.bar(ind + width, class_2_passengers, width, color='g')
rects3 = ax.bar(ind + width + width, class_3_passengers, width, color='b')
ax.set_xlim(-width, len(ind) - 2*width)
ax.set_ylabel('Number of Passengers')
ax.set_xticklabels(loading_ports)
ax.set_xticks(ind + width - 0.0005)
ax.legend( (rects1[0], rects2[0], rects3[0]), ('C', 'Q', 'S') )
def autolabel(rects):
for rect in rects:
h = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2, h, '%d' % int(h),
ha='center', va='bottom')
autolabel(rects1)
autolabel(rects2)
autolabel(rects3)
plt.show()
In [23]:
dataframe[pd.isnull(dataframe.drop(['Cabin'], axis=1)).any(axis=1)]
Out[23]:
In [24]:
dataframe['Embarked'].fillna('S', inplace=True)
In [25]:
dataframe.drop(['Cabin'], axis=1, inplace=True)
In [26]:
pd.isnull(dataframe).sum()
Out[26]:
In [ ]:
In [27]:
from sklearn.feature_extraction import DictVectorizer as DV
train_cat = dataframe[['Embarked']].astype(str)
encoder = DV(sparse = False)
encoded_data = encoder.fit_transform(train_cat.T.to_dict().values())
tmp = pd.DataFrame({'1': encoded_data[:, 0],
'2': encoded_data[:, 1],
'3': encoded_data[:, 2]})
In [28]:
dataframe = pd.concat([dataframe.drop(['Embarked', 'Title'], axis=1), tmp], axis=1)
In [29]:
def split_to_train_and_test(df):
return df.loc[0 : 890], df.loc[891 : ]
train, test = split_to_train_and_test(dataframe)
def split_to_patterns_and_targets(data):
return data.drop('Survived', axis=1), data['Survived']
X_train, y_train = split_to_patterns_and_targets(train)
X_test , y_test = split_to_patterns_and_targets(test )
In [30]:
X_train.head()
Out[30]:
In [31]:
from sklearn.ensemble import RandomForestClassifier
classifier = RandomForestClassifier(n_estimators = 100,
random_state = 1912,
max_depth = 5,
min_samples_leaf = 2,
max_leaf_nodes = 12)
classifier.fit(X_train, y_train)
answers = classifier.predict(X_test)
In [32]:
from sklearn.metrics import precision_score, average_precision_score, accuracy_score, \
recall_score, roc_auc_score, f1_score, log_loss, mean_absolute_error, \
mean_squared_error, roc_curve
print('average_precision_score:', average_precision_score(answers, y_test))
print('accuracy_score :', accuracy_score (answers, y_test))
print('precision_score :', precision_score (answers, y_test))
print('recall_score :', recall_score (answers, y_test))
print('roc_auc_score :', roc_auc_score (answers, y_test))
print('log_loss :', log_loss (answers, y_test))
print('f1_score :', f1_score (answers, y_test))
print('mean_absolute_error :', mean_absolute_error (answers, y_test))
print('mean_squared_error :', mean_squared_error (answers, y_test))
print('roc_curve(fpr) :', roc_curve (answers, y_test)[0].mean())
print('roc_curve(tpr) :', roc_curve (answers, y_test)[1].mean())
In [33]:
predictions = pd.DataFrame({'PassengerId': range(892, 1310),
'Survived' : list(answers)})
In [34]:
predictions.to_csv('predictions.csv', sep = ',', index = False)