TF-DNNRegressor - eLU - Spitzer Calibration Data
This script show a simple example of using [tf.contrib.learn][1] library to create our model.
The code is divided in following steps:
- Load CSVs data
- Continuous features
- Converting Data into Tensors
- Selecting and Engineering Features for the Model
- Defining The Regression Model
- Training and Evaluating Our Model
- Predicting output for test data
/Users/jfraine/anaconda3/lib/python3.6/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
df_train_ori = pd.read_csv('train.csv')
df_test_ori = pd.read_csv('test.csv')
Out[3]:
|
pix1 |
pix2 |
pix3 |
pix4 |
pix5 |
pix6 |
pix7 |
pix8 |
pix9 |
| 0 |
577.447021 |
3465.876709 |
1118.598145 |
550.165466 |
2460.376953 |
994.374207 |
141.741592 |
521.385254 |
694.330688 |
| 1 |
569.863098 |
3387.739258 |
1087.530762 |
556.717407 |
2552.070557 |
1021.892700 |
134.061081 |
511.347778 |
666.346069 |
| 2 |
552.641235 |
3405.411377 |
1082.131104 |
558.981445 |
2560.040771 |
1058.485352 |
146.488220 |
513.809570 |
691.019653 |
| 3 |
571.821167 |
3340.533691 |
1073.962036 |
568.324768 |
2643.155273 |
1024.431641 |
147.687546 |
525.451538 |
727.683472 |
| 4 |
538.292114 |
3248.569336 |
1021.301208 |
548.598145 |
2691.563965 |
1066.199707 |
154.170990 |
541.407532 |
718.537537 |
| 5 |
553.332214 |
3183.050293 |
1026.863281 |
578.003784 |
2679.043457 |
1085.607422 |
154.789886 |
517.848328 |
722.456909 |
| 6 |
541.202332 |
3137.938232 |
1035.948364 |
589.566528 |
2743.923584 |
1072.393555 |
158.401291 |
539.485718 |
724.184937 |
| 7 |
547.699829 |
3057.428467 |
1034.417603 |
590.498108 |
2818.238770 |
1118.098633 |
167.132004 |
540.160950 |
760.767334 |
| 8 |
534.724976 |
3015.335205 |
1007.406494 |
574.792358 |
2876.417236 |
1141.010498 |
158.403992 |
527.611633 |
765.304077 |
| 9 |
529.423950 |
3010.478027 |
973.444946 |
578.736755 |
2875.608887 |
1106.352905 |
163.936325 |
540.751587 |
746.721191 |
| 10 |
543.917847 |
2959.855225 |
1002.942017 |
600.163574 |
2921.819580 |
1124.441772 |
153.599976 |
548.917603 |
735.936890 |
| 11 |
415.543091 |
1085.521729 |
446.313171 |
1641.077881 |
4778.175781 |
319.484497 |
418.018860 |
810.050232 |
567.318176 |
| 12 |
386.626038 |
990.775635 |
455.084015 |
1653.742676 |
4831.918457 |
318.239410 |
428.828278 |
836.078308 |
576.660156 |
| 13 |
352.204071 |
909.555847 |
448.546478 |
1697.422363 |
4942.354004 |
315.904175 |
434.170258 |
822.891968 |
576.275024 |
| 14 |
356.336823 |
830.570557 |
446.085022 |
1725.667725 |
4915.351074 |
312.517944 |
455.358154 |
826.022034 |
565.244507 |
| 15 |
346.803467 |
827.159546 |
445.986084 |
1783.878418 |
4926.551270 |
299.384399 |
454.638336 |
856.649109 |
572.549500 |
| 16 |
319.600525 |
778.882324 |
452.693604 |
1766.618774 |
4900.985352 |
295.125763 |
455.190002 |
873.377808 |
581.590820 |
| 17 |
318.834503 |
772.429138 |
450.526428 |
1804.441284 |
4888.824219 |
292.672455 |
460.712219 |
873.199463 |
557.806458 |
| 18 |
304.534882 |
747.088501 |
433.972229 |
1812.826172 |
4949.355469 |
292.384491 |
459.964996 |
883.855896 |
556.133789 |
| 19 |
297.334167 |
711.377441 |
427.756256 |
1824.887939 |
4977.011230 |
288.222961 |
442.602753 |
883.110352 |
578.062012 |
| 20 |
291.785248 |
696.303345 |
434.420410 |
1808.774414 |
4967.067383 |
288.906006 |
465.871857 |
848.896790 |
584.660889 |
| 21 |
275.729004 |
684.407776 |
436.270630 |
1839.743286 |
4958.617188 |
292.503906 |
473.751465 |
874.837891 |
576.924805 |
| 22 |
267.109802 |
668.616150 |
428.174561 |
1800.920166 |
5008.872559 |
289.184784 |
474.868591 |
872.165161 |
579.913147 |
| 23 |
272.932220 |
664.508301 |
437.244934 |
1776.687744 |
5002.353027 |
288.002045 |
470.427643 |
882.427673 |
600.059509 |
| 24 |
269.932495 |
641.567139 |
432.846771 |
1789.313354 |
5064.815918 |
303.351807 |
484.210266 |
894.807068 |
602.440247 |
| 25 |
265.979431 |
648.039856 |
432.684937 |
1755.147705 |
5088.369141 |
288.072937 |
482.552338 |
882.221191 |
592.758179 |
| 26 |
564.362610 |
1541.839600 |
431.914764 |
1446.160400 |
4362.256348 |
345.700531 |
401.668152 |
821.724915 |
520.839478 |
| 27 |
506.248077 |
1415.009277 |
443.126648 |
1487.735840 |
4414.598633 |
347.017731 |
393.479889 |
795.969360 |
549.537842 |
| 28 |
515.429810 |
1315.711914 |
448.500977 |
1547.112671 |
4509.289062 |
323.252441 |
394.831543 |
821.549072 |
529.604004 |
| 29 |
476.979919 |
1244.519043 |
429.070221 |
1599.996094 |
4529.244629 |
330.517914 |
398.165985 |
820.598389 |
528.216064 |
| ... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
| 785255 |
475.147705 |
2842.766357 |
591.782959 |
633.689026 |
3837.167969 |
771.729736 |
238.382645 |
637.508911 |
712.187988 |
| 785256 |
475.464050 |
2835.344727 |
583.647827 |
640.641846 |
3873.379639 |
750.115234 |
248.780396 |
651.230103 |
716.854736 |
| 785257 |
481.847534 |
2817.340820 |
578.740479 |
637.754578 |
3877.218262 |
781.961548 |
241.087128 |
650.161316 |
733.283325 |
| 785258 |
465.787201 |
2804.904053 |
569.343262 |
647.821289 |
3857.104492 |
767.500122 |
240.323898 |
628.991760 |
726.848083 |
| 785259 |
473.557251 |
2804.657471 |
561.592529 |
633.638123 |
3901.351562 |
773.559692 |
231.294861 |
641.328003 |
733.939087 |
| 785260 |
476.991699 |
2791.731445 |
579.058228 |
643.239990 |
3923.991455 |
765.078247 |
235.408035 |
661.097107 |
718.653198 |
| 785261 |
476.635651 |
2750.070068 |
563.718689 |
635.713318 |
3949.495361 |
772.300659 |
246.790222 |
632.139221 |
730.945618 |
| 785262 |
481.483948 |
2736.324219 |
567.592407 |
660.586548 |
3944.277100 |
766.693787 |
241.954529 |
633.959961 |
734.623413 |
| 785263 |
472.956024 |
2736.230225 |
562.863098 |
659.623413 |
4004.276367 |
764.286011 |
239.283524 |
660.717346 |
713.998047 |
| 785264 |
466.867065 |
2732.458008 |
569.005798 |
666.330994 |
4024.462891 |
769.562134 |
245.346207 |
663.933350 |
725.252869 |
| 785265 |
480.046295 |
2715.175537 |
557.256165 |
645.355225 |
4037.946533 |
742.447998 |
258.840576 |
654.178589 |
706.549988 |
| 785266 |
475.449463 |
2769.531494 |
536.281189 |
655.740417 |
4031.584229 |
730.941650 |
255.209244 |
658.127136 |
719.966553 |
| 785267 |
502.837372 |
2716.243408 |
554.105591 |
677.264648 |
3980.764404 |
730.059265 |
259.389648 |
667.066284 |
708.509277 |
| 785268 |
480.399200 |
2769.769531 |
562.770813 |
679.593140 |
3983.617920 |
712.878052 |
243.168762 |
654.615112 |
689.989075 |
| 785269 |
480.807373 |
2759.118408 |
561.283997 |
699.134888 |
4009.483398 |
721.416626 |
262.486328 |
654.512329 |
713.022461 |
| 785270 |
523.245300 |
2958.067139 |
543.905029 |
697.710938 |
3776.409424 |
666.272888 |
263.427338 |
651.643799 |
670.494263 |
| 785271 |
495.875732 |
2955.598145 |
552.345215 |
657.297363 |
3747.275879 |
703.948242 |
250.976898 |
683.805420 |
657.604126 |
| 785272 |
516.237366 |
2964.318604 |
577.262634 |
668.211304 |
3793.730225 |
709.129517 |
256.492645 |
667.471802 |
683.404236 |
| 785273 |
501.681335 |
2919.767578 |
550.606995 |
638.708374 |
3802.873535 |
717.027649 |
250.493073 |
644.383301 |
702.473267 |
| 785274 |
510.000214 |
2995.055420 |
551.373047 |
660.349304 |
3745.079590 |
722.202332 |
247.733948 |
642.584961 |
695.485779 |
| 785275 |
521.931641 |
3013.854004 |
554.299988 |
631.953613 |
3710.329346 |
728.134094 |
243.888962 |
643.002563 |
691.181763 |
| 785276 |
495.807709 |
2989.654053 |
585.109985 |
639.402100 |
3714.461670 |
749.341003 |
235.572479 |
646.545593 |
676.286865 |
| 785277 |
512.751160 |
2958.394531 |
582.242310 |
652.946533 |
3707.906250 |
729.102295 |
234.076538 |
652.182556 |
688.491699 |
| 785278 |
504.160034 |
2989.499268 |
587.763611 |
658.194824 |
3695.002930 |
736.696594 |
231.465927 |
651.476196 |
681.635620 |
| 785279 |
516.517090 |
2976.608643 |
586.910278 |
637.642395 |
3714.598877 |
752.355835 |
235.402084 |
637.375977 |
700.220703 |
| 785280 |
506.839264 |
3055.761230 |
596.532349 |
647.959351 |
3610.522217 |
731.254211 |
234.202530 |
658.012695 |
675.393433 |
| 785281 |
495.875641 |
3055.138428 |
590.330444 |
623.768982 |
3686.792236 |
732.337524 |
234.055603 |
654.529114 |
690.857361 |
| 785282 |
510.378479 |
3120.925781 |
585.978516 |
623.555176 |
3621.902100 |
728.794128 |
230.795486 |
644.155518 |
694.674377 |
| 785283 |
520.521362 |
3118.752197 |
582.462891 |
632.206604 |
3600.458984 |
756.452515 |
224.326355 |
641.459534 |
696.739502 |
| 785284 |
520.870972 |
3144.105225 |
609.523743 |
617.985229 |
3561.716309 |
752.712402 |
225.764816 |
620.627991 |
680.224121 |
785285 rows × 9 columns
Out[5]:
|
pix1 |
pix2 |
pix3 |
pix4 |
pix5 |
pix6 |
pix7 |
pix8 |
pix9 |
| 0 |
0.054868 |
0.329321 |
0.106287 |
0.052276 |
0.233781 |
0.094484 |
0.013468 |
0.049541 |
0.065974 |
| 1 |
0.054337 |
0.323024 |
0.103697 |
0.053084 |
0.243342 |
0.097438 |
0.012783 |
0.048758 |
0.063537 |
| 2 |
0.052289 |
0.322207 |
0.102387 |
0.052889 |
0.242221 |
0.100150 |
0.013860 |
0.048615 |
0.065382 |
| 3 |
0.053828 |
0.314461 |
0.101097 |
0.053499 |
0.248813 |
0.096435 |
0.013903 |
0.049463 |
0.068500 |
| 4 |
0.051126 |
0.308546 |
0.097002 |
0.052105 |
0.255642 |
0.101267 |
0.014643 |
0.051422 |
0.068246 |
| 5 |
0.052693 |
0.303119 |
0.097787 |
0.055043 |
0.255123 |
0.103381 |
0.014740 |
0.049314 |
0.068799 |
| 6 |
0.051333 |
0.297631 |
0.098259 |
0.055920 |
0.260259 |
0.101716 |
0.015024 |
0.051170 |
0.068688 |
| 7 |
0.051502 |
0.287502 |
0.097271 |
0.055527 |
0.265011 |
0.105139 |
0.015716 |
0.050794 |
0.071538 |
| 8 |
0.050441 |
0.284439 |
0.095029 |
0.054221 |
0.271334 |
0.107632 |
0.014942 |
0.049770 |
0.072192 |
| 9 |
0.050299 |
0.286019 |
0.092485 |
0.054984 |
0.273205 |
0.105112 |
0.015575 |
0.051376 |
0.070944 |
| 10 |
0.051354 |
0.279453 |
0.094692 |
0.056664 |
0.275862 |
0.106164 |
0.014502 |
0.051826 |
0.069483 |
| 11 |
0.039645 |
0.103565 |
0.042581 |
0.156569 |
0.455867 |
0.030481 |
0.039882 |
0.077284 |
0.054126 |
| 12 |
0.036899 |
0.094558 |
0.043433 |
0.157831 |
0.461151 |
0.030372 |
0.040927 |
0.079794 |
0.055036 |
| 13 |
0.033545 |
0.086630 |
0.042721 |
0.161670 |
0.470731 |
0.030088 |
0.041352 |
0.078376 |
0.054887 |
| 14 |
0.034154 |
0.079609 |
0.042756 |
0.165402 |
0.471128 |
0.029954 |
0.043645 |
0.079173 |
0.054178 |
| 15 |
0.032986 |
0.078675 |
0.042420 |
0.169673 |
0.468588 |
0.028476 |
0.043243 |
0.081480 |
0.054458 |
| 16 |
0.030660 |
0.074720 |
0.043428 |
0.169475 |
0.470161 |
0.028312 |
0.043667 |
0.083785 |
0.055793 |
| 17 |
0.030600 |
0.074133 |
0.043239 |
0.173180 |
0.469202 |
0.028089 |
0.044217 |
0.083805 |
0.053535 |
| 18 |
0.029170 |
0.071559 |
0.041568 |
0.173640 |
0.474071 |
0.028006 |
0.044057 |
0.084660 |
0.053269 |
| 19 |
0.028507 |
0.068203 |
0.041011 |
0.174959 |
0.477166 |
0.027633 |
0.042434 |
0.084667 |
0.055421 |
| 20 |
0.028092 |
0.067038 |
0.041825 |
0.174144 |
0.478215 |
0.027815 |
0.044853 |
0.081729 |
0.056289 |
| 21 |
0.026480 |
0.065728 |
0.041898 |
0.176681 |
0.476205 |
0.028091 |
0.045497 |
0.084016 |
0.055405 |
| 22 |
0.025709 |
0.064353 |
0.041211 |
0.173335 |
0.482094 |
0.027833 |
0.045705 |
0.083944 |
0.055815 |
| 23 |
0.026257 |
0.063928 |
0.042064 |
0.170923 |
0.481243 |
0.027707 |
0.045257 |
0.084893 |
0.057728 |
| 24 |
0.025749 |
0.061199 |
0.041289 |
0.170683 |
0.483133 |
0.028937 |
0.046189 |
0.085356 |
0.057467 |
| 25 |
0.025487 |
0.062098 |
0.041461 |
0.168185 |
0.487587 |
0.027604 |
0.046240 |
0.084538 |
0.056800 |
| 26 |
0.054076 |
0.147736 |
0.041385 |
0.138568 |
0.417982 |
0.033124 |
0.038487 |
0.078736 |
0.049906 |
| 27 |
0.048900 |
0.136680 |
0.042803 |
0.143705 |
0.426419 |
0.033519 |
0.038007 |
0.076885 |
0.053081 |
| 28 |
0.049535 |
0.126447 |
0.043103 |
0.148685 |
0.433365 |
0.031066 |
0.037945 |
0.078955 |
0.050898 |
| 29 |
0.046052 |
0.120159 |
0.041427 |
0.154480 |
0.437299 |
0.031912 |
0.038443 |
0.079229 |
0.050999 |
| ... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
| 785255 |
0.044239 |
0.264681 |
0.055099 |
0.059001 |
0.357266 |
0.071853 |
0.022195 |
0.059356 |
0.066309 |
| 785256 |
0.044125 |
0.263130 |
0.054165 |
0.059454 |
0.359463 |
0.069613 |
0.023088 |
0.060436 |
0.066527 |
| 785257 |
0.044618 |
0.260880 |
0.053590 |
0.059055 |
0.359022 |
0.072408 |
0.022324 |
0.060203 |
0.067900 |
| 785258 |
0.043496 |
0.261929 |
0.053167 |
0.060495 |
0.360187 |
0.071671 |
0.022442 |
0.058737 |
0.067875 |
| 785259 |
0.044032 |
0.260779 |
0.052217 |
0.058916 |
0.362750 |
0.071926 |
0.021506 |
0.059631 |
0.068242 |
| 785260 |
0.044185 |
0.258607 |
0.053640 |
0.059585 |
0.363492 |
0.070872 |
0.021807 |
0.061240 |
0.066571 |
| 785261 |
0.044306 |
0.255635 |
0.052401 |
0.059093 |
0.367128 |
0.071790 |
0.022941 |
0.058761 |
0.067946 |
| 785262 |
0.044716 |
0.254128 |
0.052714 |
0.061350 |
0.366313 |
0.071204 |
0.022471 |
0.058877 |
0.068226 |
| 785263 |
0.043735 |
0.253021 |
0.052048 |
0.060996 |
0.370278 |
0.070674 |
0.022127 |
0.061097 |
0.066024 |
| 785264 |
0.042977 |
0.251533 |
0.052379 |
0.061338 |
0.370467 |
0.070841 |
0.022585 |
0.061118 |
0.066762 |
| 785265 |
0.044458 |
0.251456 |
0.051608 |
0.059767 |
0.373960 |
0.068759 |
0.023972 |
0.060584 |
0.065435 |
| 785266 |
0.043890 |
0.255661 |
0.049505 |
0.060533 |
0.372163 |
0.067475 |
0.023559 |
0.060753 |
0.066462 |
| 785267 |
0.046575 |
0.251592 |
0.051324 |
0.062732 |
0.368718 |
0.067622 |
0.024026 |
0.061787 |
0.065626 |
| 785268 |
0.044577 |
0.257012 |
0.052221 |
0.063061 |
0.369648 |
0.066149 |
0.022564 |
0.060743 |
0.064025 |
| 785269 |
0.044268 |
0.254033 |
0.051678 |
0.064370 |
0.369154 |
0.066421 |
0.024167 |
0.060261 |
0.065648 |
| 785270 |
0.048669 |
0.275139 |
0.050590 |
0.064896 |
0.351255 |
0.061972 |
0.024502 |
0.060611 |
0.062365 |
| 785271 |
0.046323 |
0.276102 |
0.051598 |
0.061403 |
0.350058 |
0.065761 |
0.023445 |
0.063879 |
0.061431 |
| 785272 |
0.047640 |
0.273556 |
0.053271 |
0.061664 |
0.350096 |
0.065440 |
0.023670 |
0.061596 |
0.063066 |
| 785273 |
0.046764 |
0.272163 |
0.051324 |
0.059536 |
0.354481 |
0.066837 |
0.023349 |
0.060065 |
0.065480 |
| 785274 |
0.047354 |
0.278096 |
0.051196 |
0.061315 |
0.347737 |
0.067058 |
0.023003 |
0.059665 |
0.064577 |
| 785275 |
0.048603 |
0.280657 |
0.051618 |
0.058849 |
0.345514 |
0.067805 |
0.022711 |
0.059878 |
0.064364 |
| 785276 |
0.046198 |
0.278569 |
0.054519 |
0.059578 |
0.346105 |
0.069822 |
0.021950 |
0.060244 |
0.063015 |
| 785277 |
0.047840 |
0.276019 |
0.054323 |
0.060920 |
0.345948 |
0.068025 |
0.021839 |
0.060849 |
0.064236 |
| 785278 |
0.046960 |
0.278458 |
0.054748 |
0.061308 |
0.344173 |
0.068620 |
0.021560 |
0.060682 |
0.063491 |
| 785279 |
0.048014 |
0.276697 |
0.054558 |
0.059273 |
0.345299 |
0.069937 |
0.021882 |
0.059249 |
0.065091 |
| 785280 |
0.047295 |
0.285146 |
0.055665 |
0.060464 |
0.336913 |
0.068236 |
0.021854 |
0.061402 |
0.063024 |
| 785281 |
0.046069 |
0.283838 |
0.054845 |
0.057951 |
0.342521 |
0.068038 |
0.021745 |
0.060809 |
0.064184 |
| 785282 |
0.047428 |
0.290018 |
0.054453 |
0.057945 |
0.336572 |
0.067724 |
0.021447 |
0.059859 |
0.064554 |
| 785283 |
0.048316 |
0.289487 |
0.054065 |
0.058682 |
0.334200 |
0.070215 |
0.020822 |
0.059541 |
0.064672 |
| 785284 |
0.048527 |
0.292924 |
0.056787 |
0.057575 |
0.331831 |
0.070127 |
0.021034 |
0.057821 |
0.063374 |
785285 rows × 9 columns
Out[10]:
Index(['xpos', 'xerr', 'ypos', 'yerr', 'xycov', 'flux', 'fluxerr', 'np',
'xfwhm', 'yfwhm', 'dn_peak', 'bmjd', 't_cernox', 'bg_flux',
'sigma_bg_flux', 'pix1', 'pix2', 'pix3', 'pix4', 'pix5', 'pix6', 'pix7',
'pix8', 'pix9'],
dtype='object')
Out[13]:
<matplotlib.legend.Legend at 0x10c5d4940>
[plt.plot(PLDpixels[key]) for key in PLDpixels.columns.values];
spitzerData = spitzerDataRaw.copy()
for key in spitzerDataRaw.columns:
if key in PLDpixels.columns:
spitzerData[key] = PLDpixels[key]
testPLD = np.array(pd.DataFrame({key:spitzerData[key] for key in spitzerData.columns.values if 'pix' in key}))
assert(not sum(abs(testPLD - np.array(PLDpixels))).all())
print('Confirmed that PLD Pixels have been Normalized to Spec')
notFeatures = ['flux', 'fluxerr', 'xerr', 'yerr', 'xycov', 't_cernox']
periodMax = spitzerData['bmjd'].values.max() - spitzerData['bmjd'].values.min()
periodMin = np.min(np.diff(spitzerData['bmjd'].values))
spitzerData['freq'] = np.linspace(np.pi/periodMax, 4*np.pi/periodMin, spitzerData['bmjd'].values.size)
feature_columns = spitzerData.drop(notFeatures,axis=1).columns.values
features = spitzerData.drop(notFeatures,axis=1).values
labels = spitzerData['flux'].values
stdScaler = StandardScaler()
features_scaled = stdScaler.fit_transform(features)
labels_scaled = labels#stdScaler.fit_transform(labels[:,None]).ravel()
x_valtest, x_train, y_valtest, y_train = train_test_split(features_scaled, labels_scaled, test_size=0.6, random_state=42)#, stratify=labels_scaled)
x_val, x_test, y_val, y_test = train_test_split(x_valtest, y_valtest, test_size=0.5, random_state=42)#, stratify=y_valtest)
print(x_val.shape , 'validation samples')
print(x_train.shape, 'train samples')
print(x_test.shape , 'test samples')
train_df = pd.DataFrame(np.c_[x_train, y_train], columns=list(feature_columns) + ['flux'])
test_df = pd.DataFrame(np.c_[x_test , y_test ], columns=list(feature_columns) + ['flux'])
evaluate_df = pd.DataFrame(np.c_[x_val , y_val ], columns=list(feature_columns) + ['flux'])
plt.scatter(train_df['xpos'].values, train_df['ypos'].values, c=train_df['flux'].values, alpha=0.1);
plt.colorbar();
We only take first 1000 rows for training/testing and last 500 row for evaluation.
This done so that this script does not consume a lot of kaggle system resources.
# train_df = df_train_ori.head(1000)
# evaluate_df = df_train_ori.tail(500)
# test_df = df_test_ori.head(1000)
# MODEL_DIR = "tf_model_spitzer/withNormalization_drop50/relu"
# MODEL_DIR = "tf_model_spitzer/adamOptimizer_with_drop50/relu"
MODEL_DIR = "tf_model_spitzer/adamOptimizer/drop50/elu/"
print("train_df.shape = " , train_df.shape)
print("test_df.shape = " , test_df.shape)
print("evaluate_df.shape = ", evaluate_df.shape)
## Filtering Categorical and Continuous features
We store Categorical, Continuous and Target features names in different variables. This will be helpful in later steps.
# categorical_features = [feature for feature in features if 'cat' in feature]
categorical_features = []
continuous_features = [feature for feature in train_df.columns]# if 'cat' in feature]
LABEL_COLUMN = 'flux'
## Converting Data into Tensors
> When building a TF.Learn model, the input data is specified by means of an Input Builder function. This builder function will not be called until it is later passed to TF.Learn methods such as fit and evaluate. The purpose of this function is to construct the input data, which is represented in the form of Tensors or SparseTensors.
> Note that input_fn will be called while constructing the TensorFlow graph, not while running the graph. What it is returning is a representation of the input data as the fundamental unit of TensorFlow computations, a Tensor (or SparseTensor).
[More detail][2] on input_fn.
[2]: https://www.tensorflow.org/versions/r0.11/tutorials/input_fn/index.html#building-input-functions-with-tf-contrib-learn
# Converting Data into Tensors
def input_fn(df, training = True):
# Creates a dictionary mapping from each continuous feature column name (k) to
# the values of that column stored in a constant Tensor.
continuous_cols = {k: tf.constant(df[k].values)
for k in continuous_features}
# Creates a dictionary mapping from each categorical feature column name (k)
# to the values of that column stored in a tf.SparseTensor.
# categorical_cols = {k: tf.SparseTensor(
# indices=[[i, 0] for i in range(df[k].size)],
# values=df[k].values,
# shape=[df[k].size, 1])
# for k in categorical_features}
# Merges the two dictionaries into one.
feature_cols = continuous_cols
# feature_cols = dict(list(continuous_cols.items()) + list(categorical_cols.items()))
if training:
# Converts the label column into a constant Tensor.
label = tf.constant(df[LABEL_COLUMN].values)
# Returns the feature columns and the label.
return feature_cols, label
# Returns the feature columns
return feature_cols
def train_input_fn():
return input_fn(train_df, training=True)
def eval_input_fn():
return input_fn(evaluate_df, training=True)
def test_input_fn():
return input_fn(test_df, training=True)
## Selecting and Engineering Features for the Model
We use tf.learn's concept of [FeatureColumn][FeatureColumn] which help in transforming raw data into suitable input features.
These engineered features will be used when we construct our model.
[FeatureColumn]: https://www.tensorflow.org/versions/r0.11/tutorials/linear/overview.html#feature-columns-and-transformations
engineered_features = []
for continuous_feature in continuous_features:
engineered_features.append(
tf.contrib.layers.real_valued_column(continuous_feature))
# for categorical_feature in categorical_features:
# sparse_column = tf.contrib.layers.sparse_column_with_hash_bucket(
# categorical_feature, hash_bucket_size=1000)
# engineered_features.append(tf.contrib.layers.embedding_column(sparse_id_column=sparse_column, dimension=16,
# combiner="sum"))
## Defining The Regression Model
Following is the simple DNNRegressor model. More detail about hidden_units, etc can be found [here][123].
**model_dir** is used to save and restore our model. This is because once we have trained the model we don't want to train it again, if we only want to predict on new data-set.
[123]: https://www.tensorflow.org/versions/r0.9/api_docs/python/contrib.learn.html#DNNRegressor
nHidden1 = train_df.shape[1]
nHidden2 = train_df.shape[1]
nHidden3 = train_df.shape[1]
regressor = tf.contrib.learn.DNNRegressor(activation_fn=tf.nn.elu, dropout=0.5, optimizer=tf.train.AdamOptimizer,
feature_columns=engineered_features, hidden_units=[nHidden1, nHidden2, nHidden3], model_dir=MODEL_DIR)
## Training and Evaluating Our Model
add progress bar through python `logging`
import logging
logging.getLogger().setLevel(logging.INFO)
# Training Our Model
nFitSteps = 100000
start = time()
wrap = regressor.fit(input_fn=train_input_fn, steps=nFitSteps)
print('TF Regressor took {} seconds'.format(time()-start))
# Evaluating Our Model
print('Evaluating ...')
results = regressor.evaluate(input_fn=eval_input_fn, steps=1)
for key in sorted(results):
print("{}: {}".format(key, results[key]))
print("Val Acc: {:.3f}".format((1-results['loss'])*100))
# Evaluating Our Model
print('Evaluating ...')
results = regressor.evaluate(input_fn=test_input_fn, steps=1)
for key in sorted(results):
print("{}: {}".format(key, results[key]))
print("Val Acc: {:.3f}".format((1-results['loss'])*100))
**Track Scalable Growth**
Shrunk data set to 23559 Training samples and 7853 Val/Test samples
| n_iters | time (s) | val acc | multicore | gpu |
|------------------------------------------------|
| 100 | 5.869 | 6.332 | yes | no |
| 200 | 6.380 | 13.178 | yes | no |
| 500 | 8.656 | 54.220 | yes | no |
| 1000 | 12.170 | 66.596 | yes | no |
| 2000 | 19.891 | 62.996 | yes | no |
| 5000 | 43.589 | 76.586 | yes | no |
| 10000 | 80.581 | 66.872 | yes | no |
| 20000 | 162.435 | 78.927 | yes | no |
| 50000 | 535.584 | 75.493 | yes | no |
| 100000 | 1062.656 | 73.162 | yes | no |
nItersList = [100,200,500,1000,2000,5000,10000,20000,50000,100000]
rtimesList = [5.869, 6.380, 8.656, 12.170, 19.891, 43.589, 80.581, 162.435, 535.584, 1062.656]
valAccList = [6.332, 13.178, 54.220, 66.596, 62.996, 76.586, 66.872, 78.927, 75.493, 73.162]
plt.loglog(nItersList, rtimesList,'o-');
plt.twinx()
plt.semilogx(nItersList, valAccList,'o-', color='orange');
Predicting output for test data
Most of the time prediction script would be separate from training script (we need not to train on same data again) but I am providing both in same script here; as I am not sure if we can create multiple notebook and somehow share data between them in Kaggle.
saveDir = 'tfSaveModels'
regressor.export_savedmodel(saveDir, regressor)
saveDir = 'tfSaveModels'
reg_args = {'feature_columns': fc, 'hidden_units': hu_array, ...}
regressor = tf.contrib.learn.DNNRegressor(**reg_args)
pickle.dump(reg_args, open('reg_args.pkl', 'wb'))
reg_args = pickle.load(open('reg_args.pkl', 'rb'))
# On another machine and so my model dir path changed:
reg_args['model_dir'] = NEW_MODEL_DIR
regressor = tf.contrib.learn.DNNRegressor(**reg_args)
predicted_scores = list(regressor.predict_scores(input_fn=test_input_fn))
# x = list(predicted_output)