attention机制最早在NLP领域,在训练中增加一个逻辑计算针对权重的机制,提高对目标有关参数的权重,降低对目标无关的权重。2018年CBMA提出卷积网络中的注意力机制,包括一维的channel attention和二维的spatial attention两个部分。
In [1]:
from mxnet.gluon import data as gdata, nn
from mxnet import contrib, image, nd,gluon,autograd, init
import os,sys
import mxnet as mx
from mxnet import gluon, nd
In [2]:
batch_size = 32
ctx = mx.gpu()
resize=(32,32)
if not os.path.exists('output'):
os.makedirs('output')
save_prefix = "output/attention"
In [37]:
import numpy as np
def load_mnist(batch_size, resize = None):
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)
transformer_train, transformer_test = [],[]
if resize:
transformer_train.append(gdata.vision.transforms.Resize(resize))
transformer_test.append(gdata.vision.transforms.Resize(resize))
transformer_train.append(gdata.vision.transforms.RandomFlipLeftRight()) #没做flip,train acc 92% test acc 91%
transformer_train += [gdata.vision.transforms.ToTensor()]
transformer_train = gdata.vision.transforms.Compose(transformer_train)
transformer_test += [gdata.vision.transforms.ToTensor()]
transformer_test = gdata.vision.transforms.Compose(transformer_test)
num_worker = 0 if sys.platform.startswith("win32") else 2
train_iter = gdata.DataLoader(mnist_train.transform_first(transformer_train),batch_size,shuffle=True,
last_batch="rollover",num_workers=num_worker)
test_iter = gdata.DataLoader(mnist_test.transform_first(transformer_test),batch_size,shuffle=False,
last_batch="rollover",num_workers=num_worker)
return train_iter,test_iter,len(mnist_train)
def get_class_names():
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return text_labels
train_iter,valid_iter,num_trainset = load_mnist(batch_size,resize)
num_classes = len(get_class_names())
print('#. trainset ',num_trainset)
In [106]:
def conv_bn_relu(channel,kernel,stride,padding):
net = nn.Sequential()
net.add(
nn.Conv2D(channel, kernel, strides=stride, padding=padding),
#nn.BatchNorm(),
nn.Activation("relu"),
nn.Conv2D(channel, kernel, strides=stride, padding=padding),
#nn.BatchNorm(),
nn.Activation("relu"),
)
return net
class BASENET(nn.Block):
def __init__(self, num_classes):
super(BASENET,self).__init__()
self.stages = nn.Sequential()
self.stages.add(
conv_bn_relu(32,5,2,2),
nn.MaxPool2D(2),
conv_bn_relu(64,3,1,1),
nn.MaxPool2D(2),
conv_bn_relu(128,3,1,1),
nn.MaxPool2D(2),
nn.GlobalMaxPool2D(),
#nn.Dense(256), #BN层在浅网络中反而导致梯度爆炸?
#nn.Dense(256),
#nn.Dropout(0.5),
nn.Dense(128),
nn.Dense(num_classes)
)
return
def forward(self, X):
Y = self.stages(X)
return Y
class ATTNET(nn.Block): #通道AM
def __init__(self, num_classes):
super(ATTNET,self).__init__()
self.stageA, self.stageB, self.stageC = nn.Sequential(),nn.Sequential(),nn.Sequential()
self.stageD, self.stageE = nn.Sequential(),nn.Sequential()
self.stageA.add(
conv_bn_relu(32,5,2,2),
nn.MaxPool2D(2),
conv_bn_relu(64,3,1,1),
nn.MaxPool2D(2)
)
self.stageB.add(
nn.GlobalMaxPool2D(),
nn.Dense(32, activation="relu"),
nn.Dense(64, activation="sigmoid") #mxnet 支持sigmoid?
)
self.stageC.add(
nn.GlobalAvgPool2D(),
nn.Dense(32, activation="relu"),
nn.Dense(64, activation="sigmoid") #mxnet 支持sigmoid?
)
self.stageD.add(
nn.Dense(32, activation="relu"),
nn.Dense(64, activation="sigmoid") #mxnet 支持sigmoid?
)
self.stageE.add(
conv_bn_relu(128,3,1,1),
nn.MaxPool2D(2),
nn.GlobalMaxPool2D(),
nn.Dense(128),
#nn.Dropout(0.5),
#nn.Dense(128),
nn.Dense(num_classes)
)
return
def forward(self,X):
Ya = self.stageA(X)
Yb = self.stageB(Ya)
Yc = self.stageC(Ya)
Ybc = self.stageD(Yb+Yc)
Ybc = nd.reshape(Ybc,(0,0,1,1)) #apply attention
Y = self.stageE(Ya*Ybc)
return Y
In [107]:
import time
def test_net(net, valid_iter, ctx):
cls_loss = gluon.loss.SoftmaxCrossEntropyLoss()
cls_acc = mx.metric.Accuracy(name="test acc")
loss_sum = []
for batch in valid_iter:
X,Y = batch
out = X.as_in_context(ctx)
out = net(out)
out = out.as_in_context(mx.cpu())
cls_acc.update(Y,out)
loss = cls_loss(out, Y)
loss_sum.append( loss.mean().asscalar() )
print("\ttest loss {} {}".format( np.mean(loss_sum),cls_acc.get()))
return cls_acc.get_name_value()[0][1], np.mean(loss_sum)
def train_net(net, train_iter, valid_iter, batch_size, trainer, ctx, num_epochs, lr_sch, save_prefix, train_log):
cls_loss = gluon.loss.SoftmaxCrossEntropyLoss()
cls_acc = mx.metric.Accuracy(name="train acc")
top_acc = 0
iter_num = 0
for epoch in range(num_epochs):
train_loss, train_acc = [], 0
t0 = time.time()
trainer.set_learning_rate(lr_sch(epoch))
for batch in train_iter:
iter_num += 1
X,Y = batch
out = X.as_in_context(ctx)
with autograd.record(True):
out = net(out)
out = out.as_in_context(mx.cpu())
loss = cls_loss(out, Y)
loss.backward()
train_loss.append( loss.mean().asscalar() )
#print(loss.mean().asscalar())
trainer.step(batch_size)
cls_acc.update(Y,out)
nd.waitall()
print("epoch {} lr {} {}sec".format(epoch,trainer.learning_rate, time.time() - t0))
train_loss,train_acc = np.mean(train_loss), cls_acc.get_name_value()[0][1]
print("\ttrain loss {} {}".format( np.mean(train_loss), cls_acc.get()))
acc,test_loss = test_net(net, valid_iter, ctx)
test_acc = acc
train_log.append((train_loss, train_acc, test_loss, test_acc))
if top_acc < acc:
top_acc = acc
print('\ttop valid acc {}'.format(acc))
if isinstance(net, mx.gluon.nn.HybridSequential) or isinstance(net, mx.gluon.nn.HybridBlock):
pf = '{}_{:.3f}.params'.format(save_prefix,top_acc)
net.export(pf,epoch)
else:
net_path = '{}top_acc_{}_{:.3f}.params'.format(save_prefix,epoch,top_acc)
net.save_parameters(net_path)
from mxnet import lr_scheduler
num_epochs = 30
logs = {0:[], 1:[]}
for ind,net in enumerate([ ATTNET(num_classes),BASENET(num_classes)]):
net.initialize(init=init.Xavier(), ctx=ctx)
#net.initialize(init=init.Normal(0.5),ctx=ctx)
#trainer = gluon.Trainer(net.collect_params(), 'sgd',{'wd': 5e-4,"momentum":0.9,"clip_gradient":10})
#trainer = gluon.Trainer(net.collect_params(), 'sgd',{'wd': 5e-4,"momentum":0.9})
trainer = gluon.Trainer(net.collect_params(), 'sgd',{'wd': 5e-4})
lr_sch = lr_scheduler.FactorScheduler(step=10, factor=0.1)
lr_sch.base_lr = 0.2
train_net(net, train_iter, valid_iter, batch_size, trainer, ctx, num_epochs, lr_sch, save_prefix,logs[ind])
In [99]:
%matplotlib inline
import matplotlib.pyplot as plt
base_loss = [L[0] for L in logs[1]]
base_acc = [L[1] for L in logs[1]]
att_loss = [L[0] for L in logs[0]]
att_acc = [L[1] for L in logs[0]]
epochs = [k for k in range(num_epochs)]
plt.plot(epochs, base_loss, color='r', label='base_loss')
plt.plot(epochs, base_acc, color='c', label='base_acc')
plt.plot(epochs, att_loss, color='b', label='att_loss')
plt.plot(epochs, att_acc, color='g', label='att_acc')
plt.xlabel('epoch')
plt.legend()
plt.show()
In [ ]: