sklearn构建管道

sklearn支持使用管道(Pipeline)连接多个sklearn中的模型类实例,但要求过程中的模型类对象带transform方法的且最后一个需要是分类器,回归器或者同样是带transform方法的模型类对象.

transform方法的类对象叫做转换器,可以使用sklearn.preprocessing.FunctionTransformer自定义.


In [1]:
import numpy as np
from sklearn.preprocessing import FunctionTransformer

In [2]:
transformer = FunctionTransformer(np.log1p)
X = np.array([[0, 1], [2, 3]])
transformer.transform(X)


Out[2]:
array([[ 0.        ,  0.69314718],
       [ 1.09861229,  1.38629436]])

管道通常用在将向量化(vectorizer) => 转换器(transformer) => 分类器(classifier) 过程封装为一个连贯的过程.

例:以fetch_20newsgroups数据为例做贝叶斯分类器模型


In [4]:
from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import fetch_20newsgroups

In [5]:
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(subset='train',categories=categories, shuffle=True, random_state=42)


Downloading dataset from http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz (14 MB)

In [6]:
text_clf = Pipeline([('vect', CountVectorizer()),# 分词并向量化
                    ('tfidf', TfidfTransformer()), # tfidf算法提取关键字
                    ('clf', MultinomialNB())]) # 分类器

In [7]:
text_clf.fit(twenty_train.data, twenty_train.target)


Out[7]:
Pipeline(steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        strip...inear_tf=False, use_idf=True)), ('clf', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))])

评估性能


In [8]:
import numpy as np
twenty_test = fetch_20newsgroups(subset='test',categories=categories, shuffle=True, random_state=42)
docs_test = twenty_test.data
predicted = text_clf.predict(docs_test)
np.mean(predicted == twenty_test.target)


Out[8]:
0.83488681757656458