XGBOOST使用指南
<div id="article_content" class="article_content clearfix csdn-tracking-statistics" data-pid="blog" data-mod="popu_307" data-dsm="post">
<div class="htmledit_views">
<h2 style="font-family:‘Helvetica Neue‘, Helvetica, Arial, sans-serif;line-height:1;color:rgb(0,0,0);"><a name="t0"></a><span style="font-size:16px;">一、导入必要的工具包</span></h2><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># 导入必要的工具包</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">import</span> xgboost <span class="hljs-keyword">as</span> xgb</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># 计算分类正确率</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> accuracy_score</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><span style="font-size:16px;">二、<span style="color:rgb(0,0,0);font-family:‘Helvetica Neue‘, Helvetica, Arial, sans-serif;"><strong>数据读取</strong></span></span><br>XGBoost可以加载libsvm格式的文本数据,libsvm的文件格式(稀疏特征)如下:<br><span style="color:#ff0000;"><strong>1</strong></span> <strong>101</strong>:1.2 <strong>102</strong>:0.03<br><span style="color:#cc0000;"><strong>0 </strong></span> <strong>1</strong>:2.1 <strong>10001</strong>:300 <strong>10002</strong>:400<br>...<br>每一行表示一个样本,第一行的开头的<strong>“1”是样本的标签</strong>。<strong>“101”和“102”为特征索引,</strong><span style="color:#cc0000;"><strong>‘1.2‘和‘0.03‘ 为特征的值。</strong></span><br><p>在两类分类中,用<strong>“1”</strong>表示正样本,用<strong>“0”</strong> 表示负样本。也支持[0,1]表示概率用来做标签,表示为正样本的概率。</p>下面的示例数据需要我们通过一些蘑菇的若干属性判断这个品种是否有毒。<br>UCI数据描述:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/ ,<br>每个样本描述了蘑菇的22个属性,比如形状、气味等等(<strong>将22维原始特征用加工后变成了126维特征,<br></strong><p><strong>并存为libsvm格式</strong>),然后给出了这个蘑菇是否可食用。其中6513个样本做训练,1611个样本做测试。</p><p><strong><span style="color:#cc0000;">注:libsvm格式文件说明如下</span> <a href="https://www.cnblogs.com/codingmengmeng/p/6254325.html" rel="nofollow" target="_blank">https://www.cnblogs.com/codingmengmeng/p/6254325.html</a></strong></p>XGBoost加载的数据存储在对象DMatrix中<br>XGBoost自定义了一个<strong><span style="color:#cc0000;">数据矩阵类DMatrix</span></strong>,优化了存储和运算速度<br><p>DMatrix文档:http://xgboost.readthedocs.io/en/latest/python/python_api.html</p><p>数据下载地址:<a href="http://download.csdn.net/download/u011630575/10266113" rel="nofollow" target="_blank">http://download.csdn.net/download/u011630575/10266113</a></p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># read in data,数据在xgboost安装的路径下的demo目录,现在我们将其copy到当前代码下的data目录</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">my_workpath = <span class="hljs-string">‘./data/‘</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dtrain = xgb.DMatrix(my_workpath + <span class="hljs-string">‘agaricus.txt.train‘</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dtest = xgb.DMatrix(my_workpath + <span class="hljs-string">‘agaricus.txt.test‘</span>)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p><span style="color:rgb(0,0,0);font-family:‘Helvetica Neue‘, Helvetica, Arial, sans-serif;font-size:14px;text-align:left;background-color:rgb(255,255,255);">查看数据情况</span></p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs">dtrain.num_col()</code><div class="hljs-button" data-title="复制"></div></pre><pre onclick="hljs.copyCode(event)"><code class="language-python hljs">dtrain.num_row()</code><div class="hljs-button" data-title="复制"></div></pre><pre onclick="hljs.copyCode(event)"><code class="language-python hljs">dtest.num_row()</code><div class="hljs-button" data-title="复制"></div></pre><span style="font-size:16px;">三、训练参数设置</span><br><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">max_depth: 树的最大深度。缺省值为<span class="hljs-number">6</span>,取值范围为:[<span class="hljs-number">1</span>,∞]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">eta:为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重。 </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">eta通过缩减特征的权重使提升计算过程更加保守。缺省值为<span class="hljs-number">0.3</span>,取值范围为:[<span class="hljs-number">0</span>,<span class="hljs-number">1</span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">silent:取<span class="hljs-number">0</span>时表示打印出运行时信息,取<span class="hljs-number">1</span>时表示以缄默方式运行,不打印运行时信息。缺省值为<span class="hljs-number">0</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">objective: 定义学习任务及相应的学习目标,“binary:logistic” 表示二分类的逻辑回归问题,输出为概率。</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">其他参数取默认值。</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># specify parameters via map</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">param = {<span class="hljs-string">‘max_depth‘</span>:<span class="hljs-number">2</span>, <span class="hljs-string">‘eta‘</span>:<span class="hljs-number">1</span>, <span class="hljs-string">‘silent‘</span>:<span class="hljs-number">0</span>, <span class="hljs-string">‘objective‘</span>:<span class="hljs-string">‘binary:logistic‘</span> }</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">print(param)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p><span style="font-size:16px;">四、训练模型</span></p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># 设置boosting迭代计算次数</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">num_round = <span class="hljs-number">2</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">import</span> time</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">starttime = time.clock()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">bst = xgb.train(param, dtrain, num_round) <span class="hljs-comment"># dtrain是训练数据集</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">endtime = time.clock()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">print</span> (endtime - starttime)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre>XGBoost预测的输出是概率。这里蘑菇分类是一个二类分类问题,<strong><span style="color:#cc0000;">输出值是样本为第一类的概率。</span></strong><br><p><span style="color:#ff0000;">我们需要将概率值转换为0或1。</span></p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">train_preds = bst.predict(dtrain)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">train_predictions = [round(value) <span class="hljs-keyword">for</span> value <span class="hljs-keyword">in</span> train_preds]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">y_train = dtrain.get_label() <span class="hljs-comment">#值为输入数据的第一行</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">train_accuracy = accuracy_score(y_train, train_predictions)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">print</span> (<span class="hljs-string">"Train Accuary: %.2f%%"</span> % (train_accuracy * <span class="hljs-number">100.0</span>))</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p><span style="font-size:16px;">五、测试</span></p><p>模型训练好后,可以用训练好的模型对测试数据进行预测<br></p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># make prediction</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">preds = bst.predict(dtest)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre>检查模型在测试集上的正确率<br><p>XGBoost预测的输出是概率,输出值是样本为第一类的概率。我们需要将概率值转换为0或1。</p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs">predictions = [round(value) <span class="hljs-keyword">for</span> value <span class="hljs-keyword">in</span> preds]</code><div class="hljs-button" data-title="复制"></div></pre><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">y_test = dtest.get_label()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">test_accuracy = accuracy_score(y_test, predictions)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">print(<span class="hljs-string">"Test Accuracy: %.2f%%"</span> % (test_accuracy * <span class="hljs-number">100.0</span>))</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><h2 style="font-family:‘Helvetica Neue‘, Helvetica, Arial, sans-serif;line-height:1;color:rgb(0,0,0);"><a name="t1"></a><span style="font-weight:normal;"><span style="font-size:16px;">六、模型可视化</span></span></h2><span style="font-size:14px;">调用XGBoost工具包中的plot_tree,在显示<br>要可视化模型需要安装graphviz软件包<br>plot_tree()的三个参数:<br>1. 模型<br>2. 树的索引,从0开始<br>3. 显示方向,缺省为竖直,‘LR‘是水平方向</span><br><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">from</span> matplotlib <span class="hljs-keyword">import</span> pyplot</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">import</span> graphviz</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">xgb.plot_tree(bst, num_trees=<span class="hljs-number">0</span>, rankdir= <span class="hljs-string">‘LR‘</span> )</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">pyplot.show()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment">#xgb.plot_tree(bst,num_trees=1, rankdir= ‘LR‘ )</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment">#pyplot.show()</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment">#xgb.to_graphviz(bst,num_trees=0)</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment">#xgb.to_graphviz(bst,num_trees=1)</span></div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p>七、代码整理</p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># coding:utf-8</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">import</span> xgboost <span class="hljs-keyword">as</span> xgb</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># 计算分类正确率</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> accuracy_score</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># read in data,数据在xgboost安装的路径下的demo目录,现在我们将其copy到当前代码下的data目录</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">my_workpath = <span class="hljs-string">‘./data/‘</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dtrain = xgb.DMatrix(my_workpath + <span class="hljs-string">‘agaricus.txt.train‘</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dtest = xgb.DMatrix(my_workpath + <span class="hljs-string">‘agaricus.txt.test‘</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dtrain.num_col()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dtrain.num_row()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dtest.num_row()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="18"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># specify parameters via map</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="19"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">param = {<span class="hljs-string">‘max_depth‘</span>:<span class="hljs-number">2</span>, <span class="hljs-string">‘eta‘</span>:<span class="hljs-number">1</span>, <span class="hljs-string">‘silent‘</span>:<span class="hljs-number">0</span>, <span class="hljs-string">‘objective‘</span>:<span class="hljs-string">‘binary:logistic‘</span> }</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="20"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">print(param)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="21"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="22"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># 设置boosting迭代计算次数</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="23"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">num_round = <span class="hljs-number">2</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="24"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="25"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">import</span> time</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="26"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="27"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">starttime = time.clock()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="28"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="29"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">bst = xgb.train(param, dtrain, num_round) <span class="hljs-comment"># dtrain是训练数据集</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="30"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="31"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">endtime = time.clock()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="32"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">print</span> (endtime - starttime)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="33"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="34"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="35"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">train_preds = bst.predict(dtrain) <span class="hljs-comment">#</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="36"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">print</span> (<span class="hljs-string">"train_preds"</span>,train_preds)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="37"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="38"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">train_predictions = [round(value) <span class="hljs-keyword">for</span> value <span class="hljs-keyword">in</span> train_preds]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="39"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">print</span> (<span class="hljs-string">"train_predictions"</span>,train_predictions)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="40"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="41"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">y_train = dtrain.get_label()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="42"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">print</span> (<span class="hljs-string">"y_train"</span>,y_train)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="43"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="44"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">train_accuracy = accuracy_score(y_train, train_predictions)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="45"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">print</span> (<span class="hljs-string">"Train Accuary: %.2f%%"</span> % (train_accuracy * <span class="hljs-number">100.0</span>))</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="46"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="47"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="48"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># make prediction</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="49"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">preds = bst.predict(dtest)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="50"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">predictions = [round(value) <span class="hljs-keyword">for</span> value <span class="hljs-keyword">in</span> preds]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="51"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="52"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">y_test = dtest.get_label()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="53"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="54"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">test_accuracy = accuracy_score(y_test, predictions)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="55"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">print(<span class="hljs-string">"Test Accuracy: %.2f%%"</span> % (test_accuracy * <span class="hljs-number">100.0</span>))</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="56"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="57"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># from matplotlib import pyplot</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="58"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># import graphviz</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="59"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="60"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">import</span> graphviz</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="61"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="62"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># xgb.plot_tree(bst, num_trees=0, rankdir=‘LR‘)</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="63"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># pyplot.show()</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="64"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="65"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># xgb.plot_tree(bst,num_trees=1, rankdir= ‘LR‘ )</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="66"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># pyplot.show()</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="67"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># xgb.to_graphviz(bst,num_trees=0)</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="68"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># xgb.to_graphviz(bst,num_trees=1)</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="69"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p><br></p><p><br></p> </div>
</div>
原文地址:https://www.cnblogs.com/guanzhicheng/p/9419582.html
时间: 2024-11-08 20:48:18