背景

attention机制最早在NLP领域,在训练中增加一个逻辑计算针对权重的机制,提高对目标有关参数的权重,降低对目标无关的权重。2018年CBMA提出卷积网络中的注意力机制,包括一维的channel attention和二维的spatial attention两个部分。


In [1]:
from mxnet.gluon import data as gdata, nn
from mxnet import contrib, image, nd,gluon,autograd, init
import os,sys
import mxnet as mx
from mxnet import gluon, nd


C:\Users\001\Anaconda3\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

In [2]:
batch_size = 32
ctx = mx.gpu()

resize=(32,32)

if not os.path.exists('output'):
    os.makedirs('output')

save_prefix = "output/attention"

使用FasionMNIST做测试数据


In [37]:
import numpy as np
def load_mnist(batch_size, resize = None):
    mnist_train = gdata.vision.FashionMNIST(train=True)
    mnist_test = gdata.vision.FashionMNIST(train=False)

    transformer_train, transformer_test = [],[]
    if resize:
        transformer_train.append(gdata.vision.transforms.Resize(resize))
        transformer_test.append(gdata.vision.transforms.Resize(resize))
    transformer_train.append(gdata.vision.transforms.RandomFlipLeftRight()) #没做flip,train acc 92% test acc 91%
    transformer_train += [gdata.vision.transforms.ToTensor()]
    transformer_train = gdata.vision.transforms.Compose(transformer_train)
    
    transformer_test += [gdata.vision.transforms.ToTensor()]
    transformer_test = gdata.vision.transforms.Compose(transformer_test)
    num_worker = 0 if sys.platform.startswith("win32") else 2
    train_iter = gdata.DataLoader(mnist_train.transform_first(transformer_train),batch_size,shuffle=True,
                                  last_batch="rollover",num_workers=num_worker)
    test_iter = gdata.DataLoader(mnist_test.transform_first(transformer_test),batch_size,shuffle=False,
                                 last_batch="rollover",num_workers=num_worker)
    return train_iter,test_iter,len(mnist_train)

def get_class_names():
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return text_labels

train_iter,valid_iter,num_trainset = load_mnist(batch_size,resize)
num_classes = len(get_class_names())

print('#. trainset ',num_trainset)


#. trainset  60000

定义网络

Q:如果conv_bn_relu()中是包含一个Conv-bn层,则训练中十分容易遇到NaN问题
Q: 两个不使用非线性激励函数的FC层,训练中容易得到NaN
其中之一使用激励函数relu后,初始化采用normal(0.5)依然容易输出很大/很小的值,导致nan;但把normal(0.5)替换成xavier至少第一轮可以避免nan. 但如果此时conv_bn_relu()中卷积层减少到1时,依然容易导致NaN问题。尝试过交换BN层和Activation层次序,没有改善。

  • 结论:这样一个浅网络,BN层未必起到规则化的作用?

In [106]:
def conv_bn_relu(channel,kernel,stride,padding):
    net = nn.Sequential()
    net.add(
        nn.Conv2D(channel, kernel, strides=stride, padding=padding),       
        #nn.BatchNorm(),
        nn.Activation("relu"),
        
        nn.Conv2D(channel, kernel, strides=stride, padding=padding),
        #nn.BatchNorm(),
        nn.Activation("relu"),
    )
    return net

class BASENET(nn.Block):
    def __init__(self, num_classes):
        super(BASENET,self).__init__()
        self.stages = nn.Sequential()
        self.stages.add(
            conv_bn_relu(32,5,2,2),
            nn.MaxPool2D(2),
            conv_bn_relu(64,3,1,1),
            nn.MaxPool2D(2),
            conv_bn_relu(128,3,1,1),
            nn.MaxPool2D(2),
            nn.GlobalMaxPool2D(),
            #nn.Dense(256), #BN层在浅网络中反而导致梯度爆炸?
            #nn.Dense(256),
            #nn.Dropout(0.5),
            nn.Dense(128),
            nn.Dense(num_classes)
        )
        return
    def forward(self, X):
        Y = self.stages(X)

        return Y
    
    
class ATTNET(nn.Block): #通道AM
    def __init__(self, num_classes):
        super(ATTNET,self).__init__()
        self.stageA, self.stageB, self.stageC = nn.Sequential(),nn.Sequential(),nn.Sequential()
        self.stageD, self.stageE = nn.Sequential(),nn.Sequential()
        self.stageA.add(
            conv_bn_relu(32,5,2,2),
            nn.MaxPool2D(2),
            conv_bn_relu(64,3,1,1),
            nn.MaxPool2D(2)
        )
        self.stageB.add(
            nn.GlobalMaxPool2D(),
            nn.Dense(32, activation="relu"),
            nn.Dense(64, activation="sigmoid") #mxnet 支持sigmoid?
        )
        self.stageC.add(
            nn.GlobalAvgPool2D(),
            nn.Dense(32, activation="relu"),
            nn.Dense(64, activation="sigmoid") #mxnet 支持sigmoid?
        )
        self.stageD.add(
            nn.Dense(32, activation="relu"),
            nn.Dense(64, activation="sigmoid") #mxnet 支持sigmoid?
        )
        self.stageE.add(
            conv_bn_relu(128,3,1,1),
            nn.MaxPool2D(2),
            nn.GlobalMaxPool2D(),
            nn.Dense(128),
            #nn.Dropout(0.5),
            #nn.Dense(128),
            nn.Dense(num_classes)
        )
        return 
    
    def forward(self,X):
        Ya = self.stageA(X)
        Yb = self.stageB(Ya) 
        Yc = self.stageC(Ya)
        Ybc = self.stageD(Yb+Yc)
        Ybc = nd.reshape(Ybc,(0,0,1,1)) #apply attention
        Y = self.stageE(Ya*Ybc)
        return Y

训练


In [107]:
import time



def test_net(net, valid_iter, ctx):
    cls_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    cls_acc = mx.metric.Accuracy(name="test acc")
    loss_sum = []
    for batch in valid_iter:
        X,Y = batch
        out = X.as_in_context(ctx)
        out = net(out)
        out = out.as_in_context(mx.cpu())
        cls_acc.update(Y,out)
        loss = cls_loss(out, Y)
        loss_sum.append( loss.mean().asscalar() )
    print("\ttest loss {} {}".format( np.mean(loss_sum),cls_acc.get()))
    return cls_acc.get_name_value()[0][1], np.mean(loss_sum)


def train_net(net, train_iter, valid_iter, batch_size, trainer, ctx, num_epochs, lr_sch, save_prefix, train_log):
    cls_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    cls_acc = mx.metric.Accuracy(name="train acc")
    top_acc = 0
    iter_num = 0
    for epoch in range(num_epochs):
        train_loss, train_acc = [], 0
        t0 = time.time()
        trainer.set_learning_rate(lr_sch(epoch))
        for batch in train_iter:
            iter_num += 1
            X,Y = batch
            out = X.as_in_context(ctx)
            with autograd.record(True):
                out = net(out)
                out = out.as_in_context(mx.cpu())
                loss = cls_loss(out, Y)
            loss.backward()
            train_loss.append( loss.mean().asscalar() )
            #print(loss.mean().asscalar())
            trainer.step(batch_size)
            cls_acc.update(Y,out)
            nd.waitall()

        print("epoch {} lr {} {}sec".format(epoch,trainer.learning_rate, time.time() - t0))
        train_loss,train_acc = np.mean(train_loss), cls_acc.get_name_value()[0][1]
        print("\ttrain loss {} {}".format( np.mean(train_loss), cls_acc.get()))
        acc,test_loss = test_net(net, valid_iter, ctx)
        test_acc = acc
        train_log.append((train_loss, train_acc, test_loss, test_acc))
        
        if top_acc < acc:
            top_acc = acc
            print('\ttop valid acc {}'.format(acc))
            if isinstance(net, mx.gluon.nn.HybridSequential) or isinstance(net, mx.gluon.nn.HybridBlock):
                pf = '{}_{:.3f}.params'.format(save_prefix,top_acc)
                net.export(pf,epoch)
            else:
                net_path = '{}top_acc_{}_{:.3f}.params'.format(save_prefix,epoch,top_acc)
                net.save_parameters(net_path)
       

from mxnet import lr_scheduler
num_epochs = 30
logs = {0:[], 1:[]}
for ind,net in enumerate([ ATTNET(num_classes),BASENET(num_classes)]):
    net.initialize(init=init.Xavier(), ctx=ctx)
    #net.initialize(init=init.Normal(0.5),ctx=ctx)
    #trainer = gluon.Trainer(net.collect_params(), 'sgd',{'wd': 5e-4,"momentum":0.9,"clip_gradient":10})
    #trainer = gluon.Trainer(net.collect_params(), 'sgd',{'wd': 5e-4,"momentum":0.9})
    trainer = gluon.Trainer(net.collect_params(), 'sgd',{'wd': 5e-4})
    lr_sch = lr_scheduler.FactorScheduler(step=10, factor=0.1)
    lr_sch.base_lr = 0.2
    train_net(net, train_iter, valid_iter, batch_size, trainer, ctx, num_epochs, lr_sch, save_prefix,logs[ind])


epoch 0 lr 0.2 42.24735713005066sec
	train loss 0.7450331449508667 ('train acc', 0.7178666666666667)
	test loss 0.4256230890750885 ('test acc', 0.8471554487179487)
	top valid acc 0.8471554487179487
epoch 1 lr 0.2 41.316762924194336sec
	train loss 0.3948751389980316 ('train acc', 0.7859833333333334)
	test loss 0.3850254714488983 ('test acc', 0.8533346645367412)
	top valid acc 0.8533346645367412
epoch 2 lr 0.2 41.236796140670776sec
	train loss 0.3428177535533905 ('train acc', 0.8147055555555556)
	test loss 0.3435649275779724 ('test acc', 0.8734975961538461)
	top valid acc 0.8734975961538461
epoch 3 lr 0.2 40.847827196121216sec
	train loss 0.3159172236919403 ('train acc', 0.8319708333333333)
	test loss 0.32022878527641296 ('test acc', 0.8790934504792333)
	top valid acc 0.8790934504792333
epoch 4 lr 0.2 40.86844205856323sec
	train loss 0.29972660541534424 ('train acc', 0.8431166666666666)
	test loss 0.30280259251594543 ('test acc', 0.8848157051282052)
	top valid acc 0.8848157051282052
epoch 5 lr 0.2 40.601545572280884sec
	train loss 0.28619149327278137 ('train acc', 0.8513638888888889)
	test loss 0.31351205706596375 ('test acc', 0.8813897763578274)
epoch 6 lr 0.2 40.55769371986389sec
	train loss 0.27708205580711365 ('train acc', 0.8578380952380953)
	test loss 0.2833424508571625 ('test acc', 0.8936298076923077)
	top valid acc 0.8936298076923077
epoch 7 lr 0.2 40.38131356239319sec
	train loss 0.2682448923587799 ('train acc', 0.8632395833333333)
	test loss 0.29032373428344727 ('test acc', 0.8946685303514377)
	top valid acc 0.8946685303514377
epoch 8 lr 0.2 40.41727304458618sec
	train loss 0.2604883313179016 ('train acc', 0.8677425925925926)
	test loss 0.27074506878852844 ('test acc', 0.8987379807692307)
	top valid acc 0.8987379807692307
epoch 9 lr 0.2 40.338781118392944sec
	train loss 0.25503936409950256 ('train acc', 0.8714816666666667)
	test loss 0.30204832553863525 ('test acc', 0.8890774760383386)
epoch 10 lr 0.2 40.71027731895447sec
	train loss 0.25048500299453735 ('train acc', 0.8747121212121212)
	test loss 0.2697789669036865 ('test acc', 0.9018429487179487)
	top valid acc 0.9018429487179487
epoch 11 lr 0.020000000000000004 40.33118391036987sec
	train loss 0.1813114583492279 ('train acc', 0.8795027777777777)
	test loss 0.23757685720920563 ('test acc', 0.9129392971246006)
	top valid acc 0.9129392971246006
epoch 12 lr 0.020000000000000004 40.585888624191284sec
	train loss 0.16449116170406342 ('train acc', 0.8840769230769231)
	test loss 0.23618438839912415 ('test acc', 0.9151642628205128)
	top valid acc 0.9151642628205128
epoch 13 lr 0.020000000000000004 39.307457447052sec
	train loss 0.1560095250606537 ('train acc', 0.8882476190476191)
	test loss 0.2308773696422577 ('test acc', 0.9170327476038339)
	top valid acc 0.9170327476038339
epoch 14 lr 0.020000000000000004 39.110111474990845sec
	train loss 0.1486705094575882 ('train acc', 0.8920455555555555)
	test loss 0.23782801628112793 ('test acc', 0.9138621794871795)
epoch 15 lr 0.020000000000000004 39.543022871017456sec
	train loss 0.14166676998138428 ('train acc', 0.8955052083333334)
	test loss 0.23304162919521332 ('test acc', 0.919129392971246)
	top valid acc 0.919129392971246
epoch 16 lr 0.020000000000000004 38.876415729522705sec
	train loss 0.13537488877773285 ('train acc', 0.8987225490196078)
	test loss 0.2340775579214096 ('test acc', 0.9165665064102564)
epoch 17 lr 0.020000000000000004 39.20858907699585sec
	train loss 0.12944984436035156 ('train acc', 0.9017194444444444)
	test loss 0.24011649191379547 ('test acc', 0.9157348242811502)
epoch 18 lr 0.020000000000000004 39.35983228683472sec
	train loss 0.1234351247549057 ('train acc', 0.904530701754386)
	test loss 0.24655161798000336 ('test acc', 0.9131610576923077)
epoch 19 lr 0.020000000000000004 39.483999729156494sec
	train loss 0.11817505210638046 ('train acc', 0.9071675)
	test loss 0.2411317676305771 ('test acc', 0.915535143769968)
epoch 20 lr 0.020000000000000004 39.40156269073486sec
	train loss 0.11250897496938705 ('train acc', 0.9096611111111111)
	test loss 0.24351535737514496 ('test acc', 0.9165665064102564)
epoch 21 lr 0.0020000000000000005 39.33268332481384sec
	train loss 0.08985942602157593 ('train acc', 0.9123931818181819)
	test loss 0.2483416199684143 ('test acc', 0.917132587859425)
epoch 22 lr 0.0020000000000000005 39.063674211502075sec
	train loss 0.08528818935155869 ('train acc', 0.9149869565217391)
	test loss 0.25194066762924194 ('test acc', 0.9162660256410257)
epoch 23 lr 0.0020000000000000005 38.89340257644653sec
	train loss 0.08315233886241913 ('train acc', 0.9174076388888889)
	test loss 0.25579795241355896 ('test acc', 0.9157348242811502)
epoch 24 lr 0.0020000000000000005 39.37335443496704sec
	train loss 0.08137514442205429 ('train acc', 0.9196553333333334)
	test loss 0.25764039158821106 ('test acc', 0.9165665064102564)
epoch 25 lr 0.0020000000000000005 38.976540088653564sec
	train loss 0.07989657670259476 ('train acc', 0.92175)
	test loss 0.2572595179080963 ('test acc', 0.9163338658146964)
epoch 26 lr 0.0020000000000000005 38.98850154876709sec
	train loss 0.07841596007347107 ('train acc', 0.9237234567901235)
	test loss 0.26184287667274475 ('test acc', 0.914863782051282)
epoch 27 lr 0.0020000000000000005 40.31989574432373sec
	train loss 0.07694493234157562 ('train acc', 0.9255690476190476)
	test loss 0.26068857312202454 ('test acc', 0.9161341853035144)
epoch 28 lr 0.0020000000000000005 39.994038343429565sec
	train loss 0.07567667961120605 ('train acc', 0.9273114942528735)
	test loss 0.2658448815345764 ('test acc', 0.9157652243589743)
epoch 29 lr 0.0020000000000000005 38.89196562767029sec
	train loss 0.07427866011857986 ('train acc', 0.9289544444444444)
	test loss 0.26632413268089294 ('test acc', 0.9157348242811502)
epoch 0 lr 0.2 33.26179480552673sec
	train loss 0.7289904356002808 ('train acc', 0.7246)
	test loss 0.39758220314979553 ('test acc', 0.8500600961538461)
	top valid acc 0.8500600961538461
epoch 1 lr 0.2 33.34288430213928sec
	train loss 0.386726051568985 ('train acc', 0.7894416666666667)
	test loss 0.3234398365020752 ('test acc', 0.8766972843450479)
	top valid acc 0.8766972843450479
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-107-d40b26a494fc> in <module>()
     72     lr_sch = lr_scheduler.FactorScheduler(step=10, factor=0.1)
     73     lr_sch.base_lr = 0.2
---> 74     train_net(net, train_iter, valid_iter, batch_size, trainer, ctx, num_epochs, lr_sch, save_prefix,logs[ind])

<ipython-input-107-d40b26a494fc> in train_net(net, train_iter, valid_iter, batch_size, trainer, ctx, num_epochs, lr_sch, save_prefix, train_log)
     41             trainer.step(batch_size)
     42             cls_acc.update(Y,out)
---> 43             nd.waitall()
     44 
     45         print("epoch {} lr {} {}sec".format(epoch,trainer.learning_rate, time.time() - t0))

~\Anaconda3\lib\site-packages\mxnet\ndarray\ndarray.py in waitall()
    159     This function is used for benchmarking only.
    160     """
--> 161     check_call(_LIB.MXNDArrayWaitAll())
    162 
    163 

KeyboardInterrupt: 

In [99]:
%matplotlib inline
import matplotlib.pyplot as plt


base_loss = [L[0] for L in logs[1]]
base_acc = [L[1] for L in logs[1]]
att_loss = [L[0] for L in logs[0]]
att_acc = [L[1] for L in logs[0]]

epochs = [k for k in range(num_epochs)]
plt.plot(epochs, base_loss, color='r', label='base_loss')
plt.plot(epochs, base_acc, color='c', label='base_acc')
plt.plot(epochs, att_loss, color='b', label='att_loss')
plt.plot(epochs, att_acc, color='g', label='att_acc')
plt.xlabel('epoch')
plt.legend()
plt.show()



In [ ]: