Import modules


In [1]:
from sklearn.datasets import load_boston
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import scale
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error
from sklearn.cross_validation import KFold
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

Load data

For this exercise, we will be using a dataset of housing prices in Boston during the 1970s. Python's super-awesome sklearn package already has the data we need to get started. Below is the command to load the data. The data is stored as a dictionary.

The 'DESCR' is a description of the data and the command for printing it is below. Note all the features we have to work with. From the dictionary, we need the data and the target variable (in this case, housing price). Store these as variables named "data" and "price", respectively. Once you have these, print their shapes to see all checks out with the DESCR.


In [2]:
boston = load_boston()
print boston.DESCR


Boston House Prices dataset

Notes
------
Data Set Characteristics:  

    :Number of Instances: 506 

    :Number of Attributes: 13 numeric/categorical predictive
    
    :Median Value (attribute 14) is usually the target

    :Attribute Information (in order):
        - CRIM     per capita crime rate by town
        - ZN       proportion of residential land zoned for lots over 25,000 sq.ft.
        - INDUS    proportion of non-retail business acres per town
        - CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)
        - NOX      nitric oxides concentration (parts per 10 million)
        - RM       average number of rooms per dwelling
        - AGE      proportion of owner-occupied units built prior to 1940
        - DIS      weighted distances to five Boston employment centres
        - RAD      index of accessibility to radial highways
        - TAX      full-value property-tax rate per $10,000
        - PTRATIO  pupil-teacher ratio by town
        - B        1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town
        - LSTAT    % lower status of the population
        - MEDV     Median value of owner-occupied homes in $1000's

    :Missing Attribute Values: None

    :Creator: Harrison, D. and Rubinfeld, D.L.

This is a copy of UCI ML housing dataset.
http://archive.ics.uci.edu/ml/datasets/Housing


This dataset was taken from the StatLib library which is maintained at Carnegie Mellon University.

The Boston house-price data of Harrison, D. and Rubinfeld, D.L. 'Hedonic
prices and the demand for clean air', J. Environ. Economics & Management,
vol.5, 81-102, 1978.   Used in Belsley, Kuh & Welsch, 'Regression diagnostics
...', Wiley, 1980.   N.B. Various transformations are used in the table on
pages 244-261 of the latter.

The Boston house-price data has been used in many machine learning papers that address regression
problems.   
     
**References**

   - Belsley, Kuh & Welsch, 'Regression diagnostics: Identifying Influential Data and Sources of Collinearity', Wiley, 1980. 244-261.
   - Quinlan,R. (1993). Combining Instance-Based and Model-Based Learning. In Proceedings on the Tenth International Conference of Machine Learning, 236-243, University of Massachusetts, Amherst. Morgan Kaufmann.
   - many more! (see http://archive.ics.uci.edu/ml/datasets/Housing)


In [3]:
data = boston.data
price = boston.target

In [4]:
print data.shape
print price.shape


(506, 13)
(506,)

Train-Test split

Now, using sklearn's train_test_split, (see here for more. I've already imported it for you.) let's make a random train-test split with the test size equal to 30% of our data (i.e. set the test_size parameter to 0.3). For consistency, let's also set the random.state parameter = 11.

Name the variables train_data, train_price for the training data and test_data, test_price for the test data. As a sanity check, let's also print the shapes of these variables.


In [6]:
data_train, data_test, price_train, price_test = train_test_split(data, price, test_size=0.30, random_state=11)

In [7]:
print data_train.shape
print data_test.shape
print price_train.shape
print price_test.shape


(354, 13)
(152, 13)
(354,)
(152,)

Scale our data

Before we get too far ahead, let's scale our data. Let's subtract the min from each column (feature) and divide by the difference between the max and min for each column.

Here's where things can get tricky. Remember, our test data is unseen yet we need to scale it. We cannot scale using it's min/max because the data is unseen might not be available to us en masse. Instead, we use the training min/max to scale the test data.

Be sure to check which axis you use to take the mins/maxes!

Let's add a "_stand" suffix to our train/test variable names for the standardized values


In [8]:
mins = np.min(data_train, axis = 0)
maxes = np.max(data_train, axis = 0)
diff = maxes - mins

In [9]:
diff


Out[9]:
array([  8.89698800e+01,   9.50000000e+01,   2.72800000e+01,
         1.00000000e+00,   4.86000000e-01,   4.86200000e+00,
         9.38000000e+01,   1.09969000e+01,   2.30000000e+01,
         5.24000000e+02,   9.40000000e+00,   3.96580000e+02,
         3.62400000e+01])

In [10]:
data_train_stand = (data_train - mins) / diff
data_test_stand = (data_test - mins) / diff

In [11]:
minPrice = np.min(price_train)
maxPrice = np.max(price_train)
diffPrice = maxPrice - minPrice

In [12]:
price_train_stand = (price_train - minPrice) / diffPrice
price_test_stand = (price_test - minPrice) / diffPrice

K-Fold CV

Now, here's where things might get really messy. Let's implement 10-Fold Cross Validation on K-NN across a range of K values (given below - 9 total). We'll keep our K for K-fold CV constant at 10.

Let's determine our accuracy using an RMSE (root-mean-square-error) value based on Euclidean distance. Save the errors for each fold at each K value (10 folds x 9 K values = 90 values) as you loop through.

Take a look at sklearn's K-fold CV. Also, sklearn has it's own K-NN implementation. There is also an implementation of mean squared error, though you'll have to take the root yourself. I've imported these for you already. :)


In [12]:
kValues = [1, 2, 3, 4, 5, 10, 20, 40, 80]

In [13]:
folds = KFold(len(data_train_stand), n_folds = 10, shuffle = True)

In [14]:
for train_index, val_index in folds:
        print train_index


[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  36  37
  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55
  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73
  74  75  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92
  93  94  95  96  97  98  99 100 101 102 103 105 106 107 108 109 110 111
 112 113 114 115 116 117 118 119 120 122 123 124 125 127 128 129 130 131
 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
 150 151 153 154 155 156 157 158 159 160 162 164 165 167 168 169 170 171
 172 173 174 175 176 177 178 179 180 182 183 184 185 186 187 188 189 190
 191 192 193 194 196 197 199 200 201 202 203 205 206 208 209 210 212 214
 216 217 218 219 220 221 222 224 225 226 227 228 229 230 231 232 234 237
 239 240 241 242 244 245 246 247 248 249 251 252 253 254 255 256 257 258
 259 260 261 262 263 264 265 266 267 268 270 271 272 274 275 276 277 278
 279 280 281 282 283 285 288 289 290 291 292 293 294 295 296 297 298 299
 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
 318 319 320 321 322 323 324 325 327 328 329 330 331 332 333 334 335 338
 339 341 342 343 344 346 347 348 349 351 352 353]
[  0   2   3   4   6   7   8   9  10  11  12  13  16  17  18  20  21  22
  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  41
  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59
  60  61  62  63  64  65  66  68  70  71  72  73  74  75  76  77  78  79
  80  81  82  83  84  86  87  88  89  90  91  93  94  95  96  97  98  99
 100 101 102 103 104 105 107 108 109 110 111 112 113 114 115 116 117 118
 120 121 123 124 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 142 143 144 145 146 148 151 152 153 154 155 156 157 158 159 160
 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
 197 198 199 200 201 202 204 205 206 207 208 209 210 211 212 213 214 215
 216 217 218 220 222 223 224 226 227 228 229 231 232 233 235 236 237 238
 239 240 241 242 243 244 245 249 250 251 252 253 254 255 257 258 259 260
 261 262 263 266 267 268 269 270 271 272 273 274 275 276 279 280 281 282
 283 284 285 286 287 288 289 290 291 292 293 294 295 297 298 299 300 301
 302 303 304 305 307 308 309 310 311 312 313 314 315 316 318 319 320 321
 322 323 324 326 327 329 330 331 332 333 334 335 336 337 338 339 340 341
 342 343 344 345 346 347 348 349 350 351 352 353]
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  25  26  27  29  30  31  32  33  34  35  36  37
  38  39  40  41  42  43  44  46  47  48  49  50  51  52  53  55  56  57
  58  59  60  61  62  63  64  65  67  69  71  72  73  74  76  77  78  79
  81  82  83  84  85  87  89  90  91  92  94  95  96  97  98  99 100 101
 102 103 104 105 106 107 108 109 111 112 114 115 116 117 118 119 120 121
 122 124 125 126 127 129 130 131 133 134 135 136 138 139 140 141 142 143
 144 145 147 148 149 150 151 152 154 155 156 157 158 159 160 161 162 163
 164 166 167 168 169 170 171 172 173 174 175 176 178 179 180 181 182 183
 184 185 186 187 188 190 191 192 193 194 195 196 197 198 199 200 201 202
 203 204 205 206 207 208 210 211 212 213 214 215 216 217 218 219 220 221
 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
 240 243 245 246 247 248 249 250 251 252 253 255 256 257 258 259 260 261
 262 263 264 265 266 267 268 269 271 272 273 274 275 277 278 279 280 281
 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
 300 301 302 303 304 306 307 308 311 312 313 314 316 317 318 319 320 321
 322 323 324 325 326 327 328 329 330 331 332 333 335 336 337 338 339 340
 341 342 343 345 346 347 348 349 350 351 352 353]
[  0   1   2   3   4   5   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  28  29  30  31  32  33  34  35  37  38
  39  40  41  44  45  46  47  48  49  50  51  52  53  54  56  57  58  59
  61  63  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81
  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99
 100 101 102 103 104 105 106 107 108 110 111 112 113 114 115 116 117 118
 119 120 121 122 123 124 125 126 128 129 130 131 132 133 135 136 137 138
 140 141 142 143 144 146 147 148 149 150 151 152 153 154 155 156 157 159
 161 162 163 164 165 166 167 168 169 170 171 172 173 174 176 177 178 180
 181 183 185 186 187 188 189 190 191 192 193 194 195 196 197 198 201 202
 203 204 205 206 207 209 210 211 212 213 214 215 217 218 219 221 222 223
 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
 243 244 245 246 247 248 249 250 251 252 253 254 255 256 258 259 260 262
 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
 299 300 301 302 304 305 306 308 309 310 311 313 314 315 316 317 318 319
 321 322 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
 340 341 342 343 344 345 347 348 349 350 351 353]
[  0   1   2   3   4   5   6   8  10  11  12  13  14  15  16  17  18  19
  22  23  24  25  26  27  28  29  30  32  33  34  35  36  37  38  39  40
  41  42  43  44  45  46  47  48  50  51  52  53  54  55  56  57  59  60
  61  62  63  64  65  66  67  68  69  70  71  72  74  75  76  77  79  80
  81  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99
 100 101 102 103 104 105 106 107 108 109 110 111 112 113 115 116 117 118
 119 121 122 123 124 125 126 127 128 129 130 131 132 134 135 136 137 138
 139 141 142 143 144 145 146 147 148 149 150 152 153 154 155 156 157 158
 159 160 161 163 164 165 166 167 168 169 170 171 172 174 175 176 177 178
 179 180 181 182 183 184 188 189 191 192 194 195 196 197 198 199 200 201
 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 218 219 220
 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
 275 276 277 278 281 282 283 284 286 287 288 289 290 292 294 295 296 297
 298 299 300 301 302 303 305 306 307 309 310 311 312 314 315 316 317 318
 319 320 321 323 324 325 326 327 328 329 331 332 333 334 335 336 337 338
 339 340 341 342 343 344 345 346 347 349 350 351 352]
[  0   1   2   3   5   6   7   8   9  10  11  13  14  15  16  18  19  20
  21  22  23  24  25  26  27  28  29  31  32  33  34  35  36  37  38  39
  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57
  58  59  60  61  62  63  64  65  66  67  68  69  70  71  73  75  76  78
  79  80  81  82  83  84  85  86  87  88  89  91  92  93  94  95  96  97
  98 100 103 104 105 106 107 109 110 111 112 113 114 115 116 118 119 120
 121 122 123 124 125 126 127 128 129 130 131 132 133 134 136 137 138 139
 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
 158 159 160 161 162 163 164 165 166 167 168 170 171 172 173 174 175 176
 177 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
 214 215 216 217 218 219 220 221 223 224 225 226 227 228 230 233 234 235
 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 252 253 254
 255 256 257 258 259 261 263 264 265 269 270 271 272 273 274 275 276 277
 278 279 280 281 283 284 285 286 287 288 289 290 291 292 293 294 295 296
 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 315
 316 317 318 319 320 322 323 325 326 328 330 331 332 334 336 337 338 339
 340 341 342 344 345 346 347 348 349 350 351 352 353]
[  1   3   4   5   6   7   8   9  11  12  13  14  15  17  18  19  20  21
  22  23  24  25  27  28  30  31  32  33  34  35  36  37  38  39  41  42
  43  44  45  46  47  48  49  51  52  53  54  55  56  57  58  59  60  61
  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  80
  81  82  84  85  86  87  88  89  90  91  92  93  94  97  99 100 101 102
 103 104 105 106 107 108 109 110 111 112 113 114 115 117 118 119 120 121
 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
 140 141 144 145 146 147 148 149 150 151 152 153 154 155 157 158 160 161
 162 163 164 165 166 167 168 169 170 171 173 175 176 177 178 179 180 181
 182 183 184 185 186 187 188 189 190 191 193 194 195 196 197 198 199 200
 201 202 203 204 207 208 209 211 212 213 214 215 216 217 218 219 220 221
 222 223 224 225 227 228 229 230 231 232 233 234 235 236 237 238 239 240
 241 242 243 244 245 246 247 248 249 250 251 252 253 254 256 257 258 259
 260 261 262 263 264 265 266 267 268 269 270 273 275 276 277 278 279 280
 281 282 284 285 286 287 289 290 291 293 294 296 297 298 300 301 302 303
 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
 322 323 324 325 326 327 328 329 330 332 333 334 335 336 337 338 339 340
 341 342 343 344 345 346 347 348 349 350 351 352 353]
[  0   1   2   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  33  34  35  36  37
  38  39  40  41  42  43  44  45  48  49  50  54  55  56  58  60  61  62
  63  64  65  66  67  68  69  70  72  73  74  75  76  77  78  79  80  81
  82  83  84  85  86  87  88  89  90  92  93  94  95  96  97  98  99 100
 101 102 104 105 106 108 109 110 111 113 114 116 117 118 119 120 121 122
 123 124 125 126 127 128 129 130 131 132 133 134 135 137 138 139 140 142
 143 144 145 146 147 149 150 151 152 153 154 155 156 157 158 159 160 161
 162 163 165 166 167 168 169 172 173 174 175 176 177 178 179 180 181 182
 183 184 185 186 187 188 189 190 192 193 194 195 196 197 198 199 200 201
 202 203 204 205 206 207 208 209 210 211 212 213 215 216 217 218 219 220
 221 222 223 224 225 226 229 230 231 232 233 234 235 236 237 238 239 240
 241 242 243 244 246 247 248 249 250 251 252 253 254 255 256 257 258 259
 260 261 262 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
 279 280 282 283 284 285 286 287 288 291 292 293 295 296 297 299 300 301
 303 304 305 306 307 308 309 310 311 312 313 314 315 317 318 319 320 321
 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
 340 341 342 343 344 345 346 348 349 350 351 352 353]
[  0   1   2   3   4   5   6   7   8   9  10  11  12  14  15  16  17  18
  19  20  21  23  24  25  26  27  28  29  30  31  32  34  35  36  38  40
  41  42  43  45  46  47  49  50  51  52  53  54  55  57  58  59  60  62
  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80
  81  82  83  84  85  86  88  90  91  92  93  95  96  98  99 100 101 102
 103 104 105 106 107 108 109 110 112 113 114 115 116 117 118 119 120 121
 122 123 124 125 126 127 128 132 133 134 135 136 137 139 140 141 142 143
 144 145 146 147 148 149 150 151 152 153 155 156 157 158 159 160 161 162
 163 164 165 166 169 170 171 172 173 174 175 176 177 178 179 181 182 183
 184 185 186 187 188 189 190 191 192 193 195 197 198 199 200 202 203 204
 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
 241 242 243 244 245 246 247 248 250 251 252 254 255 256 257 258 260 261
 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
 298 299 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
 317 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
 336 337 339 340 343 344 345 346 347 348 350 352 353]
[  0   1   2   3   4   5   6   7   9  10  12  13  14  15  16  17  19  20
  21  22  23  24  26  27  28  29  30  31  32  33  34  35  36  37  39  40
  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59
  60  61  62  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78
  79  80  82  83  85  86  87  88  89  90  91  92  93  94  95  96  97  98
  99 101 102 103 104 106 107 108 109 110 111 112 113 114 115 116 117 119
 120 121 122 123 125 126 127 128 129 130 131 132 133 134 135 136 137 138
 139 140 141 142 143 145 146 147 148 149 150 151 152 153 154 156 158 159
 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 177 178
 179 180 181 182 184 185 186 187 189 190 191 192 193 194 195 196 198 199
 200 201 203 204 205 206 207 208 209 210 211 213 214 215 216 217 219 220
 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 238 241
 242 243 244 245 246 247 248 249 250 251 253 254 255 256 257 259 260 261
 262 263 264 265 266 267 268 269 270 271 272 273 274 276 277 278 279 280
 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 298 299
 300 302 303 304 305 306 307 308 309 310 312 313 314 315 316 317 318 320
 321 322 323 324 325 326 327 328 329 330 331 333 334 335 336 337 338 340
 341 342 343 344 345 346 347 348 349 350 351 352 353]

In [15]:
scores = {}
for k in kValues:
    currentScores = []
    for train_index, val_index in folds:
        current_train_data, current_val_data = data_train_stand[train_index], data_train_stand[val_index]
        current_train_price, current_val_price = price_train_stand[train_index], price_train_stand[val_index]
        neigh = KNeighborsRegressor(n_neighbors = k)
        neigh.fit(current_train_data, current_train_price)
        guesses = neigh.predict(current_val_data)
        rmse = np.sqrt(mean_squared_error(guesses, current_val_price))
        currentScores.append(rmse)
    scores[k] = currentScores

Plot Results

Plot your training accuracy across all folds as a function of K. What do you see?


In [16]:
keys = sorted(scores.keys())
means = []
stdevs = []
for each in keys:
    current = scores[each]
    means.append(np.mean(current)) 
    stdevs.append(np.std(current))

In [17]:
figure = plt.figure()
plt.plot(keys, means, 'bo:')
plt.xlabel('K')
plt.ylabel('Normalized RMSE')
plt.title('Error as a function of # of neighbors')


Out[17]:
<matplotlib.text.Text at 0x10d7d5e10>

In [18]:
figure = plt.figure()
plt.plot(keys, means, 'bo:')
plt.xlabel('K')
plt.ylabel('Normalized RMSE')
plt.title('Error as a function of # of neighbors')
plt.xlim([0, 10]);