In [1]:
import numpy as np
import sklearn
import matplotlib.pyplot as plt
import seaborn as sns;sns.set()
from sklearn.feature_extraction.text import TfidfVectorizer
import re
answers = []
questions = []
flag = 0
######################################
question = []
answer = []
file = open('xiaomi_raw.txt', 'r')
for line in file.readlines():
    if line.startswith('用户问题'):
        questions.append(line)
        if answer:
            answers.append(answer)
        answer = []
        ##############
        answer.append(line)
        #############
    elif len(line) < 6:
        continue
    else:
        answer.append(line)
answers.append(answer)
file.close()
##########################################
######################################
question = []
answer = []
file1 = open('xiaomi_raw_2.txt', 'r')
for line in file1.readlines():
    if line.startswith('用户问题'):
        questions.append(line)
        if answer:
            answers.append(answer)
        answer = []
        ##############
        answer.append(line)
        #############
    elif len(line) < 6:
        continue
    else:
        answer.append(line)
answers.append(answer)
file1.close()
##########################################
######################################
question = []
answer = []
file2 = open('xiaomi_raw_3.txt', 'r')
for line in file2.readlines():
    if line.startswith('用户问题'):
        questions.append(line)
        if answer:
            answers.append(answer)
        answer = []
        ##############
        answer.append(line)
        #############
    elif len(line) < 6:
        continue
    else:
        answer.append(line)
answers.append(answer)
file2.close()
##########################################
answers = [''.join(ans) for ans in answers]

In [2]:
def get_accuracy(vectorizer, questions_vec, answers_vec):
    """
    输入fit过的vectorizer,向量化的questions和answers,帮你返回正确率
    """
    y = np.linspace(0, answers_vec.shape[0]-1, answers_vec.shape[0], dtype=np.int32)
    y_predict = np.array([predict_answer(question_vec)[0] for question_vec in questions_vec])
    return sum(y == y_predict) / answers_vec.shape[0], y, y_predict
def get_accuracy_with_threshold(vectorizer, questions_vec, answers_vec, threshold):
    """
    输入一定的threshold,只有问题答案分值超过threshold才进行回答
    
    Input:
        threshold
    
    Return:
        correct_number
        total_answer_number
        accuracy
    """
    correct = 0
    total = 0
    for i, question_vec in enumerate(questions_vec):
        answer_idx, answer_scores = predict_answer(question_vec)
        max_answer_score = np.max(answer_scores)
        if max_answer_score < threshold:
#             print(i, max_answer_score, 'max_answer_score < threshold', correct, total)
            pass
        else:
            if answer_idx == i:
                correct += 1    
            total += 1
            print(i, np.max(answer_scores))
    return correct, total, correct/total
def predict_answer(question_vec):
    """
    Get answer choice from a single question_vec
    
    Example:
        predict_answer(vectorizer.transform(jieba.cut('你好什么是黑洞啊')))
    
    Return:
        最高分的回答,全部回答评分
    
    """
    answer_scores = np.array([question_vec.multiply(answer_vec).sum() for answer_vec in answers_vec])
    return np.argmax(answer_scores), answer_scores
def get_answer(question):
    """
    提问题得到回答
    
    Input:
        中文问题
    
    Return:
        中文回答,全部回答评分
        
    Example:
        
    """
    question = filter_sentence(question)
    question_vec = vectorizer.transform([question])
    answer_score, answer_scores = predict_answer(question_vec)
#     print(answer_score)
    answer = answers[answer_score].replace(' ', '')
    # print(1)
    return answer, answer_scores
def check_accuracy_top_k(questions_vec, answers_vec, k, threshold=None):
    """
    Input:
        questions_vec: 
        Matrix representation of questions, with each question a row.
        
        answers_vec:
        Matrix representation of questions, with each answer a row.

        k:
        Top k answers are considered correct.
    
    Return:
        Accuracy of such metric.
    """
    correct = 0
#     total = questions_vec.shape[0]
    total = 0
    for i, question_vec in enumerate(questions_vec):
        answer_scores = np.array([question_vec.multiply(answer_vec).sum() for answer_vec in answers_vec])
        if threshold:
            if max(answer_scores) < threshold:
                continue
        predict_ones = answer_scores.argsort()[-k:][::-1]
        if i in predict_ones:
            correct += 1
        else:
            print(questions[i])
            for x in predict_ones:
                print(answers[x], '<< wrong', answer_scores[x])
            print(answers[i], '<< right', answer_scores[i])
        total += 1
        
    return correct/total
def filter_sentence(sentence):
    clean = r'小米6'
    new_sentence = re.sub(clean, '小米手机', sentence)
    clean = r'小米电视4'
    new_sentence = re.sub(clean, '小米电视', new_sentence)
    return new_sentence

In [3]:
questions = [filter_sentence(question) for question in questions]
answers = [filter_sentence(answer) for answer in answers]

In [4]:
questions[0] = '小米手环的系统需求和硬件需求是什么 \
                今天下雨了吗 \
                小米电脑是什么配置啊 \
                小米电视多少钱好不好用啊 \
                小米扫地机器人怎么样 \
                小米手机不发热了怎么办是不是坏了 \
                小米电视弯曲了变成了曲面屏怎么办 \
                小米雷军 are you ok \
                小米手环会不会对人的身体有害啊 \
                小米手环会不会影响今天的天气 \
                小米电视会不会导致小孩沉迷看电视 \
                小米手机下一代产品价格多少 \
                小米手机会不会爆炸 \
                小米电视发热大吗 \
                小米手环的需求是什么 \
                小米手环的要求是什么 \
                小米手环有什么要求'
answers[0] = questions[0] + answers[0]

In [5]:
import jieba
# questions = [' '.join(jieba.cut(question)) for question in questions]
# answers = [' '.join(jieba.cut(answer)) for answer in answers]
chinese_stopwords = '按,按照,俺,俺们,阿,别,别人,别处,别是,别的,别管,别说,不,不仅,不但,不光,不单,不只,不外乎,不如,不妨,不尽,不尽然,不得,不怕,不惟,不成,不拘,不料,不是,不比,不然,不特,不独,不管,不至于,不若,不论,不过,不问,比方,比如,比及,比,本身,本着,本地,本人,本,巴巴,巴,并,并且,非彼,彼时,彼此,便于,把,边,鄙人,罢了,被,般的,此间,此次,此时,此外,此处,此地,此,才,才能,朝,朝着,从,从此,从而,除非,除此之外,除开,除外,除了,除,诚然,诚如,出来,出于,曾,趁着,趁,处在,乘,冲,等等,等到,等,第,当着,当然,当地,当,多,多么,多少,对,对于,对待,对方,对比,得,得了,打,打从,的,的确,的话,但,但凡,但是,大家,大,地,待,都,到,叮咚,而言,而是,而已,而外,而后,而况,而且,而,尔尔,尔后,尔,二来,非独,非特,非徒,非但,否则,反过来说,反过来,反而,反之,分别,凡是,凡,个,个别,固然,故,故此,故而,果然,果真,各,各个,各位,各种,各自,关于具体地说,归齐,归,根据,管,赶,跟,过,该,给,光是,或者,或曰,或是,或则,或,何,何以,何况,何处,何时,还要,还有,还是,还,后者,很,换言之,换句话说,好,后,和,即,即令,即使,即便,即如,即或,即若,继而,继后,继之,既然,既是,既往,既,尽管如此,尽管,尽,就要,就算,就是说,就是了,就是,就,据,据此,接着,经,经过,结果,及,及其,及至,加以,加之,例如,介于,几时,几,截至,极了,简言之,竟而,紧接着,距,较之,较,进而,鉴于,基于,具体说来,兼之,借傥然,今,叫,将,可,可以,可是,可见,开始,开外,况且,靠,看,来说,来自,来着,来,两者,临,类如,论,赖以,连,连同,离,莫若,莫如,莫不然,假使,假如,假若,某,某个,某些,某某,漫说,没奈何,每当,每,慢说,冒,哪个,哪些,哪儿,哪天,哪年,哪怕,哪样,哪边,哪里,那里,那边,那般,那样,那时,那儿,那会儿,那些,那么样,那么些,那么,那个,那,乃,乃至,乃至于,宁肯,宁愿,宁可,宁,能,能否,你,你们,您,拿,难道说,内,哪,凭借,凭,旁人,譬如,譬喻,且,且不说,且说,其,其一,其中,其二,其他,其余,其它,其次,前后,前此,前者,起见,起,全部,全体,恰恰相反,岂但,却,去,若非,若果,若是,若夫,若,另,另一方面,另外,另悉,如若,如此,如果,如是,如同,如其,如何,如下,如上所述,如上,如,然则,然后,然而,任,任何,任凭,仍,仍旧,人家,人们,人,让,甚至于,甚至,甚而,甚或,甚么,甚且,什么,什么样,上,上下,虽说,虽然,虽则,虽,孰知,孰料,始而,所,所以,所在,所幸,所有,是,是以,是的,设使,设或,设若,谁,谁人,谁料,谁知,随着,随时,随后,随,顺着,顺,受到,使得,使,似的,尚且,庶几,庶乎,时候,省得,说来,首先,倘,倘使,倘或,倘然,倘若,同,同时,他,他人,他们们,她们,她,它们,它,替代,替,通过,腾,这里,这边,这般,这次,这样,这时,这就是说,这儿,这会儿,这些,这么点儿,这么样,这么些,这么,这个,这一来,这,正是,正巧,正如,正值,万一,为,为了,为什么,为何,为止,为此,为着,无论,无宁,无,我们,我,往,望,惟其,唯有,下,向着,向使,向,先不先,相对而言,许多,像,小,些,一,一些,一何,一切,一则,一方面,一旦,一来,一样,一般,一转眼,,由此可见,由此,由是,由于,由,用来,因而,因着,因此,因了,因为,因,要是,要么,要不然,要不是,要不,要,与,与其,与其说,与否,与此同时,以,以上,以为,以便,以免,以及,以故,以期,以来,以至,以至于,以致,己,已,已矣,有,有些,有关,有及,有时,有的,沿,沿着,于,于是,于是乎,云云,云尔,依照,依据,依,余外,也罢,也好,也,又及,又,抑或,犹自,犹且,用,越是,只当,只怕,只是,只有,只消,只要,只限,再,再其次,再则,再有,再者,再者说,再说,自身,自打,自己,自家,自后,自各儿,自从,自个儿,自,怎样,怎奈,怎么样,怎么办,怎么,怎,至若,至今,至于,至,纵然,纵使,纵令,纵,之,之一,之所以,之类,着呢,着,眨眼,总而言之,总的说来,总的来说,总的来看,总之,在于,在下,在,诸,诸位,诸如,咱们,咱,作为,只,最,照着,照,直到,综上所述,贼死,逐步,遵照,遵循,针对,致,者,则甚,则,咳,哇,哈,哈哈,哉,哎,哎呀,哎哟,哗,哟,哦,哩,矣哉,矣乎,矣,焉,毋宁,欤,嘿嘿,嘿,嘻,嘛,嘘,嘎登,嘎,嗳,嗯,嗬,嗡嗡,嗡,喽,喔唷,喏,喂,啷当,啪达,啦,啥,啐,啊,唉,哼唷,哼,咧,咦,咚,咋,呼哧,呸,呵呵,呵,呢,呜呼,呜,呗,呕,呃,呀,吱,吧哒,吧,吗,吓,兮,儿,亦,了,乎'
chinese_stopwords = chinese_stopwords.split(',')
# vocab = set((' '.join([answer for answer in answers] + [question for question in questions])).split())
vectorizer = TfidfVectorizer(ngram_range=(1,2), stop_words=chinese_stopwords, analyzer='char')
train_corpus = questions + answers
train_corpus = [i.replace('\n',' ') for i in train_corpus]
vectorizer.fit(train_corpus)
questions_vec = vectorizer.transform(questions)
answers_vec = vectorizer.transform(answers)

In [6]:
check_accuracy_top_k(questions_vec, answers_vec, 1)


Out[6]:
1.0

In [11]:
get_answer('小米手环怎么样')


Out[11]:
('小米手环的系统需求和硬件需求是什么今天下雨了吗小米电脑是什么配置啊小米电视多少钱好不好用啊小米扫地机器人怎么样小米手机不发热了怎么办是不是坏了小米电视弯曲了变成了曲面屏怎么办小米雷军areyouok小米手环会不会对人的身体有害啊小米手环会不会影响今天的天气小米电视会不会导致小孩沉迷看电视小米手机下一代产品价格多少小米手机会不会爆炸小米电视发热大吗小米手环的需求是什么小米手环的要求是什么小米手环有什么要求用户问题:小米手环的系统需求和硬件需求是什么?\n终端作答:小米手环的适配在系统和硬件方面的需求如下:\n系统需求:Android以上的系统\n硬件需求:蓝牙\n内部信息:\n目前在Android以下系统安装小米手环APP,手机端会提示“解析包错误”。\n',
 array([ 0.21764106,  0.11495328,  0.03818896,  0.15328379,  0.02422698,
         0.08857397,  0.05891229,  0.0152881 ,  0.13921302,  0.12880175,
         0.0902283 ,  0.17187719,  0.02176961,  0.04768401,  0.07682548,
         0.13912768,  0.08155093,  0.02875804,  0.05073737,  0.110244  ,
         0.09568517,  0.10485521,  0.09201736,  0.07920983,  0.04340597,
         0.03149957,  0.05926627,  0.0756998 ,  0.05081014,  0.05427847,
         0.02514826,  0.02768161,  0.02987732,  0.0089738 ,  0.00820951,
         0.01232841,  0.02184404,  0.00925385,  0.09428211,  0.06960849,
         0.02606759,  0.02260243,  0.02568772,  0.01102487,  0.0086545 ,
         0.00775109,  0.03424929,  0.02242969,  0.02439054,  0.00900903,
         0.0112808 ,  0.00921631,  0.01790276,  0.0235016 ,  0.03601443,
         0.00794863,  0.02495665,  0.02097905,  0.0513333 ,  0.04666661,
         0.02908988,  0.00987897,  0.04338788,  0.06085222,  0.05408628,
         0.04059187,  0.0428978 ,  0.04471819,  0.04960187,  0.05579691,
         0.05932651,  0.05749477,  0.05908832,  0.05577349,  0.05102767,
         0.04366669,  0.03066679]))

In [12]:
len(vectorizer.get_feature_names())


Out[12]:
5033

In [ ]: