In [1]:
from theano.sandbox import cuda


WARNING (theano.sandbox.cuda): The cuda backend is deprecated and will be removed in the next release (v0.10).  Please switch to the gpuarray backend. You can get more information about how to switch at this URL:
 https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29

Using gpu device 0: Tesla K80 (CNMeM is disabled, cuDNN 5103)

In [2]:
import utils; reload(utils)
from utils import *


Using Theano backend.

In [3]:
ratings = pd.read_csv('data/ml-latest-small/ratings.csv')

In [4]:
ratings


Out[4]:
userId movieId rating timestamp
0 1 31 2.5 1260759144
1 1 1029 3.0 1260759179
2 1 1061 3.0 1260759182
3 1 1129 2.0 1260759185
4 1 1172 4.0 1260759205
5 1 1263 2.0 1260759151
6 1 1287 2.0 1260759187
7 1 1293 2.0 1260759148
8 1 1339 3.5 1260759125
9 1 1343 2.0 1260759131
10 1 1371 2.5 1260759135
11 1 1405 1.0 1260759203
12 1 1953 4.0 1260759191
13 1 2105 4.0 1260759139
14 1 2150 3.0 1260759194
15 1 2193 2.0 1260759198
16 1 2294 2.0 1260759108
17 1 2455 2.5 1260759113
18 1 2968 1.0 1260759200
19 1 3671 3.0 1260759117
20 2 10 4.0 835355493
21 2 17 5.0 835355681
22 2 39 5.0 835355604
23 2 47 4.0 835355552
24 2 50 4.0 835355586
25 2 52 3.0 835356031
26 2 62 3.0 835355749
27 2 110 4.0 835355532
28 2 144 3.0 835356016
29 2 150 5.0 835355395
... ... ... ... ...
99974 671 4034 4.5 1064245493
99975 671 4306 5.0 1064245548
99976 671 4308 3.5 1065111985
99977 671 4880 4.0 1065111973
99978 671 4886 5.0 1064245488
99979 671 4896 5.0 1065111996
99980 671 4963 4.5 1065111855
99981 671 4973 4.5 1064245471
99982 671 4993 5.0 1064245483
99983 671 4995 4.0 1064891537
99984 671 5010 2.0 1066793004
99985 671 5218 2.0 1065111990
99986 671 5299 3.0 1065112004
99987 671 5349 4.0 1065111863
99988 671 5377 4.0 1064245557
99989 671 5445 4.5 1064891627
99990 671 5464 3.0 1064891549
99991 671 5669 4.0 1063502711
99992 671 5816 4.0 1065111963
99993 671 5902 3.5 1064245507
99994 671 5952 5.0 1063502716
99995 671 5989 4.0 1064890625
99996 671 5991 4.5 1064245387
99997 671 5995 4.0 1066793014
99998 671 6212 2.5 1065149436
99999 671 6268 2.5 1065579370
100000 671 6269 4.0 1065149201
100001 671 6365 4.0 1070940363
100002 671 6385 2.5 1070979663
100003 671 6565 3.5 1074784724

100004 rows × 4 columns


In [5]:
movie_names = pd.read_csv('data/ml-latest-small/movies.csv')

In [6]:
movie_names


Out[6]:
movieId title genres
0 1 Toy Story (1995) Adventure|Animation|Children|Comedy|Fantasy
1 2 Jumanji (1995) Adventure|Children|Fantasy
2 3 Grumpier Old Men (1995) Comedy|Romance
3 4 Waiting to Exhale (1995) Comedy|Drama|Romance
4 5 Father of the Bride Part II (1995) Comedy
5 6 Heat (1995) Action|Crime|Thriller
6 7 Sabrina (1995) Comedy|Romance
7 8 Tom and Huck (1995) Adventure|Children
8 9 Sudden Death (1995) Action
9 10 GoldenEye (1995) Action|Adventure|Thriller
10 11 American President, The (1995) Comedy|Drama|Romance
11 12 Dracula: Dead and Loving It (1995) Comedy|Horror
12 13 Balto (1995) Adventure|Animation|Children
13 14 Nixon (1995) Drama
14 15 Cutthroat Island (1995) Action|Adventure|Romance
15 16 Casino (1995) Crime|Drama
16 17 Sense and Sensibility (1995) Drama|Romance
17 18 Four Rooms (1995) Comedy
18 19 Ace Ventura: When Nature Calls (1995) Comedy
19 20 Money Train (1995) Action|Comedy|Crime|Drama|Thriller
20 21 Get Shorty (1995) Comedy|Crime|Thriller
21 22 Copycat (1995) Crime|Drama|Horror|Mystery|Thriller
22 23 Assassins (1995) Action|Crime|Thriller
23 24 Powder (1995) Drama|Sci-Fi
24 25 Leaving Las Vegas (1995) Drama|Romance
25 26 Othello (1995) Drama
26 27 Now and Then (1995) Children|Drama
27 28 Persuasion (1995) Drama|Romance
28 29 City of Lost Children, The (Cité des enfants p... Adventure|Drama|Fantasy|Mystery|Sci-Fi
29 30 Shanghai Triad (Yao a yao yao dao waipo qiao) ... Crime|Drama
... ... ... ...
9095 159690 Teenage Mutant Ninja Turtles: Out of the Shado... Action|Adventure|Comedy
9096 159755 Popstar: Never Stop Never Stopping (2016) Comedy
9097 159858 The Conjuring 2 (2016) Horror
9098 159972 Approaching the Unknown (2016) Drama|Sci-Fi|Thriller
9099 160080 Ghostbusters (2016) Action|Comedy|Horror|Sci-Fi
9100 160271 Central Intelligence (2016) Action|Comedy
9101 160438 Jason Bourne (2016) Action
9102 160440 The Maid's Room (2014) Thriller
9103 160563 The Legend of Tarzan (2016) Action|Adventure
9104 160565 The Purge: Election Year (2016) Action|Horror|Sci-Fi
9105 160567 Mike & Dave Need Wedding Dates (2016) Comedy
9106 160590 Survive and Advance (2013) (no genres listed)
9107 160656 Tallulah (2016) Drama
9108 160718 Piper (2016) Animation
9109 160954 Nerve (2016) Drama|Thriller
9110 161084 My Friend Rockefeller (2015) Documentary
9111 161155 Sunspring (2016) Sci-Fi
9112 161336 Author: The JT LeRoy Story (2016) Documentary
9113 161582 Hell or High Water (2016) Crime|Drama
9114 161594 Kingsglaive: Final Fantasy XV (2016) Action|Adventure|Animation|Drama|Fantasy|Sci-Fi
9115 161830 Body (2015) Drama|Horror|Thriller
9116 161918 Sharknado 4: The 4th Awakens (2016) Action|Adventure|Horror|Sci-Fi
9117 161944 The Last Brickmaker in America (2001) Drama
9118 162376 Stranger Things Drama
9119 162542 Rustom (2016) Romance|Thriller
9120 162672 Mohenjo Daro (2016) Adventure|Drama|Romance
9121 163056 Shin Godzilla (2016) Action|Adventure|Fantasy|Sci-Fi
9122 163949 The Beatles: Eight Days a Week - The Touring Y... Documentary
9123 164977 The Gay Desperado (1936) Comedy
9124 164979 Women of '69, Unboxed Documentary

9125 rows × 3 columns


In [7]:
# changing the movie and user ids so they are continuous integers
users = ratings.userId.unique()
movies = ratings.movieId.unique()

In [8]:
userid2idx = {o:i for i,o in enumerate(users)}

In [9]:
userid2idx


Out[9]:
{1: 0,
 2: 1,
 3: 2,
 4: 3,
 5: 4,
 6: 5,
 7: 6,
 8: 7,
 9: 8,
 10: 9,
 11: 10,
 12: 11,
 13: 12,
 14: 13,
 15: 14,
 16: 15,
 17: 16,
 18: 17,
 19: 18,
 20: 19,
 21: 20,
 22: 21,
 23: 22,
 24: 23,
 25: 24,
 26: 25,
 27: 26,
 28: 27,
 29: 28,
 30: 29,
 31: 30,
 32: 31,
 33: 32,
 34: 33,
 35: 34,
 36: 35,
 37: 36,
 38: 37,
 39: 38,
 40: 39,
 41: 40,
 42: 41,
 43: 42,
 44: 43,
 45: 44,
 46: 45,
 47: 46,
 48: 47,
 49: 48,
 50: 49,
 51: 50,
 52: 51,
 53: 52,
 54: 53,
 55: 54,
 56: 55,
 57: 56,
 58: 57,
 59: 58,
 60: 59,
 61: 60,
 62: 61,
 63: 62,
 64: 63,
 65: 64,
 66: 65,
 67: 66,
 68: 67,
 69: 68,
 70: 69,
 71: 70,
 72: 71,
 73: 72,
 74: 73,
 75: 74,
 76: 75,
 77: 76,
 78: 77,
 79: 78,
 80: 79,
 81: 80,
 82: 81,
 83: 82,
 84: 83,
 85: 84,
 86: 85,
 87: 86,
 88: 87,
 89: 88,
 90: 89,
 91: 90,
 92: 91,
 93: 92,
 94: 93,
 95: 94,
 96: 95,
 97: 96,
 98: 97,
 99: 98,
 100: 99,
 101: 100,
 102: 101,
 103: 102,
 104: 103,
 105: 104,
 106: 105,
 107: 106,
 108: 107,
 109: 108,
 110: 109,
 111: 110,
 112: 111,
 113: 112,
 114: 113,
 115: 114,
 116: 115,
 117: 116,
 118: 117,
 119: 118,
 120: 119,
 121: 120,
 122: 121,
 123: 122,
 124: 123,
 125: 124,
 126: 125,
 127: 126,
 128: 127,
 129: 128,
 130: 129,
 131: 130,
 132: 131,
 133: 132,
 134: 133,
 135: 134,
 136: 135,
 137: 136,
 138: 137,
 139: 138,
 140: 139,
 141: 140,
 142: 141,
 143: 142,
 144: 143,
 145: 144,
 146: 145,
 147: 146,
 148: 147,
 149: 148,
 150: 149,
 151: 150,
 152: 151,
 153: 152,
 154: 153,
 155: 154,
 156: 155,
 157: 156,
 158: 157,
 159: 158,
 160: 159,
 161: 160,
 162: 161,
 163: 162,
 164: 163,
 165: 164,
 166: 165,
 167: 166,
 168: 167,
 169: 168,
 170: 169,
 171: 170,
 172: 171,
 173: 172,
 174: 173,
 175: 174,
 176: 175,
 177: 176,
 178: 177,
 179: 178,
 180: 179,
 181: 180,
 182: 181,
 183: 182,
 184: 183,
 185: 184,
 186: 185,
 187: 186,
 188: 187,
 189: 188,
 190: 189,
 191: 190,
 192: 191,
 193: 192,
 194: 193,
 195: 194,
 196: 195,
 197: 196,
 198: 197,
 199: 198,
 200: 199,
 201: 200,
 202: 201,
 203: 202,
 204: 203,
 205: 204,
 206: 205,
 207: 206,
 208: 207,
 209: 208,
 210: 209,
 211: 210,
 212: 211,
 213: 212,
 214: 213,
 215: 214,
 216: 215,
 217: 216,
 218: 217,
 219: 218,
 220: 219,
 221: 220,
 222: 221,
 223: 222,
 224: 223,
 225: 224,
 226: 225,
 227: 226,
 228: 227,
 229: 228,
 230: 229,
 231: 230,
 232: 231,
 233: 232,
 234: 233,
 235: 234,
 236: 235,
 237: 236,
 238: 237,
 239: 238,
 240: 239,
 241: 240,
 242: 241,
 243: 242,
 244: 243,
 245: 244,
 246: 245,
 247: 246,
 248: 247,
 249: 248,
 250: 249,
 251: 250,
 252: 251,
 253: 252,
 254: 253,
 255: 254,
 256: 255,
 257: 256,
 258: 257,
 259: 258,
 260: 259,
 261: 260,
 262: 261,
 263: 262,
 264: 263,
 265: 264,
 266: 265,
 267: 266,
 268: 267,
 269: 268,
 270: 269,
 271: 270,
 272: 271,
 273: 272,
 274: 273,
 275: 274,
 276: 275,
 277: 276,
 278: 277,
 279: 278,
 280: 279,
 281: 280,
 282: 281,
 283: 282,
 284: 283,
 285: 284,
 286: 285,
 287: 286,
 288: 287,
 289: 288,
 290: 289,
 291: 290,
 292: 291,
 293: 292,
 294: 293,
 295: 294,
 296: 295,
 297: 296,
 298: 297,
 299: 298,
 300: 299,
 301: 300,
 302: 301,
 303: 302,
 304: 303,
 305: 304,
 306: 305,
 307: 306,
 308: 307,
 309: 308,
 310: 309,
 311: 310,
 312: 311,
 313: 312,
 314: 313,
 315: 314,
 316: 315,
 317: 316,
 318: 317,
 319: 318,
 320: 319,
 321: 320,
 322: 321,
 323: 322,
 324: 323,
 325: 324,
 326: 325,
 327: 326,
 328: 327,
 329: 328,
 330: 329,
 331: 330,
 332: 331,
 333: 332,
 334: 333,
 335: 334,
 336: 335,
 337: 336,
 338: 337,
 339: 338,
 340: 339,
 341: 340,
 342: 341,
 343: 342,
 344: 343,
 345: 344,
 346: 345,
 347: 346,
 348: 347,
 349: 348,
 350: 349,
 351: 350,
 352: 351,
 353: 352,
 354: 353,
 355: 354,
 356: 355,
 357: 356,
 358: 357,
 359: 358,
 360: 359,
 361: 360,
 362: 361,
 363: 362,
 364: 363,
 365: 364,
 366: 365,
 367: 366,
 368: 367,
 369: 368,
 370: 369,
 371: 370,
 372: 371,
 373: 372,
 374: 373,
 375: 374,
 376: 375,
 377: 376,
 378: 377,
 379: 378,
 380: 379,
 381: 380,
 382: 381,
 383: 382,
 384: 383,
 385: 384,
 386: 385,
 387: 386,
 388: 387,
 389: 388,
 390: 389,
 391: 390,
 392: 391,
 393: 392,
 394: 393,
 395: 394,
 396: 395,
 397: 396,
 398: 397,
 399: 398,
 400: 399,
 401: 400,
 402: 401,
 403: 402,
 404: 403,
 405: 404,
 406: 405,
 407: 406,
 408: 407,
 409: 408,
 410: 409,
 411: 410,
 412: 411,
 413: 412,
 414: 413,
 415: 414,
 416: 415,
 417: 416,
 418: 417,
 419: 418,
 420: 419,
 421: 420,
 422: 421,
 423: 422,
 424: 423,
 425: 424,
 426: 425,
 427: 426,
 428: 427,
 429: 428,
 430: 429,
 431: 430,
 432: 431,
 433: 432,
 434: 433,
 435: 434,
 436: 435,
 437: 436,
 438: 437,
 439: 438,
 440: 439,
 441: 440,
 442: 441,
 443: 442,
 444: 443,
 445: 444,
 446: 445,
 447: 446,
 448: 447,
 449: 448,
 450: 449,
 451: 450,
 452: 451,
 453: 452,
 454: 453,
 455: 454,
 456: 455,
 457: 456,
 458: 457,
 459: 458,
 460: 459,
 461: 460,
 462: 461,
 463: 462,
 464: 463,
 465: 464,
 466: 465,
 467: 466,
 468: 467,
 469: 468,
 470: 469,
 471: 470,
 472: 471,
 473: 472,
 474: 473,
 475: 474,
 476: 475,
 477: 476,
 478: 477,
 479: 478,
 480: 479,
 481: 480,
 482: 481,
 483: 482,
 484: 483,
 485: 484,
 486: 485,
 487: 486,
 488: 487,
 489: 488,
 490: 489,
 491: 490,
 492: 491,
 493: 492,
 494: 493,
 495: 494,
 496: 495,
 497: 496,
 498: 497,
 499: 498,
 500: 499,
 501: 500,
 502: 501,
 503: 502,
 504: 503,
 505: 504,
 506: 505,
 507: 506,
 508: 507,
 509: 508,
 510: 509,
 511: 510,
 512: 511,
 513: 512,
 514: 513,
 515: 514,
 516: 515,
 517: 516,
 518: 517,
 519: 518,
 520: 519,
 521: 520,
 522: 521,
 523: 522,
 524: 523,
 525: 524,
 526: 525,
 527: 526,
 528: 527,
 529: 528,
 530: 529,
 531: 530,
 532: 531,
 533: 532,
 534: 533,
 535: 534,
 536: 535,
 537: 536,
 538: 537,
 539: 538,
 540: 539,
 541: 540,
 542: 541,
 543: 542,
 544: 543,
 545: 544,
 546: 545,
 547: 546,
 548: 547,
 549: 548,
 550: 549,
 551: 550,
 552: 551,
 553: 552,
 554: 553,
 555: 554,
 556: 555,
 557: 556,
 558: 557,
 559: 558,
 560: 559,
 561: 560,
 562: 561,
 563: 562,
 564: 563,
 565: 564,
 566: 565,
 567: 566,
 568: 567,
 569: 568,
 570: 569,
 571: 570,
 572: 571,
 573: 572,
 574: 573,
 575: 574,
 576: 575,
 577: 576,
 578: 577,
 579: 578,
 580: 579,
 581: 580,
 582: 581,
 583: 582,
 584: 583,
 585: 584,
 586: 585,
 587: 586,
 588: 587,
 589: 588,
 590: 589,
 591: 590,
 592: 591,
 593: 592,
 594: 593,
 595: 594,
 596: 595,
 597: 596,
 598: 597,
 599: 598,
 600: 599,
 601: 600,
 602: 601,
 603: 602,
 604: 603,
 605: 604,
 606: 605,
 607: 606,
 608: 607,
 609: 608,
 610: 609,
 611: 610,
 612: 611,
 613: 612,
 614: 613,
 615: 614,
 616: 615,
 617: 616,
 618: 617,
 619: 618,
 620: 619,
 621: 620,
 622: 621,
 623: 622,
 624: 623,
 625: 624,
 626: 625,
 627: 626,
 628: 627,
 629: 628,
 630: 629,
 631: 630,
 632: 631,
 633: 632,
 634: 633,
 635: 634,
 636: 635,
 637: 636,
 638: 637,
 639: 638,
 640: 639,
 641: 640,
 642: 641,
 643: 642,
 644: 643,
 645: 644,
 646: 645,
 647: 646,
 648: 647,
 649: 648,
 650: 649,
 651: 650,
 652: 651,
 653: 652,
 654: 653,
 655: 654,
 656: 655,
 657: 656,
 658: 657,
 659: 658,
 660: 659,
 661: 660,
 662: 661,
 663: 662,
 664: 663,
 665: 664,
 666: 665,
 667: 666,
 668: 667,
 669: 668,
 670: 669,
 671: 670}

In [10]:
movieid2idx = {o:i for i,o in enumerate(movies)}

In [11]:
ratings.movieId = ratings.movieId.apply(lambda x: movieid2idx[x])
ratings.userId = ratings.userId.apply(lambda x: userid2idx[x])

In [12]:
ratings


Out[12]:
userId movieId rating timestamp
0 0 0 2.5 1260759144
1 0 1 3.0 1260759179
2 0 2 3.0 1260759182
3 0 3 2.0 1260759185
4 0 4 4.0 1260759205
5 0 5 2.0 1260759151
6 0 6 2.0 1260759187
7 0 7 2.0 1260759148
8 0 8 3.5 1260759125
9 0 9 2.0 1260759131
10 0 10 2.5 1260759135
11 0 11 1.0 1260759203
12 0 12 4.0 1260759191
13 0 13 4.0 1260759139
14 0 14 3.0 1260759194
15 0 15 2.0 1260759198
16 0 16 2.0 1260759108
17 0 17 2.5 1260759113
18 0 18 1.0 1260759200
19 0 19 3.0 1260759117
20 1 20 4.0 835355493
21 1 21 5.0 835355681
22 1 22 5.0 835355604
23 1 23 4.0 835355552
24 1 24 4.0 835355586
25 1 25 3.0 835356031
26 1 26 3.0 835355749
27 1 27 4.0 835355532
28 1 28 3.0 835356016
29 1 29 5.0 835355395
... ... ... ... ...
99974 670 473 4.5 1064245493
99975 670 354 5.0 1064245548
99976 670 355 3.5 1065111985
99977 670 5577 4.0 1065111973
99978 670 477 5.0 1064245488
99979 670 478 5.0 1065111996
99980 670 358 4.5 1065111855
99981 670 479 4.5 1064245471
99982 670 480 5.0 1064245483
99983 670 359 4.0 1064891537
99984 670 1225 2.0 1066793004
99985 670 1240 2.0 1065111990
99986 670 361 3.0 1065112004
99987 670 126 4.0 1065111863
99988 670 1260 4.0 1064245557
99989 670 483 4.5 1064891627
99990 670 362 3.0 1064891549
99991 670 127 4.0 1063502711
99992 670 364 4.0 1065111963
99993 670 1299 3.5 1064245507
99994 670 412 5.0 1063502716
99995 670 486 4.0 1064890625
99996 670 1308 4.5 1064245387
99997 670 365 4.0 1066793014
99998 670 2930 2.5 1065149436
99999 670 7005 2.5 1065579370
100000 670 4771 4.0 1065149201
100001 670 1329 4.0 1070940363
100002 670 1331 2.5 1070979663
100003 670 2946 3.5 1074784724

100004 rows × 4 columns


In [13]:
ratings.shape


Out[13]:
(100004, 4)

In [14]:
# split into training and test
msk = np.random.choice([True, False], size = 100004, p=[0.8,0.2])

In [15]:
sum(msk)


Out[15]:
80116

In [16]:
training_data = ratings[msk]
val_data = ratings[~msk]

In [17]:
val_data.shape


Out[17]:
(19888, 4)

In [18]:
??Embedding

Dot product


In [19]:
user_in = Input(shape=(1,), dtype='int64', name='user_in')

In [20]:
user_in


Out[20]:
user_in

In [21]:
type(user_in)


Out[21]:
theano.tensor.var.TensorVariable

In [22]:
Embedding(ratings.userId.nunique(), 50, input_length=1, W_regularizer=l2(1e-4))


Out[22]:
<keras.layers.embeddings.Embedding at 0x7fe214965610>

In [23]:
user_embedding = Embedding(training_data.userId.nunique(), 50, input_length=1, W_regularizer=l2(1e-4))(user_in)

In [24]:
user_embedding


Out[24]:
Reshape{3}.0

In [25]:
type(user_embedding)


Out[25]:
theano.tensor.var.TensorVariable

In [26]:
type(Sequential())


Out[26]:
keras.models.Sequential

In [27]:
movie_in = Input(shape=(1,), dtype='int64', name='movie_in')

In [28]:
movie_embedding = Embedding(ratings.movieId.nunique(), 50, input_length=1, W_regularizer=l2(1e-4))(movie_in)

In [29]:
x = merge([user_embedding, movie_embedding], mode = 'dot')

In [30]:
x


Out[30]:
Reshape{3}.0

In [31]:
x = Flatten()(x)

In [32]:
x


Out[32]:
Reshape{2}.0

In [33]:
model = Model([user_in, movie_in], x)

In [34]:
model.compile(Adam(0.001), loss='mse')

In [35]:
model.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
user_in (InputLayer)             (None, 1)             0                                            
____________________________________________________________________________________________________
movie_in (InputLayer)            (None, 1)             0                                            
____________________________________________________________________________________________________
embedding_2 (Embedding)          (None, 1, 50)         33550       user_in[0][0]                    
____________________________________________________________________________________________________
embedding_3 (Embedding)          (None, 1, 50)         453300      movie_in[0][0]                   
____________________________________________________________________________________________________
merge_1 (Merge)                  (None, 1, 1)          0           embedding_2[0][0]                
                                                                   embedding_3[0][0]                
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1)             0           merge_1[0][0]                    
====================================================================================================
Total params: 486,850
Trainable params: 486,850
Non-trainable params: 0
____________________________________________________________________________________________________

In [36]:
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, nb_epoch=1, verbose=2,
          validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/1
7s - loss: 9.9186 - val_loss: 4.2839
Out[36]:
<keras.callbacks.History at 0x7fe20e498d10>

In [37]:
model.optimizer.lr=0.01

In [39]:
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, nb_epoch=1, verbose=2,
          validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/1
7s - loss: 3.1207 - val_loss: 2.8445
Out[39]:
<keras.callbacks.History at 0x7fe20d475bd0>

In [40]:
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, nb_epoch=2, verbose=2,
          validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/2
7s - loss: 2.4113 - val_loss: 2.6472
Epoch 2/2
7s - loss: 2.2535 - val_loss: 2.6020
Out[40]:
<keras.callbacks.History at 0x7fe20d475a10>

In [41]:
model.optimizer.lr=0.001
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, nb_epoch=6, verbose=2,
          validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/6
7s - loss: 2.1946 - val_loss: 2.5850
Epoch 2/6
7s - loss: 2.1625 - val_loss: 2.5844
Epoch 3/6
7s - loss: 2.1375 - val_loss: 2.5812
Epoch 4/6
7s - loss: 2.1128 - val_loss: 2.5873
Epoch 5/6
7s - loss: 2.0879 - val_loss: 2.5890
Epoch 6/6
7s - loss: 2.0627 - val_loss: 2.5946
Out[41]:
<keras.callbacks.History at 0x7fe20d475950>

Bias


In [47]:
def embedding_input(name, n_in, n_out, reg):
    inp = Input(shape=(1,), dtype='int64', name=name)
    return inp, Embedding(n_in, n_out, input_length=1, W_regularizer=l2(reg))(inp)

In [48]:
user_in, u = embedding_input('user_in', ratings.userId.nunique(), 50, 1e-4)
movie_in, m = embedding_input('movie_in', ratings.movieId.nunique(), 50, 1e-4)

In [49]:
def create_bias(inp, n_in):
    x = Embedding(n_in, 1, input_length=1)(inp)
    return Flatten()(x)

In [58]:
u_b = create_bias(user_in, ratings.userId.nunique())
m_b = create_bias(movie_in, ratings.movieId.nunique())

In [59]:
x = merge([u, m], mode = 'dot')

In [60]:
x = Flatten()(x)

In [61]:
x = merge([x, u_b], mode = 'sum')

In [62]:
x = merge([x, m_b], mode = 'sum')

In [65]:
model = Model([user_in, movie_in], x)

In [66]:
model.compile(Adam(0.001), loss='mse')

In [67]:
model.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
user_in (InputLayer)             (None, 1)             0                                            
____________________________________________________________________________________________________
movie_in (InputLayer)            (None, 1)             0                                            
____________________________________________________________________________________________________
embedding_6 (Embedding)          (None, 1, 50)         33550       user_in[0][0]                    
____________________________________________________________________________________________________
embedding_7 (Embedding)          (None, 1, 50)         453300      movie_in[0][0]                   
____________________________________________________________________________________________________
merge_6 (Merge)                  (None, 1, 1)          0           embedding_6[0][0]                
                                                                   embedding_7[0][0]                
____________________________________________________________________________________________________
embedding_10 (Embedding)         (None, 1, 1)          671         user_in[0][0]                    
____________________________________________________________________________________________________
flatten_8 (Flatten)              (None, 1)             0           merge_6[0][0]                    
____________________________________________________________________________________________________
flatten_6 (Flatten)              (None, 1)             0           embedding_10[0][0]               
____________________________________________________________________________________________________
embedding_11 (Embedding)         (None, 1, 1)          9066        movie_in[0][0]                   
____________________________________________________________________________________________________
merge_7 (Merge)                  (None, 1)             0           flatten_8[0][0]                  
                                                                   flatten_6[0][0]                  
____________________________________________________________________________________________________
flatten_7 (Flatten)              (None, 1)             0           embedding_11[0][0]               
____________________________________________________________________________________________________
merge_8 (Merge)                  (None, 1)             0           merge_7[0][0]                    
                                                                   flatten_7[0][0]                  
====================================================================================================
Total params: 496,587
Trainable params: 496,587
Non-trainable params: 0
____________________________________________________________________________________________________

In [69]:
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, verbose=2, nb_epoch=1,
         validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/1
6s - loss: 8.9597 - val_loss: 3.5969
Out[69]:
<keras.callbacks.History at 0x7fe20c442750>

In [71]:
model.optimizer.lr=0.01
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, verbose=2, nb_epoch=6,
         validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/6
6s - loss: 2.5970 - val_loss: 2.3228
Epoch 2/6
6s - loss: 1.9807 - val_loss: 2.1142
Epoch 3/6
6s - loss: 1.8173 - val_loss: 2.0222
Epoch 4/6
6s - loss: 1.7249 - val_loss: 1.9563
Epoch 5/6
6s - loss: 1.6464 - val_loss: 1.8829
Epoch 6/6
6s - loss: 1.5732 - val_loss: 1.8187
Out[71]:
<keras.callbacks.History at 0x7fe20bebc050>

In [72]:
model.optimizer.lr=0.001
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, verbose=2, nb_epoch=6,
         validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/6
6s - loss: 1.5030 - val_loss: 1.7575
Epoch 2/6
6s - loss: 1.4348 - val_loss: 1.6957
Epoch 3/6
6s - loss: 1.3677 - val_loss: 1.6354
Epoch 4/6
6s - loss: 1.3026 - val_loss: 1.5786
Epoch 5/6
6s - loss: 1.2399 - val_loss: 1.5263
Epoch 6/6
6s - loss: 1.1800 - val_loss: 1.4764
Out[72]:
<keras.callbacks.History at 0x7fe20c442b50>

In [73]:
model.fit([training_data.userId, training_data.movieId], training_data.rating,
          batch_size=64, verbose=2, nb_epoch=5,
         validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/5
6s - loss: 1.1224 - val_loss: 1.4281
Epoch 2/5
6s - loss: 1.0685 - val_loss: 1.3873
Epoch 3/5
6s - loss: 1.0163 - val_loss: 1.3435
Epoch 4/5
6s - loss: 0.9677 - val_loss: 1.3056
Epoch 5/5
6s - loss: 0.9213 - val_loss: 1.2700
Out[73]:
<keras.callbacks.History at 0x7fe20d23ab50>

Neural net


In [74]:
user_in, u = embedding_input('user_in', ratings.userId.nunique(), 50, 1e-4)
movie_in, m = embedding_input('movie_in', ratings.movieId.nunique(), 50, 1e-4)

In [75]:
x = merge([u, m], mode = 'concat')
x = Flatten()(x)
x = Dropout(0.3)(x)
x = Dense(70, activation='relu')(x)
x = Dropout(0.75)(x)
x = Dense(1)(x)
nn = Model([user_in, movie_in], x)
nn.compile(Adam(0.001), loss='mse')

In [76]:
nn.summary()


____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
user_in (InputLayer)             (None, 1)             0                                            
____________________________________________________________________________________________________
movie_in (InputLayer)            (None, 1)             0                                            
____________________________________________________________________________________________________
embedding_12 (Embedding)         (None, 1, 50)         33550       user_in[0][0]                    
____________________________________________________________________________________________________
embedding_13 (Embedding)         (None, 1, 50)         453300      movie_in[0][0]                   
____________________________________________________________________________________________________
merge_9 (Merge)                  (None, 1, 100)        0           embedding_12[0][0]               
                                                                   embedding_13[0][0]               
____________________________________________________________________________________________________
flatten_9 (Flatten)              (None, 100)           0           merge_9[0][0]                    
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 100)           0           flatten_9[0][0]                  
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 70)            7070        dropout_1[0][0]                  
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 70)            0           dense_1[0][0]                    
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 1)             71          dropout_2[0][0]                  
====================================================================================================
Total params: 493,991
Trainable params: 493,991
Non-trainable params: 0
____________________________________________________________________________________________________

In [77]:
nn.fit([training_data.userId, training_data.movieId], training_data.rating,
       batch_size=64, verbose=2, nb_epoch=5,
       validation_data=([val_data.userId, val_data.movieId], val_data.rating))


Train on 80116 samples, validate on 19888 samples
Epoch 1/5
8s - loss: 2.4699 - val_loss: 0.9121
Epoch 2/5
8s - loss: 1.4852 - val_loss: 0.8783
Epoch 3/5
8s - loss: 1.2341 - val_loss: 0.8515
Epoch 4/5
8s - loss: 1.0430 - val_loss: 0.8546
Epoch 5/5
8s - loss: 0.9138 - val_loss: 0.8333
Out[77]:
<keras.callbacks.History at 0x7fe2064f8a10>

In [ ]: