使用基于Apache Spark的随机森林方法预测贷款风险

使用基于Apache Spark的随机森林方法预测贷款风险

原文:Predicting Loan Credit Risk using Apache Spark Machine Learning Random Forests 
作者:Carol McDonald,MapR解决方案架构师 
翻译:KK4SBB 
责编:周建丁([email protected].NET

在本文中,我将向大家介绍如何使用Apache SparkSpark.ml库中的随机森林算法来对银行信用贷款的风险做分类预测。Spark的spark.ml库基于DataFrame,它提供了大量的接口,帮助用户创建和调优机器学习工作流。结合dataframe使用spark.ml,能够实现模型的智能优化,从而提升模型效果。

分类算法

分类算法是一类监督式机器学习算法,它根据已知标签的样本(如已经明确交易是否存在欺诈)来预测其它样本所属的类别(如是否属于欺诈性的交易)。分类问题需要一个已经标记过的数据集和预先设计好的特征,然后基于这些信息来学习给新样本打标签。所谓的特征即是一些“是与否”的问题。标签就是这些问题的答案。在下面这个例子里,如果某个动物的行走姿态、游泳姿势和叫声都像鸭子,那么就给它打上“鸭子”的标签。

我们来看一个银行信贷的信用风险例子:

  • 我们需要预测什么?

    • 某个人是否会按时还款
    • 这就是标签:此人的信用度
  • 你用来预测的“是与否”问题或者属性是什么? 
    • 申请人的基本信息和社会身份信息:职业,年龄,存款储蓄,婚姻状态等等……
    • 这些就是特征,用来构建一个分类模型,你从中提取出对分类有帮助的特征信息。

决策树模型

决策树是一种基于输入特征来预测类别或是标签的分类模型。决策树的工作原理是这样的,它在每个节点都需要计算特征在该节点的表达式值,然后基于运算结果选择一个分支通往下一个节点。下图展示了一种用来预测信用风险的决策树模型。每个决策问题就是模型的一个节点,“是”或者“否”的答案是通往子节点的分支。

  • 问题1:账户余额是否大于200元?

    • 问题2:当前就职时间是否超过1年? 
      • 不可信赖

随机森林模型

融合学习算法结合了多个机器学习的算法,从而得到了效果更好的模型。随机森林是分类和回归问题中一类常用的融合学习方法。此算法基于训练数据的不同子集构建多棵决策树,组合成一个新的模型。预测结果是所有决策树输出的组合,这样能够减少波动,并且提高预测的准确度。对于随机森林分类模型,每棵树的预测结果都视为一张投票。获得投票数最多的类别就是预测的类别。

基于Spark机器学习工具来分析信用风险问题

我们使用德国人信用度数据集,它按照一系列特征属性将人分为信用风险好和坏两类。我们可以获得每个银行贷款申请者的以下信息:

存放德国人信用数据的csv文件格式如下:

1,1,18,4,2,1049,1,2,4,2,1,4,2,21,3,1,1,3,1,1,1
1,1,9,4,0,2799,1,3,2,3,1,2,1,36,3,1,2,3,2,1,1
1,2,12,2,9,841,2,4,2,2,1,4,1,23,3,1,1,2,1,1,1

在这个背景下,我们会构建一个由决策树组成的随机森林模型来预测是否守信用的标签/类别,基于以下特征:

  • 标签 -> 守信用或者不守信用(1或者0)
  • 特征 -> {存款余额,信用历史,贷款目的等等}

软件

本教程将使用Spark 1.6.1

按照教程指示,登录MapR沙箱,用户名为user01,密码为mapr。将样本数据文件复制到你的沙箱主目录下/user/user01 using scp。(注意,你可能需要先更新Spark的版本)打开spark shell:

$spark-shell --master local[1]

加载并解析csv数据文件

首先,我们需要引入机器学习相关的包。

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
import sqlContext.implicits._
import sqlContext._
import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }
import org.apache.spark.ml.{ Pipeline, PipelineStage }

我们用一个Scala的case类来定义Credit的属性,对应于csv文件中的一行。

    // define the Credit Schema
    case class Credit(
        creditability: Double,
        balance: Double, duration: Double, history: Double, purpose: Double, amount: Double,
        savings: Double, employment: Double, instPercent: Double, sexMarried: Double, guarantors: Double,
        residenceDuration: Double, assets: Double, age: Double, concCredit: Double, apartment: Double,
        credits: Double, occupation: Double, dependents: Double, hasPhone: Double, foreign: Double
      )

下面的函数解析一行数据文件,将值存入Credit类中。类别的索引值减去了1,因此起始索引值为0.

    // function to create a  Credit class from an Array of Double
    def parseCredit(line: Array[Double]): Credit = {
        Credit(
          line(0),
          line(1) - 1, line(2), line(3), line(4) , line(5),
          line(6) - 1, line(7) - 1, line(8), line(9) - 1, line(10) - 1,
          line(11) - 1, line(12) - 1, line(13), line(14) - 1, line(15) - 1,
          line(16) - 1, line(17) - 1, line(18) - 1, line(19) - 1, line(20) - 1
        )
      }
    // function to transform an RDD of Strings into an RDD of Double
      def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = {
        rdd.map(_.split(",")).map(_.map(_.toDouble))
      }

接下去,我们导入germancredit.csv文件中的数据,存为一个String类型的RDD。然后我们对RDD做map操作,将RDD中的每个字符串经过ParseRDDR函数的映射,转换为一个Double类型的数组。紧接着是另一个map操作,使用ParseCredit函数,将每个Double类型的RDD转换为Credit对象。toDF()函数将Array[[Credit]]类型的RDD转为一个Credit类的Dataframe。

    // load the data into a  RDD
    val creditDF= parseRDD(sc.textFile("germancredit.csv")).map(parseCredit).toDF().cache()
    creditDF.registerTempTable("credit")

DataFrame的printSchema()函数将各个字段含义以树状的形式打印到控制台输出。

    // Return the schema of this DataFrame
    creditDF.printSchema

    root
     |-- creditability: double (nullable = false)
     |-- balance: double (nullable = false)
     |-- duration: double (nullable = false)
     |-- history: double (nullable = false)
     |-- purpose: double (nullable = false)
     |-- amount: double (nullable = false)
     |-- savings: double (nullable = false)
     |-- employment: double (nullable = false)
     |-- instPercent: double (nullable = false)
     |-- sexMarried: double (nullable = false)
     |-- guarantors: double (nullable = false)
     |-- residenceDuration: double (nullable = false)
     |-- assets: double (nullable = false)
     |-- age: double (nullable = false)
     |-- concCredit: double (nullable = false)
     |-- apartment: double (nullable = false)
     |-- credits: double (nullable = false)
     |-- occupation: double (nullable = false)
     |-- dependents: double (nullable = false)
     |-- hasPhone: double (nullable = false)
     |-- foreign: double (nullable = false)

    // Display the top 20 rows of DataFrame
    creditDF.show

    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+
    |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|
    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+
    |          1.0|    0.0|    18.0|    4.0|    2.0|1049.0|    0.0|       1.0|        4.0|       1.0|       0.0|              3.0|   1.0|21.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|
    |          1.0|    0.0|     9.0|    4.0|    0.0|2799.0|    0.0|       2.0|        2.0|       2.0|       0.0|              1.0|   0.0|36.0|       2.0|      0.0|    1.0|       2.0|       1.0|     0.0|    0.0|
    |          1.0|    1.0|    12.0|    2.0|    9.0| 841.0|    1.0|       3.0|        2.0|       1.0|       0.0|              3.0|   0.0|23.0|       2.0|      0.0|    0.0|       1.0|       0.0|     0.0|    0.0|
    |          1.0|    0.0|    12.0|    4.0|    0.0|2122.0|    0.0|       2.0|        3.0|       2.0|       0.0|              1.0|   0.0|39.0|       2.0|      0.0|    1.0|       1.0|       1.0|     0.0|    1.0|
    |          1.0|    0.0|    12.0|    4.0|    0.0|2171.0|    0.0|       2.0|        4.0|       2.0|       0.0|              3.0|   1.0|38.0|       0.0|      1.0|    1.0|       1.0|       0.0|     0.0|    1.0|
    |          1.0|    0.0|    10.0|    4.0|    0.0|2241.0|    0.0|       1.0|        1.0|       2.0|       0.0|              2.0|   0.0|48.0|       2.0|      0.0|    1.0|       1.0|       1.0|     0.0|    1.0|
    |          1.0|    0.0|     8.0|    4.0|    0.0|3398.0|    0.0|       3.0|        1.0|       2.0|       0.0|              3.0|   0.0|39.0|       2.0|      1.0|    1.0|       1.0|       0.0|     0.0|    1.0|
    |          1.0|    0.0|     6.0|    4.0|    0.0|1361.0|    0.0|       1.0|        2.0|       2.0|       0.0|              3.0|   0.0|40.0|       2.0|      1.0|    0.0|       1.0|       1.0|     0.0|    1.0|
    |          1.0|    3.0|    18.0|    4.0|    3.0|1098.0|    0.0|       0.0|        4.0|       1.0|       0.0|              3.0|   2.0|65.0|       2.0|      1.0|    1.0|       0.0|       0.0|     0.0|    0.0|
    |          1.0|    1.0|    24.0|    2.0|    3.0|3758.0|    2.0|       0.0|        1.0|       1.0|       0.0|              3.0|   3.0|23.0|       2.0|      0.0|    0.0|       0.0|       0.0|     0.0|    0.0|
    |          1.0|    0.0|    11.0|    4.0|    0.0|3905.0|    0.0|       2.0|        2.0|       2.0|       0.0|              1.0|   0.0|36.0|       2.0|      0.0|    1.0|       2.0|       1.0|     0.0|    0.0|
    |          1.0|    0.0|    30.0|    4.0|    1.0|6187.0|    1.0|       3.0|        1.0|       3.0|       0.0|              3.0|   2.0|24.0|       2.0|      0.0|    1.0|       2.0|       0.0|     0.0|    0.0|
    |          1.0|    0.0|     6.0|    4.0|    3.0|1957.0|    0.0|       3.0|        1.0|       1.0|       0.0|              3.0|   2.0|31.0|       2.0|      1.0|    0.0|       2.0|       0.0|     0.0|    0.0|
    |          1.0|    1.0|    48.0|    3.0|   10.0|7582.0|    1.0|       0.0|        2.0|       2.0|       0.0|              3.0|   3.0|31.0|       2.0|      1.0|    0.0|       3.0|       0.0|     1.0|    0.0|
    |          1.0|    0.0|    18.0|    2.0|    3.0|1936.0|    4.0|       3.0|        2.0|       3.0|       0.0|              3.0|   2.0|23.0|       2.0|      0.0|    1.0|       1.0|       0.0|     0.0|    0.0|
    |          1.0|    0.0|     6.0|    2.0|    3.0|2647.0|    2.0|       2.0|        2.0|       2.0|       0.0|              2.0|   0.0|44.0|       2.0|      0.0|    0.0|       2.0|       1.0|     0.0|    0.0|
    |          1.0|    0.0|    11.0|    4.0|    0.0|3939.0|    0.0|       2.0|        1.0|       2.0|       0.0|              1.0|   0.0|40.0|       2.0|      1.0|    1.0|       1.0|       1.0|     0.0|    0.0|
    |          1.0|    1.0|    18.0|    2.0|    3.0|3213.0|    2.0|       1.0|        1.0|       3.0|       0.0|              2.0|   0.0|25.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|
    |          1.0|    1.0|    36.0|    4.0|    3.0|2337.0|    0.0|       4.0|        4.0|       2.0|       0.0|              3.0|   0.0|36.0|       2.0|      1.0|    0.0|       2.0|       0.0|     0.0|    0.0|
    |          1.0|    3.0|    11.0|    4.0|    0.0|7228.0|    0.0|       2.0|        1.0|       2.0|       0.0|              3.0|   1.0|39.0|       2.0|      1.0|    1.0|       1.0|       0.0|     0.0|    0.0|
    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+  

dataframe初始化之后,你可以用SQL命令查询数据了。下面是一些使用Scala DataFrame接口查询数据的例子:

计算数值型数据的统计信息,包括计数、均值、标准差、最小值和最大值。

    //  computes statistics for balance
      creditDF.describe("balance").show

    +-------+-----------------+
    |summary|          balance|
    +-------+-----------------+
    |  count|             1000|
    |   mean|            1.577|
    | stddev|1.257637727110893|
    |    min|              0.0|
    |    max|              3.0|
    +-------+-----------------+

    // compute the avg balance by creditability (the label)
     creditDF.groupBy("creditability").avg("balance").show

    +-------------+------------------+
    |creditability|      avg(balance)|
    +-------------+------------------+
    |          1.0|1.8657142857142857|
    |          0.0|0.9033333333333333|
    +-------------+------------------+

你可以用某个表名将DataFrame注册为一张临时表,然后用SQLContext提供的sql方法执行SQL命令。下面是几个用sqlContext查询的例子:

     sqlContext.sql("SELECT creditability, avg(balance) as avgbalance, avg(amount) as avgamt, avg(duration) as avgdur  FROM credit GROUP BY creditability ").show

    +-------------+------------------+------------------+------------------+
    |creditability|        avgbalance|            avgamt|            avgdur|
    +-------------+------------------+------------------+------------------+
    |          1.0|1.8657142857142857| 2985.442857142857|19.207142857142856|
    |          0.0|0.9033333333333333|3938.1266666666666|             24.86|
    +-------------+------------------+------------------+------------------+

提取特征

为了构建一个分类模型,你首先需要提取对分类最有帮助的特征。在德国人信用度的数据集里,每条样本用两个类别来标记——1(可信)和0(不可信)。

每个样本的特征包括以下的字段:

  • 标签 -> 是否可信:0或者1
  • 特征 -> {“存款”,“期限”,“历史记录”,“目的”,“数额”,“储蓄”,“是否在职”,“婚姻”,“担保人”,“居住时间”,“资产”,“年龄”,“历史信用”,“居住公寓”,“贷款”,“职业”,“监护人”,“是否有电话”,“外籍”}

定义特征数组

图片来自:学习Spark

为了在机器学习算法中使用这些特征,这些特征经过了变换,存入特征向量中,即一组表示各个维度特征值的数值向量。

下图中,用VectorAssembler方法将每个维度的特征都做变换,返回一个新的dataframe。

    //define the feature columns to put in the feature vector
    val featureCols = Array("balance", "duration", "history", "purpose", "amount",
        "savings", "employment", "instPercent", "sexMarried",  "guarantors",
        "residenceDuration", "assets",  "age", "concCredit", "apartment",
        "credits",  "occupation", "dependents",  "hasPhone", "foreign" )
    //set the input and output column names
      val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
    //return a dataframe with all of the  feature columns in  a vector column
    val df2 = assembler.transform( creditDF)
    // the transform method produced a new column: features.
    df2.show

    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+
    |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|
    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+
    |          1.0|    0.0|    18.0|    4.0|    2.0|1049.0|    0.0|       1.0|        4.0|       1.0|       0.0|              3.0|   1.0|21.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|(20,[1,2,3,4,6,7,...|

接着,我们使用StringIndexer方法返回一个Dataframe,增加了信用度这一列作为标签。

    //  Create a label column with the StringIndexer
    val labelIndexer = new StringIndexer().setInputCol("creditability").setOutputCol("label")
    val df3 = labelIndexer.fit(df2).transform(df2)
    // the  transform method produced a new column: label.
    df3.show

    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+
    |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|
    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+
    |          1.0|    0.0|    18.0|    4.0|    2.0|1049.0|    0.0|       1.0|        4.0|       1.0|       0.0|              3.0|   1.0|21.0|       2.0|      0.0|    0.0|       2.0|       0.0|     0.0|    0.0|(20,[1,2,3,4,6,7,...|  0.0|

下图中,数据集被分为训练数据和测试数据两个部分,70%的数据用来训练模型,30%的数据用来测试模型。

    //  split the dataframe into training and test data
    val splitSeed = 5043
    val Array(trainingData, testData) = df3.randomSplit(Array(0.7, 0.3), splitSeed)

训练模型

接着,我们按照下列参数训练一个随机森林分类器:

  • maxDepth:每棵树的最大深度。增加树的深度可以提高模型的效果,但是会延长训练时间。
  • maxBins:连续特征离散化时选用的最大分桶个数,并且决定每个节点如何分裂。
  • impurity:计算信息增益的指标
  • auto:在每个节点分裂时是否自动选择参与的特征个数
  • seed:随机数生成种子

模型的训练过程就是将输入特征和这些特征对应的样本标签相关联的过程。

    // create the classifier,  set parameters for training
    val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043)
    //  use the random forest classifier  to train (fit) the model
    val model = classifier.fit(trainingData) 

    // print out the random forest trees
    model.toDebugString
    res20: String =
    res5: String =
    "RandomForestClassificationModel (uid=rfc_6c4ceb92ba78) with 20 trees
      Tree 0 (weight 1.0):
        If (feature 0 <= 1.0)
         If (feature 10 <= 0.0)
          If (feature 3 <= 6.0)
           Predict: 0.0
          Else (feature 3 > 6.0)
           Predict: 0.0
         Else (feature 10 > 0.0)
          If (feature 12 <= 63.0)
           Predict: 0.0
          Else (feature 12 > 63.0)
           Predict: 0.0
        Else (feature 0 > 1.0)
         If (feature 13 <= 1.0)
          If (feature 3 <= 3.0)
           Predict: 0.0
          Else (feature 3 > 3.0)
           Predict: 1.0
         Else (feature 13 > 1.0)
          If (feature 7 <= 1.0)
           Predict: 0.0
          Else (feature 7 > 1.0)
           Predict: 0.0
      Tree 1 (weight 1.0):
        If (feature 2 <= 1.0)
         If (feature 15 <= 0.0)
          If (feature 11 <= 0.0)
           Predict: 0.0
          Else (feature 11 > 0.0)
           Predict: 1.0
         Else (feature 15 > 0.0)
          If (feature 11 <= 0.0)
           Predict: 0.0
          Else (feature 11 > 0.0)
           Predict: 1.0
        Else (feature 2 > 1.0)
         If (feature 12 <= 31.0)
          If (feature 5 <= 0.0)
           Predict: 0.0
          Else (feature 5 > 0.0)
           Predict: 0.0
         Else (feature 12 > 31.0)
          If (feature 3 <= 4.0)
           Predict: 0.0
          Else (feature 3 > 4.0)
           Predict: 0.0
      Tree 2 (weight 1.0):
        If (feature 8 <= 1.0)
         If (feature 6 <= 2.0)
          If (feature 4 <= 10875.0)
           Predict: 0.0
          Else (feature 4 > 10875.0)
           Predict: 1.0
         Else (feature 6 > 2.0)
          If (feature 1 <= 36.0)
           Predict: 0.0
          Else (feature 1 > 36.0)
           Predict: 1.0
        Else (feature 8 > 1.0)
         If (feature 5 <= 0.0)
          If (feature 4 <= 4113.0)
           Predict: 0.0
          Else (feature 4 > 4113.0)
           Predict: 1.0
         Else (feature 5 > 0.0)
          If (feature 11 <= 2.0)
           Predict: 0.0
          Else (feature 11 > 2.0)
           Predict: 0.0
      Tree 3 ...

测试模型

接下来,我们对测试数据进行预测。

    // run the  model on test features to get predictions
    val predictions = model.transform(testData)
    //As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.
    predictions.show

    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
    |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign|            features|label|       rawPrediction|         probability|prediction|
    +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+
    |          0.0|    0.0|    12.0|    0.0|    5.0|1108.0|    0.0|       3.0|        4.0|       2.0|       0.0|              2.0|   0.0|28.0|       2.0|      1.0|    1.0|       2.0|       0.0|     0.0|    0.0|(20,[1,3,4,6,7,8,...|  1.0|[14.1964586927573...|[0.70982293463786...|       0.0|

然后,我们用BinaryClassificationEvaluator评估预测的效果,它将预测结果与样本的实际标签相比较,返回一个准确度指标(ROC曲线所覆盖的面积)。本例子中,AUC达到78%。

    // create an Evaluator for binary classification, which expects two input columns: rawPrediction and label.
    val evaluator = new BinaryClassificationEvaluator().setLabelCol("label")
    // Evaluates predictions and returns a scalar metric areaUnderROC(larger is better).
    val accuracy = evaluator.evaluate(predictions)
    accuracy: Double = 0.7824906081835722

使用机器学习管道

我们接着用管道来训练模型,可能会取得更好的效果。管道采取了一种简单的方式来比较各种不同组合的参数的效果,这个方法称为网格搜索法(grid search),你先设置好待测试的参数,MLLib就会自动完成这些参数的不同组合。管道搭建了一条工作流,一次性完成了整个模型的调优,而不是独立对每个参数进行调优。

下面我们就用ParamGridBuilder工具来构建参数网格。

    // We use a ParamGridBuilder to construct a grid of parameters to search over
    val paramGrid = new ParamGridBuilder()
      .addGrid(classifier.maxBins, Array(25, 28, 31))
      .addGrid(classifier.maxDepth, Array(4, 6, 8))
      .addGrid(classifier.impurity, Array("entropy", "gini"))
      .build()

创建并完成一条管道。一条管道由一系列stage组成,每个stage相当于一个Estimator或是Transformer。

    val steps: Array[PipelineStage] = Array(classifier)
    val pipeline = new Pipeline().setStages(steps)

我们用CrossValidator类来完成模型筛选。CrossValidator类使用一个Estimator类,一组ParamMaps类和一个Evaluator类。注意,使用CrossValidator类的开销很大。

    // Evaluate model on test instances and compute test error
    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(10)

管道在参数网格上不断地爬行,自动完成了模型优化的过程:对于每个ParamMap类,CrossValidator训练得到一个Estimator,然后用Evaluator来评价结果,然后用最好的ParamMap和整个数据集来训练最优的Estimator。

    // When fit is called, the stages are executed in order.
    // Fit will run cross-validation,  and choose the best set of parameters
    //The fitted model from a Pipeline is an PipelineModel, which consists of fitted models and transformers

    val pipelineFittedModel = cv.fit(trainingData)

现在,我们可以用管道训练得到的最优模型进行预测,将预测结果与标签做比较。预测结果取得了82%的准确率,相比之前78%的准确率有提高。

    //  call tranform to make predictions on test data. The fitted model will use the best model found
    val predictions = pipelineFittedModel.transform(testData)
    val accuracy = evaluator.evaluate(predictions)
    Double = 0.8204386232104784
    val rm2 = new RegressionMetrics(
      predictions.select("prediction", "label").rdd.map(x =>
      (x(0).asInstanceOf[Double], x(1).asInstanceOf[Double])))
    println("MSE: " + rm2.meanSquaredError)
    println("MAE: " + rm2.meanAbsoluteError)
    println("RMSE Squared: " + rm2.rootMeanSquaredError)
    println("R Squared: " + rm2.r2)
    println("Explained Variance: " + rm2.explainedVariance + "\n")

    MSE: 0.2575250836120402
    MAE: 0.25752508361204013
    RMSE Squared: 0.5074692932700856
    R Squared: -0.1687988628287138
    Explained Variance: 0.15466269952237702

更多相关知识

在本文中,我们演示了如何用Apache Spark的机器学习随机森林算法和机器学习管道来解决分类问题。如果你有任何的疑惑,请在评论区留言。



CCAI 2016中国人工智能大会将于8月26-27日在京举行,AAAI主席,多位院士,MIT、微软、大疆、百度、滴滴专家领衔全球技术领袖和产业先锋打造国内人工智能前沿平台,6+重磅大主题报告,4大专题论坛,1000+高质量参会嘉宾,探讨人机交互、机器学习、模式识别及产业实战。门票限时六折优惠中。

时间: 2024-10-21 23:41:13

使用基于Apache Spark的随机森林方法预测贷款风险的相关文章

基于Apache Spark机器学习的客户流失预测

流失预测是个重要的业务,通过预测哪些客户可能取消对服务的订阅来最大限度地减少客户流失.虽然最初在电信行业使用,但它已经成为银行,互联网服务提供商,保险公司和其他垂直行业的通用业务. 预测过程是大规模数据的驱动,并且经常结合使用先进的机器学习技术.在本篇文章中,我们将看到通常使用的哪些类型客户数据,对数据进行一些初步分析,并生成流失预测模型 - 所有这些都是通过Spark及其机器学习框架来完成的. 使用数据科学更好地理解和预测客户行为是一个迭代过程,其中涉及: 1.发现和模型创建: 分析历史数据.

机器学习入门-随机森林温度预测的案例

在这个案例中: 1. datetime.datetime.strptime(data, '%Y-%m-%d') # 由字符串格式转换为日期格式 2. pd.get_dummies(features)  # 将数据中的文字标签转换为one-hot编码形式,增加了特征的列数 3. rf.feature_importances 探究了随机森林样本特征的重要性,对其进行排序后条形图 代码: 第一步:数据读取,通过.describe() 查看数据是否存在缺失值的情况 第二步:对年月日特征进行字符串串接,使

sklearn 随机森林方法

Notes The default values for the parameters controlling the size of the trees (e.g. max_depth, min_samples_leaf, etc.) lead to fully grown and unpruned trees which can potentially be very large on some data sets. To reduce memory consumption, the com

基于随机森林的煤与瓦斯突出预测方法研究

1引言 煤炭在我国一次能源中的主导地位短期内不会发生根本性改变.随着煤炭产量的增长,近年来我国煤矿生产事故频繁发生,安全形势非常严峻.煤矿事故已经成为社会各界关注的焦点.而煤与瓦斯突出是煤矿生产过程中的一种严重自然灾害.长期以来,煤与瓦斯突出事故严重制约着我国煤矿生产和煤炭企业经济效益的提高,给煤矿安全生产和井下作业人员的生命财产安全带来了极大威胁.因此,正确预测矿井煤与瓦斯突出的规模,对于煤炭企业安全生产具有重要的现实意义. 目前关于煤与瓦斯突出的预测方法主要有:单项指标法.瓦斯地质统计法.D

Spark随机森林实现学习

前言 最近阅读了spark mllib(版本:spark 1.3)中Random Forest的实现,发现在分布式的数据结构上实现迭代算法时,有些地方与单机环境不一样.单机上一些直观的操作(递归),在分布式数据上,必须进行优化,否则I/O(网络,磁盘)会消耗大量时间.本文整理spark随机森林实现中的相关技巧,方便后面回顾. ? 随机森林算法概要 随机森林算法的详细实现和细节,可以参考论文Breiman 2001.这里简单说说大体思路,方便理解代码. 随机森林是一个组装(ensemble mod

机器学习实战之 第七章 集成方法(随机森林和 AdaBoost)

第7章 集成方法 ensemble method 集成方法: ensemble method(元算法: meta algorithm) 概述 概念:是对其他算法进行组合的一种形式. 通俗来说: 当做重要决定时,大家可能都会考虑吸取多个专家而不只是一个人的意见. 机器学习处理问题时又何尝不是如此? 这就是集成方法背后的思想. 集成方法: 投票选举(bagging: 自举汇聚法 bootstrap aggregating): 是基于数据随机重抽样分类器构造的方法 再学习(boosting): 是基于

【Spark MLlib速成宝典】模型篇06随机森林【Random Forests】(Python版)

目录 随机森林原理 随机森林代码(Spark Python) 随机森林原理 待续... 返回目录 随机森林代码(Spark Python) 代码里数据:https://pan.baidu.com/s/1jHWKG4I 密码:acq1 # -*-coding=utf-8 -*- from pyspark import SparkConf, SparkContext sc = SparkContext('local') from pyspark.mllib.tree import RandomFor

Apache Spark 2.2.0 中文文档 - SparkR (R on Spark) | ApacheCN

SparkR (R on Spark) 概述 SparkDataFrame 启动: SparkSession 从 RStudio 来启动 创建 SparkDataFrames 从本地的 data frames 来创建 SparkDataFrames 从 Data Sources(数据源)创建 SparkDataFrame 从 Hive tables 来创建 SparkDataFrame SparkDataFrame 操作 Selecting rows(行), columns(列) Groupin

随机森林入门攻略(内含R、Python代码)

随机森林入门攻略(内含R.Python代码) 简介 近年来,随机森林模型在界内的关注度与受欢迎程度有着显著的提升,这多半归功于它可以快速地被应用到几乎任何的数据科学问题中去,从而使人们能够高效快捷地获得第一组基准测试结果.在各种各样的问题中,随机森林一次又一次地展示出令人难以置信的强大,而与此同时它又是如此的方便实用. 需要大家注意的是,在上文中特别提到的是第一组测试结果,而非所有的结果,这是因为随机森林方法固然也有自己的局限性.在这篇文章中,我们将向你介绍运用随机森林构建预测模型时最令人感兴趣