Imports


In [1]:
import sys 
import os
sys.path.append(os.path.join(os.getcwd(), '../Code/'))
from LadickyDataset import *

In [2]:
import tensorflow as tf
from keras.models import  Model, load_model
from keras.applications.vgg16 import VGG16
from keras.layers import Input , Flatten, Dense, Reshape, Lambda
from keras.layers.convolutional import Conv2D


Using TensorFlow backend.

In [3]:
from math import ceil

In [4]:
from PIL import Image

Utility Functions


In [5]:
def show_image(npimg):
    return Image.fromarray(npimg.astype(np.uint8))

In [6]:
def show_normals(npnorms):
    return Image.fromarray(((npnorms+1)/2*255).astype(np.uint8))

Dataset


In [7]:
file = '../Data/LadickyDataset.mat'

In [8]:
testNdxs = [1,2,9,14,15,16,17,18,21,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,46,47,56,57,59,60,61,62,63,76,77,78,79,84,85,86,87,88,89,90,91,117,118,119,125,126,127,128,129,131,132,133,134,137,153,154,155,167,168,169,171,172,173,174,175,176,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,207,208,209,210,211,212,220,221,222,250,264,271,272,273,279,280,281,282,283,284,285,296,297,298,299,300,301,302,310,311,312,315,316,317,325,326,327,328,329,330,331,332,333,334,335,351,352,355,356,357,358,359,360,361,362,363,364,384,385,386,387,388,389,390,395,396,397,411,412,413,414,430,431,432,433,434,435,441,442,443,444,445,446,447,448,462,463,464,465,466,469,470,471,472,473,474,475,476,477,508,509,510,511,512,513,515,516,517,518,519,520,521,522,523,524,525,526,531,532,533,537,538,539,549,550,551,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,579,580,581,582,583,591,592,593,594,603,604,605,606,607,612,613,617,618,619,620,621,633,634,635,636,637,638,644,645,650,651,656,657,658,663,664,668,669,670,671,672,673,676,677,678,679,680,681,686,687,688,689,690,693,694,697,698,699,706,707,708,709,710,711,712,713,717,718,724,725,726,727,728,731,732,733,734,743,744,759,760,761,762,763,764,765,766,767,768,769,770,771,772,773,774,775,776,777,778,779,780,781,782,783,784,785,786,787,800,801,802,803,804,810,811,812,813,814,821,822,823,833,834,835,836,837,838,839,840,841,842,843,844,845,846,850,851,852,857,858,859,860,861,862,869,870,871,906,907,908,917,918,919,926,927,928,932,933,934,935,945,946,947,959,960,961,962,965,966,967,970,971,972,973,974,975,976,977,991,992,993,994,995,1001,1002,1003,1004,1010,1011,1012,1021,1022,1023,1032,1033,1034,1038,1039,1048,1049,1052,1053,1057,1058,1075,1076,1077,1078,1079,1080,1081,1082,1083,1084,1088,1089,1090,1091,1092,1093,1094,1095,1096,1098,1099,1100,1101,1102,1103,1104,1106,1107,1108,1109,1117,1118,1119,1123,1124,1125,1126,1127,1128,1129,1130,1131,1135,1136,1144,1145,1146,1147,1148,1149,1150,1151,1152,1153,1154,1155,1156,1157,1158,1162,1163,1164,1165,1166,1167,1170,1171,1174,1175,1176,1179,1180,1181,1182,1183,1184,1192,1193,1194,1195,1196,1201,1202,1203,1204,1205,1206,1207,1208,1209,1210,1211,1212,1216,1217,1218,1219,1220,1226,1227,1228,1229,1230,1233,1234,1235,1247,1248,1249,1250,1254,1255,1256,1257,1258,1259,1260,1261,1262,1263,1264,1265,1275,1276,1277,1278,1279,1280,1285,1286,1287,1288,1289,1290,1291,1292,1293,1294,1295,1297,1298,1299,1302,1303,1304,1305,1306,1307,1308,1314,1315,1329,1330,1331,1332,1335,1336,1337,1338,1339,1340,1347,1348,1349,1353,1354,1355,1356,1364,1365,1368,1369,1384,1385,1386,1387,1388,1389,1390,1391,1394,1395,1396,1397,1398,1399,1400,1401,1407,1408,1409,1410,1411,1412,1413,1414,1421,1422,1423,1424,1430,1431,1432,1433,1441,1442,1443,1444,1445,1446,1447,1448]

In [9]:
dataset = LadickyDataset(file, testNdxs)

In [10]:
dataset.size


Out[10]:
653

Loss Function


In [11]:
def mean_dot_product(y_true, y_pred):
    dot = tf.einsum('ijkl,ijkl->ijk', y_true, y_pred) # Dot product
    n = tf.cast(tf.count_nonzero(dot),tf.float32)
    mean = tf.reduce_sum(dot) / n
    return -1 * mean

Model Architecture


In [12]:
def vgg16_model():
    # create model
    input_tensor = Input(shape=(240, 320, 3)) 
    base_model = VGG16(input_tensor=input_tensor,weights='imagenet', include_top=False)
    x = base_model.output
    x = Flatten()(x)
    x = Dense(4096, activation='relu', name='fc1')(x)
    x = Dense(80*60*3, activation='relu', name='fc2')(x)
    x = Reshape((60,80,3))(x)
    x = Lambda(lambda x: tf.image.resize_bilinear(x , [240,320]) )(x)
    pred = Lambda(lambda x: tf.nn.l2_normalize(x, 3) )(x)
    model = Model(inputs=base_model.input, outputs=pred)
    # Compile model
    model.compile(loss= mean_dot_product, optimizer='sgd')
    return model

Variables


In [13]:
images = np.empty([len(testNdxs), dataset.batch_height, dataset.batch_width, 3], dtype=np.float32)
normals = np.empty([len(testNdxs), dataset.batch_height, dataset.batch_width, 3], dtype=np.float32)
preds = np.empty([len(testNdxs), dataset.batch_height, dataset.batch_width, 3], dtype=np.float32)

In [14]:
model = load_model('../Data/vgg16-ladicky-model.h5', custom_objects={'mean_dot_product': mean_dot_product, 'tf':tf})

Main Loop


In [15]:
index = 0
for i in testNdxs:
    print('Index: '+str(i))
    images[index], normals[index] = dataset.get_data(i)
    preds[index] = model.predict_on_batch(images[index].reshape((1,dataset.batch_height, dataset.batch_width, 3 )))
    index += 1


Index: 1
Index: 2
Index: 9
Index: 14
Index: 15
Index: 16
Index: 17
Index: 18
Index: 21
Index: 28
Index: 29
Index: 30
Index: 31
Index: 32
Index: 33
Index: 34
Index: 35
Index: 36
Index: 37
Index: 38
Index: 39
Index: 40
Index: 41
Index: 42
Index: 43
Index: 46
Index: 47
Index: 56
Index: 57
Index: 59
Index: 60
Index: 61
Index: 62
Index: 63
Index: 76
Index: 77
Index: 78
Index: 79
Index: 84
Index: 85
Index: 86
Index: 87
Index: 88
Index: 89
Index: 90
Index: 91
Index: 117
Index: 118
Index: 119
Index: 125
Index: 126
Index: 127
Index: 128
Index: 129
Index: 131
Index: 132
Index: 133
Index: 134
Index: 137
Index: 153
Index: 154
Index: 155
Index: 167
Index: 168
Index: 169
Index: 171
Index: 172
Index: 173
Index: 174
Index: 175
Index: 176
Index: 180
Index: 181
Index: 182
Index: 183
Index: 184
Index: 185
Index: 186
Index: 187
Index: 188
Index: 189
Index: 190
Index: 191
Index: 192
Index: 193
Index: 194
Index: 195
Index: 196
Index: 197
Index: 198
Index: 199
Index: 200
Index: 201
Index: 202
Index: 207
Index: 208
Index: 209
Index: 210
Index: 211
Index: 212
Index: 220
Index: 221
Index: 222
Index: 250
Index: 264
Index: 271
Index: 272
Index: 273
Index: 279
Index: 280
Index: 281
Index: 282
Index: 283
Index: 284
Index: 285
Index: 296
Index: 297
Index: 298
Index: 299
Index: 300
Index: 301
Index: 302
Index: 310
Index: 311
Index: 312
Index: 315
Index: 316
Index: 317
Index: 325
Index: 326
Index: 327
Index: 328
Index: 329
Index: 330
Index: 331
Index: 332
Index: 333
Index: 334
Index: 335
Index: 351
Index: 352
Index: 355
Index: 356
Index: 357
Index: 358
Index: 359
Index: 360
Index: 361
Index: 362
Index: 363
Index: 364
Index: 384
Index: 385
Index: 386
Index: 387
Index: 388
Index: 389
Index: 390
Index: 395
Index: 396
Index: 397
Index: 411
Index: 412
Index: 413
Index: 414
Index: 430
Index: 431
Index: 432
Index: 433
Index: 434
Index: 435
Index: 441
Index: 442
Index: 443
Index: 444
Index: 445
Index: 446
Index: 447
Index: 448
Index: 462
Index: 463
Index: 464
Index: 465
Index: 466
Index: 469
Index: 470
Index: 471
Index: 472
Index: 473
Index: 474
Index: 475
Index: 476
Index: 477
Index: 508
Index: 509
Index: 510
Index: 511
Index: 512
Index: 513
Index: 515
Index: 516
Index: 517
Index: 518
Index: 519
Index: 520
Index: 521
Index: 522
Index: 523
Index: 524
Index: 525
Index: 526
Index: 531
Index: 532
Index: 533
Index: 537
Index: 538
Index: 539
Index: 549
Index: 550
Index: 551
Index: 555
Index: 556
Index: 557
Index: 558
Index: 559
Index: 560
Index: 561
Index: 562
Index: 563
Index: 564
Index: 565
Index: 566
Index: 567
Index: 568
Index: 569
Index: 570
Index: 571
Index: 579
Index: 580
Index: 581
Index: 582
Index: 583
Index: 591
Index: 592
Index: 593
Index: 594
Index: 603
Index: 604
Index: 605
Index: 606
Index: 607
Index: 612
Index: 613
Index: 617
Index: 618
Index: 619
Index: 620
Index: 621
Index: 633
Index: 634
Index: 635
Index: 636
Index: 637
Index: 638
Index: 644
Index: 645
Index: 650
Index: 651
Index: 656
Index: 657
Index: 658
Index: 663
Index: 664
Index: 668
Index: 669
Index: 670
Index: 671
Index: 672
Index: 673
Index: 676
Index: 677
Index: 678
Index: 679
Index: 680
Index: 681
Index: 686
Index: 687
Index: 688
Index: 689
Index: 690
Index: 693
Index: 694
Index: 697
Index: 698
Index: 699
Index: 706
Index: 707
Index: 708
Index: 709
Index: 710
Index: 711
Index: 712
Index: 713
Index: 717
Index: 718
Index: 724
Index: 725
Index: 726
Index: 727
Index: 728
Index: 731
Index: 732
Index: 733
Index: 734
Index: 743
Index: 744
Index: 759
Index: 760
Index: 761
Index: 762
Index: 763
Index: 764
Index: 765
Index: 766
Index: 767
Index: 768
Index: 769
Index: 770
Index: 771
Index: 772
Index: 773
Index: 774
Index: 775
Index: 776
Index: 777
Index: 778
Index: 779
Index: 780
Index: 781
Index: 782
Index: 783
Index: 784
Index: 785
Index: 786
Index: 787
Index: 800
Index: 801
Index: 802
Index: 803
Index: 804
Index: 810
Index: 811
Index: 812
Index: 813
Index: 814
Index: 821
Index: 822
Index: 823
Index: 833
Index: 834
Index: 835
Index: 836
Index: 837
Index: 838
Index: 839
Index: 840
Index: 841
Index: 842
Index: 843
Index: 844
Index: 845
Index: 846
Index: 850
Index: 851
Index: 852
Index: 857
Index: 858
Index: 859
Index: 860
Index: 861
Index: 862
Index: 869
Index: 870
Index: 871
Index: 906
Index: 907
Index: 908
Index: 917
Index: 918
Index: 919
Index: 926
Index: 927
Index: 928
Index: 932
Index: 933
Index: 934
Index: 935
Index: 945
Index: 946
Index: 947
Index: 959
Index: 960
Index: 961
Index: 962
Index: 965
Index: 966
Index: 967
Index: 970
Index: 971
Index: 972
Index: 973
Index: 974
Index: 975
Index: 976
Index: 977
Index: 991
Index: 992
Index: 993
Index: 994
Index: 995
Index: 1001
Index: 1002
Index: 1003
Index: 1004
Index: 1010
Index: 1011
Index: 1012
Index: 1021
Index: 1022
Index: 1023
Index: 1032
Index: 1033
Index: 1034
Index: 1038
Index: 1039
Index: 1048
Index: 1049
Index: 1052
Index: 1053
Index: 1057
Index: 1058
Index: 1075
Index: 1076
Index: 1077
Index: 1078
Index: 1079
Index: 1080
Index: 1081
Index: 1082
Index: 1083
Index: 1084
Index: 1088
Index: 1089
Index: 1090
Index: 1091
Index: 1092
Index: 1093
Index: 1094
Index: 1095
Index: 1096
Index: 1098
Index: 1099
Index: 1100
Index: 1101
Index: 1102
Index: 1103
Index: 1104
Index: 1106
Index: 1107
Index: 1108
Index: 1109
Index: 1117
Index: 1118
Index: 1119
Index: 1123
Index: 1124
Index: 1125
Index: 1126
Index: 1127
Index: 1128
Index: 1129
Index: 1130
Index: 1131
Index: 1135
Index: 1136
Index: 1144
Index: 1145
Index: 1146
Index: 1147
Index: 1148
Index: 1149
Index: 1150
Index: 1151
Index: 1152
Index: 1153
Index: 1154
Index: 1155
Index: 1156
Index: 1157
Index: 1158
Index: 1162
Index: 1163
Index: 1164
Index: 1165
Index: 1166
Index: 1167
Index: 1170
Index: 1171
Index: 1174
Index: 1175
Index: 1176
Index: 1179
Index: 1180
Index: 1181
Index: 1182
Index: 1183
Index: 1184
Index: 1192
Index: 1193
Index: 1194
Index: 1195
Index: 1196
Index: 1201
Index: 1202
Index: 1203
Index: 1204
Index: 1205
Index: 1206
Index: 1207
Index: 1208
Index: 1209
Index: 1210
Index: 1211
Index: 1212
Index: 1216
Index: 1217
Index: 1218
Index: 1219
Index: 1220
Index: 1226
Index: 1227
Index: 1228
Index: 1229
Index: 1230
Index: 1233
Index: 1234
Index: 1235
Index: 1247
Index: 1248
Index: 1249
Index: 1250
Index: 1254
Index: 1255
Index: 1256
Index: 1257
Index: 1258
Index: 1259
Index: 1260
Index: 1261
Index: 1262
Index: 1263
Index: 1264
Index: 1265
Index: 1275
Index: 1276
Index: 1277
Index: 1278
Index: 1279
Index: 1280
Index: 1285
Index: 1286
Index: 1287
Index: 1288
Index: 1289
Index: 1290
Index: 1291
Index: 1292
Index: 1293
Index: 1294
Index: 1295
Index: 1297
Index: 1298
Index: 1299
Index: 1302
Index: 1303
Index: 1304
Index: 1305
Index: 1306
Index: 1307
Index: 1308
Index: 1314
Index: 1315
Index: 1329
Index: 1330
Index: 1331
Index: 1332
Index: 1335
Index: 1336
Index: 1337
Index: 1338
Index: 1339
Index: 1340
Index: 1347
Index: 1348
Index: 1349
Index: 1353
Index: 1354
Index: 1355
Index: 1356
Index: 1364
Index: 1365
Index: 1368
Index: 1369
Index: 1384
Index: 1385
Index: 1386
Index: 1387
Index: 1388
Index: 1389
Index: 1390
Index: 1391
Index: 1394
Index: 1395
Index: 1396
Index: 1397
Index: 1398
Index: 1399
Index: 1400
Index: 1401
Index: 1407
Index: 1408
Index: 1409
Index: 1410
Index: 1411
Index: 1412
Index: 1413
Index: 1414
Index: 1421
Index: 1422
Index: 1423
Index: 1424
Index: 1430
Index: 1431
Index: 1432
Index: 1433
Index: 1441
Index: 1442
Index: 1443
Index: 1444
Index: 1445
Index: 1446
Index: 1447
Index: 1448

In [17]:
for i in range(len(testNdxs)):
        img = show_image(images[i])
        norm = show_normals(normals[i])
        pred = show_normals(preds[i])
        out = Image.new('RGB', (img.size[0],3*img.size[1]))
        out.paste(img.copy())
        out.paste(norm.copy(), (0,norm.size[1]))
        out.paste(pred.copy(), (0,norm.size[1]+pred.size[1]))
        out.save('../Data/Output/'+ str(i)+'.png')

In [18]:
from scipy.io import savemat

In [19]:
savemat('../Data/pred.mat',{'Predictions': preds, 'Normals': normals})

Code


In [1]:
%%writefile ../Code/Experiments/Prediction.py
# Imports
import tensorflow as tf
from keras.models import load_model
import numpy as np
from PIL import Image
from scipy.io import savemat

# Utility functions
def show_image(npimg):
    return Image.fromarray(npimg.astype(np.uint8))
def show_normals(npnorms):
    return Image.fromarray(((npnorms+1)/2*255).astype(np.uint8))

# Loss function
def mean_dot_product(y_true, y_pred):
    dot = tf.einsum('ijkl,ijkl->ijk', y_true, y_pred) # Dot product
    n = tf.cast(tf.count_nonzero(dot),tf.float32)
    mean = tf.reduce_sum(dot) / n
    return -1 * mean

# Prediction
def Predict(ID, Dataset):
    
    # Load data set
    print('Loading the data set...')
    dataset = Dataset()
    
    # Load model
    print('Loading the model...')
    model = load_model('Experiments/Outputs/'+ ID + '.h5', custom_objects={'mean_dot_product': mean_dot_product, 'tf':tf})
    
    # Variables
    images = np.empty([dataset.size, dataset.batch_height, dataset.batch_width, 3], dtype=np.float32)
    normals = np.empty([dataset.size, dataset.batch_height, dataset.batch_width, 3], dtype=np.float32)
    preds = np.empty([dataset.size, dataset.batch_height, dataset.batch_width, 3], dtype=np.float32)
    
    # Prediction Loop
    print('Normal Estimation...')
    index = 0
    for i in dataset.validIndices:
        print('Index: '+str(i))
        images[index], normals[index] = dataset.get_data(i)
        preds[index] = model.predict_on_batch(images[index].reshape((1,dataset.batch_height, dataset.batch_width, 3 )))
        index += 1
    
    # Saving the result
    for i in range(dataset.size):
        img = show_image(images[i])
        norm = show_normals(normals[i])
        pred = show_normals(preds[i])
        out = Image.new('RGB', (img.size[0],3*img.size[1]))
        out.paste(img.copy())
        out.paste(norm.copy(), (0,norm.size[1]))
        out.paste(pred.copy(), (0,norm.size[1]+pred.size[1]))
        out.save('Experiments/Outputs/'+ID+'/'+str(i)+'.png')
        
    savemat('Experiments/Outputs/'+ ID + '.mat',{'Predictions': preds, 'Normals': normals})


Overwriting ../Code/Experiments/Prediction.py

In [ ]: