CNN Character SA



In [1]:
import keras 
from  os.path import join
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Dropout,Activation, Lambda,Input
from keras.layers import Embedding
from keras.layers import Convolution1D
from keras.datasets import imdb
from keras import backend as K
from keras.layers import Convolution1D, GlobalMaxPooling1D,Convolution2D
from keras.utils import np_utils
from keras.models import Model


Using TensorFlow backend.

数据预处理


In [2]:
file_names = ['stsa.fine.test','stsa.fine.train','stsa.fine.dev']
file_path = '/home/bruce/data/sentiment/citai_process'
def read_file(fname=''):
    with open(join(file_path,fname)) as fr:
        lines = fr.readlines()
    lines = [line.strip() for line in lines]
    lables = [int(line[0:1]) for line in lines]
    characters = [list(line[2:]) for line in lines]
    return characters,lables       
train_X,train_y = read_file(fname='stsa.fine.train')
test_X,test_y = read_file(fname='stsa.fine.test')
dev_X,dev_y = read_file(fname='stsa.fine.dev')
print(len(train_X))
print(len(test_X))
print(len(dev_X))
print(train_X[0:2])
print(train_y[0:2])


8544
2210
1101
[['a', ' ', 's', 't', 'i', 'r', 'r', 'i', 'n', 'g', ' ', ',', ' ', 'f', 'u', 'n', 'n', 'y', ' ', 'a', 'n', 'd', ' ', 'f', 'i', 'n', 'a', 'l', 'l', 'y', ' ', 't', 'r', 'a', 'n', 's', 'p', 'o', 'r', 't', ' ', 'r', 'e', '-', 'i', 'm', 'a', 'g', 'i', 'n', 'i', 'n', 'g', ' ', 'o', 'f', ' ', 'b', 'e', 'a', 'u', 't', 'y', ' ', 'a', 'n', 'd', ' ', 't', 'h', 'e', ' ', 'b', 'e', 'a', 's', 't', ' ', 'a', 'n', 'd', ' ', '1', '9', '3', '0', 's', ' ', 'h', 'o', 'r', 'r', 'o', 'r', ' ', 'f', 'i', 'l', 'm'], ['a', 'p', 'p', 'a', 'r', 'e', 'n', 't', 'l', 'y', ' ', 'r', 'e', 'a', 's', 's', 'e', 'm', 'b', 'l', 'e', ' ', 'f', 'r', 'o', 'm', ' ', 't', 'h', 'e', ' ', 'c', 'u', 't', 't', 'i', 'n', 'g', '-', 'r', 'o', 'o', 'm', ' ', 'f', 'l', 'o', 'o', 'r', ' ', 'o', 'f', ' ', 'a', 'n', 'y', ' ', 'g', 'i', 'v', 'e', ' ', 'd', 'a', 'y', 't', 'i', 'm', 'e', ' ', 's', 'o', 'a', 'p', ' ', '.']]
[4, 1]

In [ ]:

句子长度统计信息


In [3]:
def statics_list2(arrays=[]):
    lengths = [len(i) for i in arrays]
    lengths = sorted(lengths)
    length = len(lengths)
    print('length = ',len(lengths))
    print('max = ',lengths[-1])
    print('min =',lengths[0])
    print('average = ',sum(lengths)/length)
    print('top 50% = ',lengths[int(0.5*length)])
    print('top 80% = ',lengths[int(0.8*length)])
    print('top 90% = ',lengths[int(0.9*length)])
    print('top 95% = ',lengths[int(0.95*length)])
    
statics_list2(arrays=train_X)


length =  8544
max =  279
min = 4
average =  100.29693352059925
top 50% =  96
top 80% =  144
top 90% =  170
top 95% =  190

character to index


In [4]:
def token_to_index(datas=[]):
    word_index={}
    count=1
    for data in datas:
        for list_ in data:
            for w in list_:
                if w not in word_index:
                    word_index[w] = count
                    count = count + 1
    print('leng of word_index =',len(word_index))
    for i in range(len(datas)):
        datas[i] = [[ word_index[w] for w in line ] for line in datas[i]] 
    return datas,word_index

In [5]:
X,word_index = token_to_index(datas=[train_X,dev_X])
train_X,dev_X = X
print(len(word_index))
print(word_index)


leng of word_index = 97
97
{'\\': 58, 'æ': 85, "'": 30, '&': 78, 'd': 13, '0': 25, 'a': 1, 's': 3, '-': 18, '.': 28, 'q': 36, 'M': 38, 'B': 46, '*': 77, 'O': 82, 'k': 31, 'h': 21, '7': 60, 'è': 66, 'ó': 65, 'ã': 86, 'e': 17, '`': 49, 'Z': 70, 'R': 50, '6': 63, '1': 22, 'D': 37, 'Q': 90, 'm': 19, 'n': 7, 'v': 27, 'F': 57, 'â': 81, 'N': 72, 'C': 44, 'ô': 94, 'P': 53, 'S': 33, '4': 76, '%': 96, 'á': 79, '3': 24, 'U': 88, 'y': 12, 'ñ': 92, 'j': 32, 'u': 11, 'r': 6, 'ç': 89, '+': 93, 'í': 64, 'à': 97, 'b': 20, 'J': 69, 'g': 8, 'w': 29, 'E': 45, 'p': 15, 'T': 74, 'é': 34, '!': 51, ' ': 2, 'Y': 62, 'L': 41, '@': 40, 'x': 35, 'c': 26, 'H': 68, ',': 9, 'X': 91, 'ï': 84, 'K': 61, '?': 52, 'f': 10, 't': 4, ';': 55, 'ö': 95, 'o': 16, 'ü': 80, '$': 39, 'i': 5, '2': 48, '9': 23, '#': 87, '/': 59, 'l': 14, 'G': 56, '=': 83, 'A': 47, 'W': 71, 'z': 42, ':': 43, '8': 73, 'I': 54, '5': 75, 'V': 67}

In [6]:
print(train_X[0])


[1, 2, 3, 4, 5, 6, 6, 5, 7, 8, 2, 9, 2, 10, 11, 7, 7, 12, 2, 1, 7, 13, 2, 10, 5, 7, 1, 14, 14, 12, 2, 4, 6, 1, 7, 3, 15, 16, 6, 4, 2, 6, 17, 18, 5, 19, 1, 8, 5, 7, 5, 7, 8, 2, 16, 10, 2, 20, 17, 1, 11, 4, 12, 2, 1, 7, 13, 2, 4, 21, 17, 2, 20, 17, 1, 3, 4, 2, 1, 7, 13, 2, 22, 23, 24, 25, 3, 2, 21, 16, 6, 6, 16, 6, 2, 10, 5, 14, 19]

构建模型


In [7]:
max_len = 190
batch_size=32

max_features = 100
embedding_dims = 150
nb_filter = 150

nb_filter = 150
filter_length = 3
dense1_hindden = 150
nb_classes = 5

In [8]:
print('Build model...')
model = Sequential()
model.add(Embedding(input_dim=max_features,
                    output_dim = embedding_dims
                   ))
model.add(Convolution1D(nb_filter = nb_filter,
                        filter_length = filter_length,
                        border_mode = 'valid',
                        activation='relu',
                        subsample_length = 1
                       ))
model.add(GlobalMaxPooling1D())
model.add(Dense(dense1_hindden))
model.add(Dropout(0.2))
model.add(Activation('relu'))

model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(loss = 'categorical_crossentropy',
              optimizer = 'adadelta',
              metrics=['accuracy']
             )
print('finish build')


Build model...
finish build

模型输入


In [9]:
print(type(train_y[0]))
train_y = np_utils.to_categorical(train_y, nb_classes)
dev_y = np_utils.to_categorical(dev_y, nb_classes)
train_X = sequence.pad_sequences(train_X, maxlen=max_len)
dev_X = sequence.pad_sequences(dev_X, maxlen=max_len)


<class 'int'>

In [10]:
def my_generator(X=None,y=None):
    i = 0
    max_i = int(len(X)/batch_size)
    while True:
        i = i % max_i
        x_batch = X[i*batch_size:(i+1)*batch_size]
        y_batch = y[i*batch_size:(i+1)*batch_size]
        yield (x_batch,y_batch)
        i = i + 1

训练模型


In [11]:
model.fit_generator(my_generator(train_X,train_y),samples_per_epoch = 32*267,nb_epoch=500,verbose=1,validation_data=(dev_X,dev_y))


Epoch 1/500
8544/8544 [==============================] - 21s - loss: 1.5698 - acc: 0.2735 - val_loss: 1.5605 - val_acc: 0.3252
Epoch 2/500
8544/8544 [==============================] - 21s - loss: 1.5466 - acc: 0.3001 - val_loss: 1.5273 - val_acc: 0.3106
Epoch 3/500
8544/8544 [==============================] - 21s - loss: 1.5146 - acc: 0.3278 - val_loss: 1.4938 - val_acc: 0.3424
Epoch 4/500
8544/8544 [==============================] - 21s - loss: 1.4836 - acc: 0.3409 - val_loss: 1.4703 - val_acc: 0.3224
Epoch 5/500
8544/8544 [==============================] - 21s - loss: 1.4527 - acc: 0.3604 - val_loss: 1.4530 - val_acc: 0.3361
Epoch 6/500
8544/8544 [==============================] - 21s - loss: 1.4283 - acc: 0.3786 - val_loss: 1.4415 - val_acc: 0.3497
Epoch 7/500
8544/8544 [==============================] - 21s - loss: 1.4041 - acc: 0.3859 - val_loss: 1.4373 - val_acc: 0.3379
Epoch 8/500
8544/8544 [==============================] - 21s - loss: 1.3865 - acc: 0.3950 - val_loss: 1.4265 - val_acc: 0.3615
Epoch 9/500
8544/8544 [==============================] - 21s - loss: 1.3718 - acc: 0.4051 - val_loss: 1.4258 - val_acc: 0.3470
Epoch 10/500
8544/8544 [==============================] - 21s - loss: 1.3537 - acc: 0.4092 - val_loss: 1.4240 - val_acc: 0.3479
Epoch 11/500
8544/8544 [==============================] - 21s - loss: 1.3410 - acc: 0.4221 - val_loss: 1.4238 - val_acc: 0.3433
Epoch 12/500
8544/8544 [==============================] - 21s - loss: 1.3288 - acc: 0.4263 - val_loss: 1.4223 - val_acc: 0.3470
Epoch 13/500
8544/8544 [==============================] - 21s - loss: 1.3163 - acc: 0.4352 - val_loss: 1.4372 - val_acc: 0.3397
Epoch 14/500
8544/8544 [==============================] - 21s - loss: 1.3062 - acc: 0.4380 - val_loss: 1.4211 - val_acc: 0.3542
Epoch 15/500
8544/8544 [==============================] - 21s - loss: 1.2937 - acc: 0.4412 - val_loss: 1.4382 - val_acc: 0.3497
Epoch 16/500
8544/8544 [==============================] - 21s - loss: 1.2810 - acc: 0.4443 - val_loss: 1.4381 - val_acc: 0.3606
Epoch 17/500
8544/8544 [==============================] - 21s - loss: 1.2735 - acc: 0.4505 - val_loss: 1.4505 - val_acc: 0.3379
Epoch 18/500
8544/8544 [==============================] - 21s - loss: 1.2662 - acc: 0.4531 - val_loss: 1.4512 - val_acc: 0.3533
Epoch 19/500
8544/8544 [==============================] - 21s - loss: 1.2584 - acc: 0.4600 - val_loss: 1.4677 - val_acc: 0.3433
Epoch 20/500
8544/8544 [==============================] - 21s - loss: 1.2460 - acc: 0.4613 - val_loss: 1.4624 - val_acc: 0.3542
Epoch 21/500
8544/8544 [==============================] - 21s - loss: 1.2376 - acc: 0.4688 - val_loss: 1.4650 - val_acc: 0.3515
Epoch 22/500
8544/8544 [==============================] - 21s - loss: 1.2259 - acc: 0.4780 - val_loss: 1.4726 - val_acc: 0.3451
Epoch 23/500
8544/8544 [==============================] - 21s - loss: 1.2218 - acc: 0.4717 - val_loss: 1.4956 - val_acc: 0.3479
Epoch 24/500
8544/8544 [==============================] - 21s - loss: 1.2136 - acc: 0.4824 - val_loss: 1.5139 - val_acc: 0.3460
Epoch 25/500
8544/8544 [==============================] - 21s - loss: 1.2018 - acc: 0.4840 - val_loss: 1.4969 - val_acc: 0.3433
Epoch 26/500
8544/8544 [==============================] - 21s - loss: 1.1980 - acc: 0.4856 - val_loss: 1.4946 - val_acc: 0.3524
Epoch 27/500
8544/8544 [==============================] - 21s - loss: 1.1907 - acc: 0.4933 - val_loss: 1.5005 - val_acc: 0.3470
Epoch 28/500
8544/8544 [==============================] - 21s - loss: 1.1821 - acc: 0.4933 - val_loss: 1.5172 - val_acc: 0.3460
Epoch 29/500
8544/8544 [==============================] - 21s - loss: 1.1783 - acc: 0.4973 - val_loss: 1.5286 - val_acc: 0.3424
Epoch 30/500
6016/8544 [====================>.........] - ETA: 6s - loss: 1.1780 - acc: 0.4942
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-11-92d00cb0bbc6> in <module>()
----> 1 model.fit_generator(my_generator(train_X,train_y),samples_per_epoch = 32*267,nb_epoch=500,verbose=1,validation_data=(dev_X,dev_y))

/home/bruce/anaconda3/lib/python3.5/site-packages/keras/models.py in fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose, callbacks, validation_data, nb_val_samples, class_weight, max_q_size, nb_worker, pickle_safe, **kwargs)
    872                                         max_q_size=max_q_size,
    873                                         nb_worker=nb_worker,
--> 874                                         pickle_safe=pickle_safe)
    875 
    876     def evaluate_generator(self, generator, val_samples, max_q_size=10, nb_worker=1, pickle_safe=False, **kwargs):

/home/bruce/anaconda3/lib/python3.5/site-packages/keras/engine/training.py in fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose, callbacks, validation_data, nb_val_samples, class_weight, max_q_size, nb_worker, pickle_safe)
   1441                     outs = self.train_on_batch(x, y,
   1442                                                sample_weight=sample_weight,
-> 1443                                                class_weight=class_weight)
   1444                 except:
   1445                     _stop.set()

/home/bruce/anaconda3/lib/python3.5/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1219             ins = x + y + sample_weights
   1220         self._make_train_function()
-> 1221         outputs = self.train_function(ins)
   1222         if len(outputs) == 1:
   1223             return outputs[0]

/home/bruce/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   1011             feed_dict[tensor] = value
   1012         session = get_session()
-> 1013         updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict)
   1014         return updated[:len(self.outputs)]
   1015 

/home/bruce/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    708     try:
    709       result = self._run(None, fetches, feed_dict, options_ptr,
--> 710                          run_metadata_ptr)
    711       if run_metadata:
    712         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/bruce/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    906     if final_fetches or final_targets:
    907       results = self._do_run(handle, final_targets, final_fetches,
--> 908                              feed_dict_string, options, run_metadata)
    909     else:
    910       results = []

/home/bruce/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
    956     if handle is None:
    957       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
--> 958                            target_list, options, run_metadata)
    959     else:
    960       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/bruce/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
    963   def _do_call(self, fn, *args):
    964     try:
--> 965       return fn(*args)
    966     except errors.OpError as e:
    967       message = compat.as_text(e.message)

/home/bruce/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
    945         return tf_session.TF_Run(session, options,
    946                                  feed_dict, fetch_list, target_list,
--> 947                                  status, run_metadata)
    948 
    949     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

预测


In [ ]:
score = model.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])