scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类

scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类数据集 fetch_20newsgroups

#-*- coding: UTF-8 -*-

import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.grid_search import GridSearchCV
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.datasets import fetch_20newsgroups
from sklearn import metrics

获取待分类的文本数据源
categories = [‘comp.graphics‘, ‘comp.os.ms-windows.misc‘,‘comp.sys.ibm.pc.hardware‘,‘comp.sys.mac.hardware‘,‘comp.windows.x‘];
newsgroup_data = fetch_20newsgroups(subset = ‘train‘,categories = categories)
X,Y=np.array(newsgroup_data.data),np.array(newsgroup_data.target)
Xtrain,Ytrain,Xtest,Ytest =X[0:2400],Y[0:2400],X[2400:],Y[2400:]

#Pipeline主要用于将三个需要串行的模块串在一起,后一个模型处理前一个的结果‘‘‘
#vect主要用于去音调、转小写、去停顿词->tdidf主要用于计词频->clf分类模型‘‘‘
pipeline_obj = Pipeline([(‘vect‘,CountVectorizer()),(‘tfidf‘,TfidfTransformer()),(‘clf‘,SGDClassifier()),])
print "pipeline:",‘\n‘, [name for name, _ in pipeline_obj.steps],‘\n‘

#定义需要遍历的所有候选参数的字典,key_name需要用__分隔模型名和模型内部的参数名‘‘‘
parameters = {
    ‘vect__max_df‘: (0.5, 0.75),‘vect__max_features‘: (None, 5000, 10000),
    ‘tfidf__use_idf‘: (True, False),‘tfidf__norm‘: (‘l1‘, ‘l2‘),
    ‘clf__alpha‘: (0.00001, 0.000001), ‘clf__n_iter‘: (10, 50) }
print "parameters:",‘\n‘,parameters,‘\n‘

#GridSearchCV用于寻找vectorizer词频统计, tfidftransformer特征变换和SGD classifier分类模型的最优参数
grid_search = GridSearchCV( pipeline_obj, parameters, n_jobs = 1,verbose=1 )
print ‘grid_search‘,‘\n‘,grid_search,‘\n‘ #输出所有参数名及参数候选值
grid_search.fit(Xtrain,Ytrain),‘\n‘#遍历执行候选参数,寻找最优参数

best_parameters = dict(grid_search.best_estimator_.get_params())#get实例中的最优参数
for param_name in sorted(parameters.keys()):
    print("\t%s: %r" % (param_name, best_parameters[param_name])),‘\n‘#输出最有参数结果
pipeline_obj.set_params(clf__alpha = 1e-05,clf__n_iter = 50,tfidf__use_idf = True,vect__max_df = 0.5,vect__max_features = None)
#将pipeline_obj实例中的参数重写为最优结果‘‘‘
print pipeline_obj.named_steps

#用最优参数训练模型‘‘‘
pipeline_obj.fit(Xtrain,Ytrain)
pred = pipeline_obj.predict(Xtrain)
print ‘\n‘,metrics.classification_report(Ytrain,pred)
pred = pipeline_obj.predict(Xtest)
print ‘\n‘,metrics.classification_report(Ytest,pred)

执行结果:总共有96个参数排列组合候选组,每组跑3次模型进行交叉验证,共计跑模型96*3=288次。

调参前VS调参后:

#参考

#http://blog.csdn.net/mmc2015/article/details/46991465
# http://blog.csdn.net/abcjennifer/article/details/23884761
# http://scikit-learn.org/stable/modules/pipeline.html
# http://blog.csdn.net/yuanyu5237/article/details/44278759

时间: 2024-12-09 22:54:33

scikit learn 模块 调参 pipeline+girdsearch 数据举例:文档分类的相关文章

sahrepoint 上次数据到文档库

sharepoint学习笔记汇总 http://blog.csdn.net/qq873113580/article/details/20390149 protected void Button1_Click(object sender, EventArgs e) { using (SPSite site = new SPSite("http://zhangyi:90")) { using (SPWeb web = site.OpenWeb()) { web.AllowUnsafeUpd

configparser模块——用于生成和修改常见配置文档

配置文档格式 1 [DEFAULT] 2 ServerAliveInterval = 45 3 Compression = yes 4 CompressionLevel = 9 5 ForwardX11 = yes 6 7 [bitbucket.org] 8 User = hg 配置文档文件格式 解析配置文件:查询 1 #-*- coding:utf-8 -*- 2 #解析配置文件 3 import configparser 4 config = configparser.ConfigParse

编写powerdesigner字段数据设计文档

package com.winway.wcloud.protal.gym; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.List; import com.mysql.jdbc.PreparedStatement; publi

Dom4j解析语音数据XML文档(注意ArrayList多次添加对象,会导致覆盖之前的对象)

今天做的一个用dom4j解析声音文本的xml文档时,我用ArrayList来存储每一个Item的信息,要注意ArrayList多次添加对象,会导致覆盖之前的对象:解决方案是在最后将对象添加入ArrayLis时先new 一个对象,然后将之前那个对象的属性set到新的对象中,之后在加入到 ArrayList,就不会出错了. package parseXML; import org.dom4j.Attribute;import org.dom4j.Document;import org.dom4j.E

利用swagger模块开发flask的api接口帮助文档

swagger官网称其为世界最流行的api工具.用过的都说好.我已经深有体会. 附上官网编辑页面.只需要拷贝相应的文件就可以实现效果 swagger在线编辑器 下面主要讲解一下在python的flask框架下,如何使用这款屌炸天的应用. 1.安装flasgger 项目地址https://github.com/rochacbruno/flasgger pip install flasgger 2.写一个简单的web例子. 以下是简写代码 #coding:utf8 import sys reload

随机森林在乳腺癌数据上的调参

这篇文章中,使用基于方差和偏差的调参方法,在乳腺癌数据上进行一次随机森林的调参.乳腺癌数据是sklearn自带的分类数据之一. 方差和偏差 案例中,往往使用真实数据,为什么我们要使用sklearn自带的数据呢?因为真实数据在随机森林下的调参过程,往往非常缓慢.真实数据量大,维度高,在使用随机森林之前需要一系列的处理,因此不太适合用来做直播中的案例演示.原本,我为大家准备了kaggle上下载的辨别手写数字的数据,有4W多条记录700多个左右的特征,随机森林在这个辨别手写数字的数据上有非常好的表现,

支持向量机高斯核调参小结

在支持向量机(以下简称SVM)的核函数中,高斯核(以下简称RBF)是最常用的,从理论上讲, RBF一定不比线性核函数差,但是在实际应用中,却面临着几个重要的超参数的调优问题.如果调的不好,可能比线性核函数还要差.所以我们实际应用中,能用线性核函数得到较好效果的都会选择线性核函数.如果线性核不好,我们就需要使用RBF,在享受RBF对非线性数据的良好分类效果前,我们需要对主要的超参数进行选取.本文我们就对scikit-learn中 SVM RBF的调参做一个小结. 1. SVM RBF 主要超参数概

sklearn-GBDT 调参

1. scikit-learn GBDT类库概述 在sacikit-learn中,GradientBoostingClassifier为GBDT的分类类, 而GradientBoostingRegressor为GBDT的回归类.两者的参数类型完全相同,当然有些参数比如损失函数loss的可选择项并不相同.这些参数中,我们把重要参数分为两类,第一类是Boosting框架的重要参数,第二类是弱学习器即CART回归树的重要参数. 下面我们就从这两个方面来介绍这些参数的使用. 2. GBDT类库boost

python模块--BeautifulSoup <HTML/XML文档搜索模块>

之前解析字符串都是上正则,导致后来解析HTML/XML也习惯上正则,可是毕竟正则太底层的东西,对于这种有规律的文档,它不是一个好的选择. 后来发现了HTMLParser,感觉比正则好多了,正想深入学习一下,却发现了这个. BeautifulSoup 一比较然后我把以前代码里面的解析HTML/XML的正则全删了,改成BS来解析,所以在此推荐这个HTML/XML文档解析模块,当然它也可以用来修改文档. BeautifulSoup中文文档 至于示例和详细说明便不说了,文档写得不错,而且还是中文的.