In [1]:
# Enable logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

A simple prediction workflow

Drain workflows consist of drain.step.Step objects. Take for example the drain.data.ClassificationData step:


In [2]:
import drain.data
data = drain.data.ClassificationData(target=True, n_samples=1000, n_features=100)

This step calls the sklearn.datasets.make_classification method to generate a dataset with a binary outcome. We can run the step by step by calling its execute method:


In [3]:
data.execute()


INFO:root:Running
	ClassificationData(n_features=100, n_samples=1000)
Out[3]:
{'X':            0         1         2         3         4         5         6   \
 0   -1.254740  1.568295  0.859270 -0.663371  0.356787 -0.638139  0.429711   
 1   -1.239955 -2.459949  0.264857  0.192786 -1.293184 -0.803018  1.693877   
 2   -0.485544  0.852361  2.351547 -0.438737 -1.006562 -0.383894  0.911859   
 3    1.407038 -0.930342  2.072200  0.270287  0.978798 -2.089569  0.444492   
 4    1.355452  0.994878  0.590868 -0.154992  1.435196  1.124019  1.160561   
 5    0.660716  2.346761 -1.220935 -1.001100 -0.943435  1.055353 -0.098777   
 6    0.821727  0.507412  0.020194  0.618511  0.420213 -0.889522 -0.822467   
 7    0.560941 -2.194889 -2.288383  0.137653 -0.608870 -0.321122 -2.200869   
 8    0.826213  0.502941  1.617814 -0.395790  1.123574 -0.477804  0.544032   
 9   -0.859297  0.507568 -0.173425  0.478020 -0.120310  1.010425 -0.465142   
 10  -0.181034 -1.480597  1.076036  0.297779  2.611382 -1.110131 -0.792575   
 11   0.363537 -0.978450 -0.600857 -0.646410 -0.098534 -0.409467  2.197751   
 12   0.993590  1.295459  0.795243 -0.014217 -0.997072  0.302877 -1.127455   
 13  -0.279930  0.991420  0.345867  0.862786 -0.991938  0.533575  0.366574   
 14  -1.081655 -0.440487 -1.961790 -0.609297 -1.387298 -1.895199 -0.994840   
 15   1.492068  0.157100  0.759088 -1.610228 -0.939882  1.123338 -0.938158   
 16  -0.101319 -0.058831 -0.097279  0.384583  0.054213  0.787314 -0.738299   
 17   0.634146 -1.511758  1.043389  1.037083 -0.694244 -2.113977 -0.735913   
 18   0.530260 -0.821921  2.889919  0.128075 -0.089201 -1.135251 -1.109900   
 19   0.383313  1.288586  0.630922 -0.278463 -0.576125  0.527565  0.155685   
 20   0.619015 -0.218414  0.381265 -0.450987  0.320198  0.551846 -0.760200   
 21   0.590202 -0.349106  0.104717 -0.094581 -0.812437  1.137110 -0.558482   
 22  -0.792744  0.625676 -1.219969 -1.086694  0.015019  0.654400 -0.093755   
 23   0.090251  0.573517 -0.392694 -0.900419 -0.632813  0.021403 -0.150404   
 24   1.595944  1.393577  1.196274 -1.040089 -0.301914  0.252948  0.449126   
 25  -1.010943  0.844200  2.811122 -1.014319 -0.454213 -0.135180  0.544227   
 26   0.611729  1.130267  1.297125 -0.278936 -0.696657 -0.399766  1.109386   
 27   2.030651 -0.197337  0.511531 -0.052648  0.374089 -0.184860 -0.391595   
 28   1.553720 -0.907226 -0.530793  0.408559  0.304652  1.444174 -0.447947   
 29  -0.132642  1.204381  0.321965  1.453492 -0.277177  0.360333  1.902768   
 ..        ...       ...       ...       ...       ...       ...       ...   
 970  1.427974  0.199819 -0.411418  1.498040  0.194273  1.383671 -0.863180   
 971 -0.375810  0.708675  0.472280  0.557209 -1.220460 -0.318778  1.554513   
 972  1.259005 -2.024813 -0.166037 -1.117739 -1.174340  0.000942 -0.189614   
 973 -0.082218  0.618799 -0.785486 -0.993933  0.083065  0.342710 -0.406177   
 974  0.573386  0.583677  0.456756  1.093332  0.440058 -0.662806 -0.879922   
 975  1.396891 -0.173375  0.316951  0.933275 -1.138031 -0.266611  0.359347   
 976 -0.802807  0.811760  0.315390  1.027934 -0.383143 -0.289900  0.564009   
 977  0.572079 -0.525685 -0.969152 -0.168889  0.134876  0.094651  0.706020   
 978 -1.010922 -0.137452  0.805744  1.052817  0.564659  0.023036  1.262760   
 979  1.662119 -1.446498  1.107810  0.856095  2.265605 -0.901305 -0.717911   
 980  1.399542 -2.517785  1.670611  0.993417 -0.638496 -1.284510  0.178300   
 981  0.856778  0.973065 -0.293322  1.148104 -1.982842 -0.395618  0.737356   
 982 -1.123017 -0.452388  0.966455  0.463871  2.083240 -0.282247  2.820722   
 983  1.020074  0.632911  0.326278 -1.382778  0.522737  0.360319 -1.327267   
 984 -0.399732 -1.112838  0.297023 -1.671867 -0.465824  0.433053  1.253564   
 985 -0.578588 -1.041905 -0.513800  0.593854 -0.945068 -0.228836 -0.230977   
 986 -1.600228  0.002392 -1.151025 -0.511848  0.847090  0.474955 -0.490107   
 987 -0.153341 -0.707808  1.036295  0.465421 -0.048909 -0.869606 -0.773781   
 988 -0.224466  0.819665  0.234562  1.742887 -0.733612  1.056107  0.077367   
 989  1.673015 -0.703000 -0.912050 -0.346017  1.415573  0.972629 -1.073629   
 990  0.375088 -2.157032 -0.962757 -0.016289  0.379584  0.724160 -0.729455   
 991  0.471328 -0.536188  1.377669 -0.419758 -0.565780  0.624298  0.002458   
 992 -0.817988 -1.055492  0.580356 -1.444357  1.006731 -1.187316  0.541494   
 993 -0.211229  1.535913  0.689260  0.766016  0.424826 -0.947540 -0.560268   
 994 -0.218786 -1.943115  0.389745 -0.177236 -0.244490 -0.232373  0.186143   
 995  0.197082  0.149167 -0.396804  0.409610  1.961823 -0.063628  0.163441   
 996 -1.145868 -1.140588  0.035049  0.941297  0.528853 -0.285471  1.943755   
 997 -0.436684 -0.539361 -0.377773 -1.406364 -0.884543  0.200418  1.271856   
 998 -0.408406 -0.058132 -0.210153  1.311349 -0.561259  0.054425  1.618026   
 999  1.086343 -0.507914  1.941163  1.794463 -0.025451  0.037245  0.399774   
 
            7         8         9     ...           90        91        92  \
 0   -0.317344  0.649796  0.435950    ...     1.692325 -0.800594  0.291897   
 1    1.397676  0.582380  1.845770    ...    -1.346165 -0.515097 -0.986927   
 2   -0.301803 -0.966162  0.635666    ...     0.821978  1.543093  1.440337   
 3    0.510186  0.612589 -0.647749    ...    -1.200731  0.286586 -0.019840   
 4    1.137675 -0.919961  0.457479    ...    -0.732394  0.842354  0.257231   
 5   -0.212021  1.854793  0.091716    ...    -1.147652  0.220629 -0.875983   
 6    0.677341 -0.694924 -0.439465    ...     0.945848 -1.082315 -0.465580   
 7   -1.135885  2.231805  1.238442    ...    -0.809847  1.044542  1.007487   
 8   -1.552876  0.425295  0.747201    ...     0.270946  0.278733 -1.100260   
 9   -0.616774 -0.215404  0.236690    ...    -0.622256 -0.221961 -0.197931   
 10  -1.365648 -1.447118  0.561395    ...     0.398376  1.077897 -0.321834   
 11  -0.629822 -2.140771  1.535139    ...    -1.900833  0.874642  0.567859   
 12   0.049380 -0.531111  1.098333    ...    -0.560815  0.077198  1.000120   
 13   1.579323  0.128987 -0.642869    ...    -2.101945  0.514262 -0.840356   
 14   0.498721 -0.079305  0.549055    ...     0.646890  1.128414  0.482017   
 15  -2.083341  0.622965 -0.990498    ...     1.337750  0.619105 -0.709034   
 16  -0.320129 -1.452961  0.194205    ...     0.583090  0.327822 -0.228025   
 17  -2.205644  0.176696  0.061404    ...     1.011616 -0.390736  0.571668   
 18  -0.224665 -1.211889  0.673980    ...     1.988059 -0.763888 -0.044132   
 19  -0.149820  0.534449  0.515180    ...    -0.380617 -1.285820  0.065645   
 20   2.469470 -0.365862 -0.489380    ...    -1.189763  0.235829  0.929609   
 21  -0.345152 -1.156759  0.511396    ...    -0.239752 -1.003209  1.015141   
 22  -1.563144 -0.783295  0.195164    ...    -0.646349  0.807345 -0.889528   
 23  -1.105089  0.125311 -0.441498    ...     0.317951 -1.686191  2.279539   
 24  -0.371268  0.716161  0.500363    ...     0.698846  0.263003  1.193694   
 25  -1.843775 -0.708209 -0.390841    ...     0.120863  1.577266  0.608153   
 26  -0.183411  0.943761 -0.539840    ...    -2.392502 -1.724615 -0.790046   
 27  -0.685548  0.658868 -0.129957    ...     0.789021  0.480574  0.485108   
 28   0.394668  0.927867  0.474594    ...     1.017249  0.233735 -0.865090   
 29   0.875193 -0.648333 -0.618216    ...     1.154525 -1.282005  0.426993   
 ..        ...       ...       ...    ...          ...       ...       ...   
 970  1.584451  0.898695  0.950156    ...    -1.198896 -1.500364  0.147468   
 971 -1.516772  0.942546 -0.306998    ...    -0.838506 -1.165612 -0.009636   
 972 -1.303565 -1.562939 -0.525488    ...    -0.811399 -0.908523  1.115758   
 973  0.923340 -0.346517  1.274084    ...     0.140783 -0.202236 -0.550950   
 974 -1.416193 -0.376203 -0.435894    ...    -0.377619  1.307645  0.067711   
 975 -0.720587  1.449158  0.140012    ...     0.779804  0.350238 -1.225247   
 976  0.880073 -1.031125  0.438087    ...     0.890204  0.179396 -2.010895   
 977 -2.360070  0.105652  1.200624    ...     1.799642  0.361594 -0.252854   
 978 -0.364856  1.518637  0.861621    ...     1.076710  0.410516 -0.179108   
 979  0.081220 -0.436517 -0.313258    ...    -0.378749 -0.365762  0.165117   
 980 -1.173198 -0.139173 -1.866092    ...     0.225117  0.652381 -1.784381   
 981 -0.109078 -0.928463 -1.327946    ...     1.048008 -0.455142 -1.138962   
 982 -1.072128  0.403995  0.863175    ...     0.399199  0.355422 -0.725796   
 983 -0.450487  0.572312 -0.441018    ...    -0.067267  1.245170 -0.174082   
 984  0.061257 -0.276346 -1.053986    ...     0.681992  0.289140  0.091927   
 985  0.158640 -0.543643  0.802841    ...    -1.078982  0.500068  0.456649   
 986 -1.074230  0.185078 -0.513902    ...    -0.231571  0.325234 -0.160442   
 987 -1.414793  0.466396 -0.789620    ...    -0.314411  2.439750 -0.029362   
 988 -1.895999 -1.315336  1.379091    ...     0.228738  1.346434 -1.150442   
 989 -0.991179  0.789371  0.982003    ...    -0.794618 -1.037488  0.548931   
 990  0.022600  0.354721 -0.526454    ...    -0.499464 -0.026963  0.854360   
 991  1.111153 -0.106708 -0.752531    ...     1.489186 -0.196638 -0.542070   
 992 -0.752222  0.103705  0.011176    ...    -0.270127  2.142073 -1.645180   
 993  0.839112 -0.945992 -0.722982    ...    -2.953682 -0.393865 -0.660461   
 994 -1.661020  0.616278 -1.296305    ...     1.159081  0.023756  0.302573   
 995  0.561612 -1.107007  0.915307    ...    -0.305321 -0.547401 -0.441435   
 996 -1.472651  0.024061 -0.866942    ...     0.612108  0.413757 -1.376710   
 997  0.520302  1.444479  0.158747    ...     0.195213  0.755401  0.130984   
 998  0.226956  0.317800 -0.299309    ...     0.930140  0.782400  0.267767   
 999  0.528914 -0.559269  1.971265    ...     0.799492 -0.659607 -0.256106   
 
            93        94        95        96        97        98        99  
 0   -0.453697 -0.178751 -0.607489  1.037445 -1.749688 -0.607751  0.357977  
 1    0.102077  1.496402 -0.646635 -0.417379 -0.302278  0.201746  1.240888  
 2    0.181615  0.403379 -0.271970  0.417755 -2.832542  0.968731 -1.794433  
 3    0.763633  0.140741 -0.008938  0.925198  0.656512 -0.018097 -0.824847  
 4   -0.763106 -1.773781  0.791774 -1.135354 -0.736330  0.631639 -0.533620  
 5   -0.411940 -0.822743 -1.680231 -1.000032  0.551767 -2.749766  0.482842  
 6    0.241179 -2.002802 -1.462483  2.470792  1.039311 -0.951879 -1.283762  
 7    0.518928 -0.603663  0.734483  2.258332 -0.339975 -0.028964  0.222319  
 8   -1.517364 -0.020546 -0.394022  1.323604  0.396881 -1.868242 -0.152760  
 9    0.881865  2.110516  0.054782  3.510844  0.186379  0.709308 -0.409238  
 10  -2.125828  1.675130 -0.137430  0.488037  0.817900 -0.906300  1.371698  
 11   0.283778  0.153505 -0.870782  0.128697 -1.057270 -0.460268  0.930424  
 12  -0.755670  1.283407 -0.748315  1.118766 -0.035667 -0.466580  0.525206  
 13   0.354591 -0.704720 -0.994755  0.346154 -1.992765 -0.043006  1.363481  
 14   0.056618  1.191785 -0.414376  1.003905 -0.273869 -1.568518  0.147835  
 15   0.460069 -0.235582  0.908541  1.604067 -0.144376 -0.360808 -1.438284  
 16   1.328314  0.417165  0.457554  1.241104 -0.243334  0.143554 -1.314256  
 17  -0.488609 -0.661437  2.468581 -1.194021 -1.317828  0.072991  0.580017  
 18   1.049466  1.329531  0.851378 -0.872630  0.569834  0.170138 -0.014328  
 19   1.793206  0.144106  1.687074  1.116450  1.026708  0.690760 -1.218642  
 20   0.187670 -1.265645  1.562305 -0.974613 -1.144446  1.858976 -0.304615  
 21   0.648009  1.479887  0.237014 -0.102091 -0.424306  1.105144 -0.792205  
 22   0.974942  1.950779 -0.075844 -1.002870 -0.325791 -0.823237 -0.953522  
 23  -1.527603 -0.706660  0.701941  1.497003  0.135767 -1.375465  0.360233  
 24   0.372183  0.251492  1.196615 -0.812162  0.812581 -0.590464 -0.097109  
 25   1.218969  0.994976  0.375506 -0.806442 -0.129328  0.944796 -0.001709  
 26  -0.357680  1.151953  0.490046 -1.087264  0.352795  0.227000  0.287900  
 27  -0.551473  1.765073 -0.341247 -2.017771  1.163188 -1.083413 -0.466292  
 28  -0.134374 -0.593132  1.456376 -1.980204 -1.131998  1.495568  0.937576  
 29   0.894516  0.276828 -0.061231  0.576754 -1.328837 -0.590606 -2.520944  
 ..        ...       ...       ...       ...       ...       ...       ...  
 970  0.028337 -1.988826 -0.077648 -0.918449  1.025052 -0.180295 -0.197297  
 971 -0.072396 -0.193522 -1.668951 -0.846518 -0.328954 -1.821519  1.059601  
 972  1.146515 -0.535107  1.639795 -1.297755  0.657985  0.863315 -0.990740  
 973 -0.969484 -0.485421 -1.656809  1.046791 -1.699698  0.708455 -0.156831  
 974  0.352147 -0.116751 -0.154777  0.960183  0.395414 -0.061042 -0.322580  
 975 -1.216062 -1.094429  2.290265 -0.128994  1.468547  0.114044  0.177380  
 976 -1.172977  1.258391  1.052376  0.796369  0.147978 -2.088108  0.608715  
 977 -0.812450  0.653344  0.481369  0.710590 -1.171188 -1.003806  0.124105  
 978 -0.487519 -0.390283  0.117963  0.377637  0.518965 -0.221915 -0.648131  
 979  0.318640  0.972550 -2.528320  1.189220 -0.507969 -0.093103  0.576090  
 980  0.514738  0.415845  1.662958 -0.438644  0.738547 -0.911023  0.609567  
 981 -0.382158 -0.035499  1.452173 -0.635157 -1.345215 -0.443101  0.377928  
 982 -0.540355  0.842463 -0.695370  0.916631  0.544897  0.032422  1.361608  
 983 -0.184585  1.064541  0.345514 -0.952976 -0.908573  2.086482 -1.229222  
 984  0.882285  0.440914  0.568569  1.774436 -0.690365 -0.088919 -0.078799  
 985  2.000569  1.672179 -0.086324  2.108751  0.610264  2.142055 -0.452081  
 986  0.728624 -1.659569  0.232187  1.553900  0.576782  0.217307  1.031375  
 987 -0.979757 -0.633185 -1.637807  0.558194  0.663296  0.903608  1.204646  
 988 -1.841558  0.198859  1.359167 -0.731091 -1.143660  1.205318 -0.058567  
 989 -1.051403  1.199599  1.590223  1.407460  0.131064  0.539702  0.463315  
 990  0.745798  1.746454 -0.634516 -0.038540 -1.046814  0.375747 -0.899132  
 991  0.049980  0.766890  0.012255 -1.300009 -1.902284  0.538790 -1.448589  
 992 -0.777510  0.112427  0.905064  0.365634  1.001759  0.830781 -0.570362  
 993  1.744678 -0.908464 -1.152979 -0.519524  0.032816 -0.640255  0.919545  
 994 -0.467906  0.267300 -0.775554 -0.915739 -0.877088  0.902944 -0.503776  
 995  2.032371  0.418168 -0.484227 -0.877188  1.578590 -0.809888  0.328456  
 996  0.477274 -0.387270  1.050338 -1.148986 -1.548766  0.354266  0.059252  
 997 -0.434422  0.889280  0.608129 -1.123611 -0.241521 -0.333948  0.764316  
 998  0.863556 -1.517694  1.218220  0.787357 -1.124006 -0.150231  0.058815  
 999  1.044219 -1.045599  1.871140 -0.047166  1.134148  3.001915  0.952727  
 
 [1000 rows x 100 columns], 'test': 0       True
 1       True
 2      False
 3       True
 4       True
 5      False
 6      False
 7       True
 8       True
 9       True
 10     False
 11      True
 12      True
 13      True
 14     False
 15     False
 16      True
 17      True
 18      True
 19      True
 20     False
 21     False
 22      True
 23     False
 24     False
 25      True
 26     False
 27     False
 28     False
 29      True
        ...  
 970     True
 971    False
 972     True
 973     True
 974    False
 975    False
 976     True
 977     True
 978     True
 979    False
 980     True
 981    False
 982     True
 983     True
 984     True
 985    False
 986     True
 987     True
 988     True
 989     True
 990    False
 991     True
 992     True
 993     True
 994    False
 995    False
 996    False
 997     True
 998    False
 999     True
 dtype: bool, 'train': 0      False
 1      False
 2       True
 3      False
 4      False
 5       True
 6       True
 7      False
 8      False
 9      False
 10      True
 11     False
 12     False
 13     False
 14      True
 15      True
 16     False
 17     False
 18     False
 19     False
 20      True
 21      True
 22     False
 23      True
 24      True
 25     False
 26      True
 27      True
 28      True
 29     False
        ...  
 970    False
 971     True
 972    False
 973    False
 974     True
 975     True
 976    False
 977    False
 978    False
 979     True
 980    False
 981     True
 982    False
 983    False
 984    False
 985     True
 986    False
 987    False
 988    False
 989    False
 990     True
 991    False
 992    False
 993    False
 994     True
 995     True
 996     True
 997    False
 998     True
 999    False
 dtype: bool, 'y': 0      0
 1      1
 2      0
 3      0
 4      1
 5      0
 6      0
 7      0
 8      0
 9      0
 10     0
 11     0
 12     0
 13     0
 14     0
 15     0
 16     0
 17     1
 18     1
 19     0
 20     1
 21     0
 22     1
 23     0
 24     1
 25     1
 26     1
 27     1
 28     1
 29     0
       ..
 970    1
 971    1
 972    1
 973    0
 974    0
 975    0
 976    0
 977    0
 978    0
 979    0
 980    1
 981    1
 982    0
 983    1
 984    0
 985    0
 986    0
 987    0
 988    1
 989    0
 990    0
 991    1
 992    0
 993    1
 994    1
 995    1
 996    1
 997    0
 998    0
 999    0
 dtype: int64}

The result is a dictionary containing a standard set of objects that drain uses for machine learning workflows:

  • X is a matrix of features, also called a design matrix,
  • y is a vector of outcomes
  • train is a binary vector indicating the rows of X which are in the training set
  • test is a binary vector indicating the rows of X which are in the test set

Let's add another step to our workflow to construct a random forest estimator:


In [4]:
import drain.model, drain.step
estimator = drain.step.Construct('sklearn.ensemble.RandomForestClassifier', n_estimators=1)

The Construct step is simply constructs an instance of the specified class with the given arguments:


In [5]:
estimator.execute()


INFO:root:Running
	Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1)
Out[5]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=1, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)

Next we add another step to fit this estimator on our previously generated dataset:


In [6]:
fit = drain.model.Fit(inputs=[estimator, data], return_estimator=True, return_feature_importances=True)

Note the special inputs argument. This argument is a collection of steps whose results Fit takes as input.


In [7]:
fit.execute()


INFO:root:Running
	Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False)
INFO:root:Fitting with 392 examples, 100 features
Out[7]:
{'estimator': RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
             max_depth=None, max_features='auto', max_leaf_nodes=None,
             min_samples_leaf=1, min_samples_split=2,
             min_weight_fraction_leaf=0.0, n_estimators=1, n_jobs=1,
             oob_score=False, random_state=None, verbose=0,
             warm_start=False), 'feature_importances':     feature  importance
 38       38    0.152378
 96       96    0.111899
 11       11    0.061071
 79       79    0.046286
 84       84    0.044512
 48       48    0.041138
 82       82    0.037886
 53       53    0.037471
 18       18    0.034083
 26       26    0.032706
 76       76    0.030827
 0         0    0.029203
 22       22    0.028594
 34       34    0.026507
 99       99    0.026000
 78       78    0.025110
 64       64    0.020386
 37       37    0.020012
 77       77    0.018746
 6         6    0.018398
 89       89    0.018124
 23       23    0.017522
 59       59    0.016851
 39       39    0.015019
 42       42    0.014395
 27       27    0.009911
 29       29    0.009858
 58       58    0.009626
 74       74    0.009402
 73       73    0.008307
 ..      ...         ...
 30       30    0.000000
 28       28    0.000000
 25       25    0.000000
 24       24    0.000000
 13       13    0.000000
 21       21    0.000000
 20       20    0.000000
 19       19    0.000000
 14       14    0.000000
 17       17    0.000000
 9         9    0.000000
 41       41    0.000000
 61       61    0.000000
 43       43    0.000000
 60       60    0.000000
 16       16    0.000000
 56       56    0.000000
 55       55    0.000000
 54       54    0.000000
 7         7    0.000000
 52       52    0.000000
 51       51    0.000000
 1         1    0.000000
 49       49    0.000000
 8         8    0.000000
 47       47    0.000000
 46       46    0.000000
 45       45    0.000000
 44       44    0.000000
 50       50    0.000000
 
 [100 rows x 2 columns]}

The Fit step returns the fitted estimator object as well as a dataframe containing the names of features and their importances.

Let's add one final step to our pipeline to generate predictions on the test set of our classification data:


In [8]:
predict = drain.model.Predict(inputs=[fit, data])
predict.execute()


INFO:root:Running
	Predict(inputs=[Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False), ClassificationData(n_features=100, n_samples=1000)],
	    prefit=True, return_estimator=False, return_feature_importances=False,
	    return_predictions=True)
INFO:root:Predicting 608 examples
Out[8]:
{'y':      true  score
 0       0      0
 1       1      1
 3       0      0
 4       1      1
 7       0      0
 8       0      0
 9       0      0
 11      0      0
 12      0      0
 13      0      0
 16      0      0
 17      1      0
 18      1      0
 19      0      1
 22      1      1
 25      1      1
 29      0      0
 33      0      0
 34      0      1
 35      0      0
 36      0      0
 38      1      0
 40      0      0
 43      1      1
 47      0      0
 48      1      1
 49      1      0
 50      0      0
 51      0      1
 56      1      1
 ..    ...    ...
 946     0      0
 948     0      1
 954     1      1
 957     1      0
 958     1      1
 959     0      0
 960     1      1
 961     0      1
 963     0      1
 964     1      1
 969     1      1
 970     1      0
 972     1      0
 973     0      0
 976     0      0
 977     0      0
 978     0      1
 980     1      0
 982     0      0
 983     1      1
 984     0      0
 986     0      1
 987     0      0
 988     1      1
 989     0      0
 991     1      0
 992     0      0
 993     1      0
 997     0      0
 999     0      1
 
 [608 rows x 2 columns]}

The Predict method returns a dataframe with a score column containing the predictions of the estimator and a true column containing the true outcomes.

The drain.model module contains a variety of metrics which can be run directly on the predict object:


In [9]:
drain.model.auc(predict)


Out[9]:
0.63894663894663895

In [10]:
drain.model.baseline(predict)


Out[10]:
0.48684210526315791

In [11]:
drain.model.precision(predict, k=10)


Out[11]:
0.80000000000000004

We can retrieve the results of any step that has been run through the get_result method:


In [12]:
predict.get_result()


Out[12]:
{'y':      true  score
 0       0      0
 1       1      1
 3       0      0
 4       1      1
 7       0      0
 8       0      0
 9       0      0
 11      0      0
 12      0      0
 13      0      0
 16      0      0
 17      1      0
 18      1      0
 19      0      1
 22      1      1
 25      1      1
 29      0      0
 33      0      0
 34      0      1
 35      0      0
 36      0      0
 38      1      0
 40      0      0
 43      1      1
 47      0      0
 48      1      1
 49      1      0
 50      0      0
 51      0      1
 56      1      1
 ..    ...    ...
 946     0      0
 948     0      1
 954     1      1
 957     1      0
 958     1      1
 959     0      0
 960     1      1
 961     0      1
 963     0      1
 964     1      1
 969     1      1
 970     1      0
 972     1      0
 973     0      0
 976     0      0
 977     0      0
 978     0      1
 980     1      0
 982     0      0
 983     1      1
 984     0      0
 986     0      1
 987     0      0
 988     1      1
 989     0      0
 991     1      0
 992     0      0
 993     1      0
 997     0      0
 999     0      1
 
 [608 rows x 2 columns]}

More on workflow execution

Let's redefine the above workflow using a function:


In [13]:
def prediction_workflow():
    # generate the data including a training and test split
    data = drain.data.ClassificationData(target=True, n_samples=1000, n_features=100)
    # construct a random forest estimator
    estimator = drain.step.Construct('sklearn.ensemble.RandomForestClassifier', n_estimators=1)
    # fit the estimator
    fit = drain.model.Fit(inputs=[estimator, data], return_estimator=True, return_feature_importances=True)
    # make predictions
    return drain.model.Predict(inputs=[fit, data])

In [14]:
predict2 = prediction_workflow()

Note that step execution is recursive, that is the execute method will ensure that all inputs, and inputs of inputs, etc. have been run before running the given step:


In [15]:
predict2.execute()


INFO:root:Running
	Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1)
INFO:root:Running
	ClassificationData(n_features=100, n_samples=1000)
INFO:root:Running
	Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False)
INFO:root:Fitting with 396 examples, 100 features
INFO:root:Running
	Predict(inputs=[Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False), ClassificationData(n_features=100, n_samples=1000)],
	    prefit=True, return_estimator=False, return_feature_importances=False,
	    return_predictions=True)
INFO:root:Predicting 604 examples
Out[15]:
{'y':      true  score
 0       0      1
 2       0      1
 3       1      1
 4       0      0
 8       0      1
 9       0      1
 12      1      1
 13      0      0
 14      1      0
 16      0      0
 21      1      0
 22      1      1
 23      1      0
 24      1      0
 25      0      0
 27      0      1
 30      1      1
 31      1      0
 33      1      0
 34      1      1
 35      0      0
 36      1      1
 37      1      0
 38      1      0
 43      1      0
 44      1      0
 45      0      1
 46      0      1
 48      0      1
 49      0      0
 ..    ...    ...
 962     1      1
 963     1      0
 964     0      0
 967     0      1
 968     1      0
 969     0      0
 970     1      1
 971     1      1
 972     0      0
 973     1      0
 974     0      0
 975     1      0
 976     1      1
 977     0      0
 978     1      1
 979     0      1
 983     1      0
 984     1      0
 985     1      0
 986     0      1
 987     0      0
 988     0      1
 989     1      1
 992     0      1
 993     0      0
 995     0      0
 996     0      0
 997     0      0
 998     0      1
 999     1      1
 
 [604 rows x 2 columns]}

The steps of a workflow form a network (a directed acyclic graph or DAG, to be precise).

A more complicated workflow

In practice we want to train many models on a given dataset. Let's define a workflow that searches over the number of trees in the random forest model:


In [16]:
def n_estimator_search():
    data = drain.data.ClassificationData(target=True, n_samples=1000, n_features=100)
    
    predict = []
    for n_estimators in range(1, 4):
        estimator = drain.step.Construct('sklearn.ensemble.RandomForestClassifier', n_estimators=n_estimators, name = 'estimator')
        fit = drain.model.Fit(inputs=[estimator, data], return_estimator=True, return_feature_importances=True)
        predict.append(drain.model.Predict(inputs=[fit, data]))
        
    return predict

In [17]:
predictions = n_estimator_search()

In [18]:
for p in predictions:
    p.execute()


INFO:root:Running
	Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1)
INFO:root:Running
	ClassificationData(n_features=100, n_samples=1000)
INFO:root:Running
	Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False)
INFO:root:Fitting with 400 examples, 100 features
INFO:root:Running
	Predict(inputs=[Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=1), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False), ClassificationData(n_features=100, n_samples=1000)],
	    prefit=True, return_estimator=False, return_feature_importances=False,
	    return_predictions=True)
INFO:root:Predicting 600 examples
INFO:root:Running
	Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=2)
INFO:root:Running
	Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=2), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False)
INFO:root:Fitting with 400 examples, 100 features
INFO:root:Running
	Predict(inputs=[Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=2), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False), ClassificationData(n_features=100, n_samples=1000)],
	    prefit=True, return_estimator=False, return_feature_importances=False,
	    return_predictions=True)
INFO:root:Predicting 600 examples
INFO:root:Running
	Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=3)
INFO:root:Running
	Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=3), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False)
INFO:root:Fitting with 400 examples, 100 features
INFO:root:Running
	Predict(inputs=[Fit(inputs=[Construct(__class_name__='sklearn.ensemble.RandomForestClassifier',
	     n_estimators=3), ClassificationData(n_features=100, n_samples=1000)],
	  prefit=False, return_estimator=True, return_feature_importances=True,
	  return_predictions=False), ClassificationData(n_features=100, n_samples=1000)],
	    prefit=True, return_estimator=False, return_feature_importances=False,
	    return_predictions=True)
INFO:root:Predicting 600 examples

Note that the ClassificationData step was only run once.

Drain provides some additional utilities for model exploration in the drain.explore module:


In [19]:
from drain import explore
df = explore.to_dataframe(predictions)
df


Out[19]:
n_estimators step
0 1 Predict(inputs=[Fit(inputs=[Construct(__class_...
1 2 Predict(inputs=[Fit(inputs=[Construct(__class_...
2 3 Predict(inputs=[Fit(inputs=[Construct(__class_...

In [20]:
from drain import model
explore.apply(df, model.auc)


Out[20]:
n_estimators
1    0.750700
2    0.809657
3    0.778523
Name: step, dtype: float64

In [21]:
%matplotlib inline
explore.apply(df, model.precision_series).plot()


Out[21]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fe129261dd0>