In [0]:
rm -r data

In [139]:
!mkdir data && wget https://github.com/shawngraham/crane-experiments/raw/master/one-shot-classification/aerial.zip && unzip aerial.zip -d data/


--2019-03-19 19:23:18--  https://github.com/shawngraham/crane-experiments/raw/master/one-shot-classification/aerial.zip
Resolving github.com (github.com)... 140.82.118.4, 140.82.118.3
Connecting to github.com (github.com)|140.82.118.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/shawngraham/crane-experiments/master/one-shot-classification/aerial.zip [following]
--2019-03-19 19:23:18--  https://raw.githubusercontent.com/shawngraham/crane-experiments/master/one-shot-classification/aerial.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10834671 (10M) [application/zip]
Saving to: ‘aerial.zip’

aerial.zip          100%[===================>]  10.33M  --.-KB/s    in 0.06s   

2019-03-19 19:23:18 (160 MB/s) - ‘aerial.zip’ saved [10834671/10834671]

Archive:  aerial.zip
   creating: data/earthworks-pics/
  inflating: data/earthworks-pics/.DS_Store  
   creating: data/__MACOSX/
   creating: data/__MACOSX/earthworks-pics/
  inflating: data/__MACOSX/earthworks-pics/._.DS_Store  
   creating: data/earthworks-pics/first/
  inflating: data/earthworks-pics/first/640f016j.jpg  
  inflating: data/earthworks-pics/first/640f017j.jpg  
  inflating: data/earthworks-pics/first/640f23j.jpg  
  inflating: data/earthworks-pics/first/640f24j.jpg  
  inflating: data/earthworks-pics/first/640f25j.jpg  
  inflating: data/earthworks-pics/first/640f29j.jpg  
  inflating: data/earthworks-pics/first/640f30j.jpg  
   creating: data/earthworks-pics/second/
  inflating: data/earthworks-pics/second/640f31j.jpg  
  inflating: data/earthworks-pics/second/640f32j.jpg  
  inflating: data/earthworks-pics/second/640f33j.jpg  
  inflating: data/earthworks-pics/second/640f34j.jpg  
  inflating: data/earthworks-pics/second/640f35j.jpg  
  inflating: data/earthworks-pics/second/640f36j.jpg  
  inflating: data/earthworks-pics/second/640f37j.jpg  
  inflating: data/earthworks-pics/second/640f38j.jpg  
  inflating: data/earthworks-pics/second/640f40j.jpg  
  inflating: data/earthworks-pics/second/640f41j.jpg  
   creating: data/earthworks-pics/third/
  inflating: data/earthworks-pics/third/640f043j.jpg  
  inflating: data/earthworks-pics/third/640f044j.jpg  
  inflating: data/earthworks-pics/third/640f63j.jpg  
  inflating: data/earthworks-pics/third/640f64j.jpg  
  inflating: data/earthworks-pics/third/640f65j.jpg  
  inflating: data/earthworks-pics/third/640f66j.jpg  
  inflating: data/earthworks-pics/third/640f67j.jpg  
  inflating: data/earthworks-pics/third/640f74j.jpg  
  inflating: data/earthworks-pics/third/640f75j.jpg  
  inflating: data/earthworks-pics/third/640f80j.jpg  
  inflating: data/earthworks-pics/third/640f86j.jpg  
  inflating: data/earthworks-pics/third/640f87j.jpg  
  inflating: data/earthworks-pics/third/640f97j.jpg  
  inflating: data/earthworks-pics/third/640fj10j.jpg  
  inflating: data/earthworks-pics/third/640fj12j.jpg  
   creating: data/native-settlement/
  inflating: data/native-settlement/.DS_Store  
   creating: data/__MACOSX/native-settlement/
  inflating: data/__MACOSX/native-settlement/._.DS_Store  
   creating: data/native-settlement/first/
  inflating: data/native-settlement/first/640f21j.jpg  
  inflating: data/native-settlement/first/640f22j.jpg  
  inflating: data/native-settlement/first/640f23j.jpg  
  inflating: data/native-settlement/first/640f24j.jpg  
  inflating: data/native-settlement/first/640f29j.jpg  
  inflating: data/native-settlement/first/640f30j.jpg  
  inflating: data/native-settlement/first/640f31j.jpg  
  inflating: data/native-settlement/first/640f35j.jpg  
  inflating: data/native-settlement/first/640f36j.jpg  
  inflating: data/native-settlement/first/640f43j.jpg  
  inflating: data/native-settlement/first/640f44j.jpg  
  inflating: data/native-settlement/first/640f45j.jpg  
  inflating: data/native-settlement/first/640f46j.jpg  
  inflating: data/native-settlement/first/640f47j.jpg  
  inflating: data/native-settlement/first/640f48j.jpg  
  inflating: data/native-settlement/first/640f49j.jpg  
   creating: data/native-settlement/second/
  inflating: data/native-settlement/second/640f50j.jpg  
  inflating: data/native-settlement/second/640f51j.jpg  
  inflating: data/native-settlement/second/640f52j.jpg  
  inflating: data/native-settlement/second/640f53j.jpg  
  inflating: data/native-settlement/second/640f54j.jpg  
  inflating: data/native-settlement/second/640f55j.jpg  
  inflating: data/native-settlement/second/640f56j.jpg  
   creating: data/native-settlement/third/
  inflating: data/native-settlement/third/640f57j.jpg  
  inflating: data/native-settlement/third/640f58j.jpg  
  inflating: data/native-settlement/third/640f59j.jpg  
  inflating: data/native-settlement/third/640f60j.jpg  
  inflating: data/native-settlement/third/640f61j.jpg  
  inflating: data/native-settlement/third/640f62j.jpg  
  inflating: data/native-settlement/third/640f63j.jpg  
  inflating: data/native-settlement/third/640f64j.jpg  
   creating: data/roman-fort-pics/
  inflating: data/roman-fort-pics/.DS_Store  
   creating: data/__MACOSX/roman-fort-pics/
  inflating: data/__MACOSX/roman-fort-pics/._.DS_Store  
   creating: data/roman-fort-pics/first/
  inflating: data/roman-fort-pics/first/640f21j.jpg  
  inflating: data/roman-fort-pics/first/640f22j.jpg  
  inflating: data/roman-fort-pics/first/640f23j.jpg  
  inflating: data/roman-fort-pics/first/640f24j.jpg  
  inflating: data/roman-fort-pics/first/640f29j.jpg  
  inflating: data/roman-fort-pics/first/640f30j.jpg  
  inflating: data/roman-fort-pics/first/640f31j.jpg  
  inflating: data/roman-fort-pics/first/640f35j.jpg  
  inflating: data/roman-fort-pics/first/640f36j.jpg  
  inflating: data/roman-fort-pics/first/640f43j.jpg  
   creating: data/roman-fort-pics/second/
  inflating: data/roman-fort-pics/second/640f44j.jpg  
  inflating: data/roman-fort-pics/second/640f45j.jpg  
  inflating: data/roman-fort-pics/second/640f46j.jpg  
  inflating: data/roman-fort-pics/second/640f47j.jpg  
  inflating: data/roman-fort-pics/second/640f48j.jpg  
  inflating: data/roman-fort-pics/second/640f49j.jpg  
  inflating: data/roman-fort-pics/second/640f50j.jpg  
  inflating: data/roman-fort-pics/second/640f51j.jpg  
  inflating: data/roman-fort-pics/second/640f52j.jpg  
  inflating: data/roman-fort-pics/second/640f53j.jpg  
  inflating: data/roman-fort-pics/second/640f54j.jpg  
  inflating: data/roman-fort-pics/second/640f55j.jpg  
  inflating: data/roman-fort-pics/second/640f56j.jpg  
  inflating: data/roman-fort-pics/second/640f57j.jpg  
  inflating: data/roman-fort-pics/second/640f58j.jpg  
   creating: data/roman-fort-pics/third/
  inflating: data/roman-fort-pics/third/640f59j.jpg  
  inflating: data/roman-fort-pics/third/640f60j.jpg  
  inflating: data/roman-fort-pics/third/640f61j.jpg  
  inflating: data/roman-fort-pics/third/640f62j.jpg  
  inflating: data/roman-fort-pics/third/640f63j.jpg  
  inflating: data/roman-fort-pics/third/640f64j.jpg  
  inflating: data/roman-fort-pics/third/640f65j.jpg  
  inflating: data/roman-fort-pics/third/640f66j.jpg  
  inflating: data/roman-fort-pics/third/640f67j.jpg  
  inflating: data/roman-fort-pics/third/640f68j.jpg  
  inflating: data/roman-fort-pics/third/640f69j.jpg  
  inflating: data/roman-fort-pics/third/640f70j.jpg  
  inflating: data/roman-fort-pics/third/640f71j.jpg  
  inflating: data/roman-fort-pics/third/640f72j.jpg  
  inflating: data/roman-fort-pics/third/640f84j.jpg  
  inflating: data/roman-fort-pics/third/640f85j.jpg  
  inflating: data/roman-fort-pics/third/640f90j.jpg  
  inflating: data/roman-fort-pics/third/640f91j.jpg  
  inflating: data/roman-fort-pics/third/640f92j.jpg  
  inflating: data/roman-fort-pics/third/640f93j.jpg  
  inflating: data/roman-fort-pics/third/640f94j.jpg  
  inflating: data/roman-fort-pics/third/640f95j.jpg  
  inflating: data/roman-fort-pics/third/640f96j.jpg  
  inflating: data/roman-fort-pics/third/640f98j.jpg  
  inflating: data/roman-fort-pics/third/640f99j.jpg  
  inflating: data/roman-fort-pics/third/640j24j.jpg  
  inflating: data/roman-fort-pics/third/640j40j.jpg  
  inflating: data/roman-fort-pics/third/640j41j.jpg  
  inflating: data/roman-fort-pics/third/640k12j.jpg  
  inflating: data/roman-fort-pics/third/640k13j.jpg  

In [0]:
rm aerial.zip.1 aerial.zip.2 aerial.zip.3 data.zip

In [115]:
ls


aerial.zip  earthworks-pics/  native-settlement/
data/       __MACOSX/         roman-fort-pics/

In [116]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))


Found GPU at: /device:GPU:0

In [0]:
%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [0]:
def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

In [0]:
class Config():
    training_dir = "data/roman-fort-pics/"
    testing_dir = "data/native-settlement/"
    train_batch_size = 64
    train_number_epochs = 100

In [0]:
class SiameseNetworkDataset(Dataset):
    
    def __init__(self,imageFolderDataset,transform=None,should_invert=True):
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        self.should_invert = should_invert
        
    def __getitem__(self,index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            while True:
                #keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1]==img1_tuple[1]:
                    break
        else:
            while True:
                #keep looping till a different class image is found
                
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] !=img1_tuple[1]:
                    break

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])
        img0 = img0.convert("L")
        img1 = img1.convert("L")
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        
        return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [0]:
folder_dataset = dset.ImageFolder(root=Config.training_dir)

In [0]:
ls


aerial.zip  aerial.zip.1  data/  data.zip  sample_data/

In [0]:
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
                                        transform=transforms.Compose([transforms.Resize((100,100)),
                                                                      transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False)

In [143]:
vis_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=8)
dataiter = iter(vis_dataloader)


example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())


[[1.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]]

In [0]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),


            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),


        )

        self.fc1 = nn.Sequential(
            nn.Linear(8*100*100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5))

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

In [0]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [0]:
train_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=Config.train_batch_size)

In [0]:
net = SiameseNetwork().cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )

In [0]:
counter = []
loss_history = [] 
iteration_number= 0

In [148]:
for epoch in range(0,Config.train_number_epochs):
    for i, data in enumerate(train_dataloader,0):
        img0, img1 , label = data
        img0, img1 , label = img0.cuda(), img1.cuda() , label.cuda()
        optimizer.zero_grad()
        output1,output2 = net(img0,img1)
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.item()))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.item())
show_plot(counter,loss_history)


Epoch number 0
 Current loss 2.466195583343506

Epoch number 1
 Current loss 51.4345703125

Epoch number 2
 Current loss 4.164605617523193

Epoch number 3
 Current loss 7.120481491088867

Epoch number 4
 Current loss 2.965501546859741

Epoch number 5
 Current loss 2.3786025047302246

Epoch number 6
 Current loss 2.877575397491455

Epoch number 7
 Current loss 3.062469005584717

Epoch number 8
 Current loss 1.967812418937683

Epoch number 9
 Current loss 2.2690696716308594

Epoch number 10
 Current loss 1.5769760608673096

Epoch number 11
 Current loss 2.0668704509735107

Epoch number 12
 Current loss 1.834489107131958

Epoch number 13
 Current loss 1.5606380701065063

Epoch number 14
 Current loss 2.2070980072021484

Epoch number 15
 Current loss 1.5467792749404907

Epoch number 16
 Current loss 1.4089213609695435

Epoch number 17
 Current loss 1.7084054946899414

Epoch number 18
 Current loss 1.5892726182937622

Epoch number 19
 Current loss 1.2021260261535645

Epoch number 20
 Current loss 1.9634219408035278

Epoch number 21
 Current loss 1.8621584177017212

Epoch number 22
 Current loss 1.735438585281372

Epoch number 23
 Current loss 1.9115698337554932

Epoch number 24
 Current loss 1.484889268875122

Epoch number 25
 Current loss 1.967524528503418

Epoch number 26
 Current loss 1.284722089767456

Epoch number 27
 Current loss 1.4302233457565308

Epoch number 28
 Current loss 1.9330520629882812

Epoch number 29
 Current loss 1.3384604454040527

Epoch number 30
 Current loss 1.1989768743515015

Epoch number 31
 Current loss 1.2199060916900635

Epoch number 32
 Current loss 2.6282706260681152

Epoch number 33
 Current loss 1.5939826965332031

Epoch number 34
 Current loss 1.4789856672286987

Epoch number 35
 Current loss 1.5218342542648315

Epoch number 36
 Current loss 1.3080195188522339

Epoch number 37
 Current loss 2.977379560470581

Epoch number 38
 Current loss 1.1785558462142944

Epoch number 39
 Current loss 1.3754465579986572

Epoch number 40
 Current loss 1.3098615407943726

Epoch number 41
 Current loss 1.5119662284851074

Epoch number 42
 Current loss 1.42640221118927

Epoch number 43
 Current loss 1.373023271560669

Epoch number 44
 Current loss 1.3345496654510498

Epoch number 45
 Current loss 1.4021955728530884

Epoch number 46
 Current loss 1.1908115148544312

Epoch number 47
 Current loss 1.35209059715271

Epoch number 48
 Current loss 1.2044925689697266

Epoch number 49
 Current loss 1.1204413175582886

Epoch number 50
 Current loss 1.226791262626648

Epoch number 51
 Current loss 1.307025671005249

Epoch number 52
 Current loss 1.2170758247375488

Epoch number 53
 Current loss 1.124749779701233

Epoch number 54
 Current loss 1.1314074993133545

Epoch number 55
 Current loss 1.1590628623962402

Epoch number 56
 Current loss 1.197157859802246

Epoch number 57
 Current loss 1.3361177444458008

Epoch number 58
 Current loss 1.1540579795837402

Epoch number 59
 Current loss 1.1770998239517212

Epoch number 60
 Current loss 1.5406588315963745

Epoch number 61
 Current loss 1.0822186470031738

Epoch number 62
 Current loss 1.1273560523986816

Epoch number 63
 Current loss 1.1571528911590576

Epoch number 64
 Current loss 1.3783022165298462

Epoch number 65
 Current loss 1.132799744606018

Epoch number 66
 Current loss 1.300498127937317

Epoch number 67
 Current loss 1.1353871822357178

Epoch number 68
 Current loss 1.1412829160690308

Epoch number 69
 Current loss 1.4469434022903442

Epoch number 70
 Current loss 1.2158079147338867

Epoch number 71
 Current loss 1.422668695449829

Epoch number 72
 Current loss 1.586799144744873

Epoch number 73
 Current loss 1.2416285276412964

Epoch number 74
 Current loss 1.201046109199524

Epoch number 75
 Current loss 1.2580440044403076

Epoch number 76
 Current loss 1.3888639211654663

Epoch number 77
 Current loss 1.30446457862854

Epoch number 78
 Current loss 1.2717534303665161

Epoch number 79
 Current loss 1.1334199905395508

Epoch number 80
 Current loss 1.1643669605255127

Epoch number 81
 Current loss 1.2850315570831299

Epoch number 82
 Current loss 1.056455373764038

Epoch number 83
 Current loss 1.1474298238754272

Epoch number 84
 Current loss 1.2469476461410522

Epoch number 85
 Current loss 1.2564928531646729

Epoch number 86
 Current loss 1.3109350204467773

Epoch number 87
 Current loss 1.1160469055175781

Epoch number 88
 Current loss 1.146398663520813

Epoch number 89
 Current loss 1.1059321165084839

Epoch number 90
 Current loss 1.1521704196929932

Epoch number 91
 Current loss 1.1055035591125488

Epoch number 92
 Current loss 1.1548124551773071

Epoch number 93
 Current loss 1.1196362972259521

Epoch number 94
 Current loss 1.0926971435546875

Epoch number 95
 Current loss 1.1050325632095337

Epoch number 96
 Current loss 1.1961538791656494

Epoch number 97
 Current loss 1.1686053276062012

Epoch number 98
 Current loss 1.105148196220398

Epoch number 99
 Current loss 1.1375362873077393


In [149]:
folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                        transform=transforms.Compose([transforms.Resize((100,100)),
                                                                      transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False)

test_dataloader = DataLoader(siamese_dataset,num_workers=6,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
x0,_,_ = next(dataiter)

for i in range(10):
    _,x1,label2 = next(dataiter)
    concatenated = torch.cat((x0,x1),0)
    
    output1,output2 = net(Variable(x0).cuda(),Variable(x1).cuda())
    euclidean_distance = F.pairwise_distance(output1, output2)
    imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))



In [0]: