package Spark_MLlib import org.apache.spark.ml.Pipeline import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} import org.apache.spark.sql.SparkSession import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} case class data_scheam(features:Vector,label:String) object 决策树__回归模型 { val spark=SparkSession.builder().master("local").getOrCreate() import spark.implicits._ def main(args: Array[String]): Unit = { val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo2.txt") .map(_.split(",")).map(x=>data_schema(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF() val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data) val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data) val labelCoverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3)) //决策树回归模型构造设置 val dtRegressor=new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures") //构造机器学习工作流 val pipelineRegressor=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,dtRegressor,labelCoverter)) //训练决策树回归模型 val modelRegressor=pipelineRegressor.fit(trainData) //进行预测 val prediction=modelRegressor.transform(testData) prediction.show(150) //评估决策树回归模型 val evaluatorRegressor=new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse") //setMetricName:设置决定你的度量标准是均方根误差还是均方误差等,值可以为:rmse,mse,r2,mae val Root_Mean_Squared_Error=evaluatorRegressor.evaluate(prediction) println("均方根误差为: "+Root_Mean_Squared_Error) val treeModelRegressor=modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel] val schema_decisionTree=treeModelRegressor.toDebugString println("决策树分类模型的结构为: "+schema_decisionTree) } }Spark 源码:关于setMetricName("")
@Since("2.0.0") override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) SchemaUtils.checkNumericType(schema, $(labelCol)) val predictionAndLabels = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) .rdd .map { case Row(prediction: Double, label: Double) => (prediction, label) } val metrics = new RegressionMetrics(predictionAndLabels) val metric = $(metricName) match { case "rmse" => metrics.rootMeanSquaredError case "mse" => metrics.meanSquaredError case "r2" => metrics.r2 case "mae" => metrics.meanAbsoluteError } metric }
结果:
+-----------------+------+------------+-----------------+----------+--------------+
| features| label|indexedLabel| indexedFeatures|prediction|predictedLabel|
+-----------------+------+------------+-----------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|hadoop| 1.0|[4.6,3.1,1.5,0.2]| 1.0| hadoop|
|[4.6,3.4,1.4,0.3]|hadoop| 1.0|[4.6,3.4,1.4,0.3]| 1.0| hadoop|
|[4.7,3.2,1.3,0.2]|hadoop| 1.0|[4.7,3.2,1.3,0.2]| 1.0| hadoop|
|[4.8,3.0,1.4,0.1]|hadoop| 1.0|[4.8,3.0,1.4,0.1]| 1.0| hadoop|
|[5.1,3.3,1.7,0.5]|hadoop| 1.0|[5.1,3.3,1.7,0.5]| 1.0| hadoop|
|[5.1,3.7,1.5,0.4]|hadoop| 1.0|[5.1,3.7,1.5,0.4]| 1.0| hadoop|
|[5.4,3.9,1.3,0.4]|hadoop| 1.0|[5.4,3.9,1.3,0.4]| 1.0| hadoop|
|[5.5,2.3,4.0,1.3]| spark| 0.0|[5.5,2.3,4.0,1.3]| 0.0| spark|
|[5.5,3.5,1.3,0.2]|hadoop| 1.0|[5.5,3.5,1.3,0.2]| 1.0| hadoop|
|[5.6,2.7,4.2,1.3]| spark| 0.0|[5.6,2.7,4.2,1.3]| 0.0| spark|
|[5.6,3.0,4.1,1.3]| spark| 0.0|[5.6,3.0,4.1,1.3]| 0.0| spark|
|[5.6,3.0,4.5,1.5]| spark| 0.0|[5.6,3.0,4.5,1.5]| 0.0| spark|
|[5.7,2.6,3.5,1.0]| spark| 0.0|[5.7,2.6,3.5,1.0]| 0.0| spark|
|[5.7,4.4,1.5,0.4]|hadoop| 1.0|[5.7,4.4,1.5,0.4]| 1.0| hadoop|
|[5.8,2.7,3.9,1.2]| spark| 0.0|[5.8,2.7,3.9,1.2]| 0.0| spark|
|[5.8,2.7,4.1,1.0]| spark| 0.0|[5.8,2.7,4.1,1.0]| 0.0| spark|
|[5.8,2.8,5.1,2.4]| Scala| 2.0|[5.8,2.8,5.1,2.4]| 2.0| Scala|
|[5.8,4.0,1.2,0.2]|hadoop| 1.0|[5.8,4.0,1.2,0.2]| 1.0| hadoop|
|[5.9,3.0,4.2,1.5]| spark| 0.0|[5.9,3.0,4.2,1.5]| 0.0| spark|
|[5.9,3.0,5.1,1.8]| Scala| 2.0|[5.9,3.0,5.1,1.8]| 2.0| Scala|
|[5.9,3.2,4.8,1.8]| spark| 0.0|[5.9,3.2,4.8,1.8]| 2.0| Scala|
|[6.1,2.6,5.6,1.4]| Scala| 2.0|[6.1,2.6,5.6,1.4]| 2.0| Scala|
|[6.1,2.8,4.0,1.3]| spark| 0.0|[6.1,2.8,4.0,1.3]| 0.0| spark|
|[6.3,2.9,5.6,1.8]| Scala| 2.0|[6.3,2.9,5.6,1.8]| 2.0| Scala|
|[6.3,3.4,5.6,2.4]| Scala| 2.0|[6.3,3.4,5.6,2.4]| 2.0| Scala|
|[6.4,2.7,5.3,1.9]| Scala| 2.0|[6.4,2.7,5.3,1.9]| 2.0| Scala|
|[6.4,3.1,5.5,1.8]| Scala| 2.0|[6.4,3.1,5.5,1.8]| 2.0| Scala|
|[6.4,3.2,4.5,1.5]| spark| 0.0|[6.4,3.2,4.5,1.5]| 0.0| spark|
|[6.5,2.8,4.6,1.5]| spark| 0.0|[6.5,2.8,4.6,1.5]| 0.0| spark|
|[6.5,3.0,5.5,1.8]| Scala| 2.0|[6.5,3.0,5.5,1.8]| 2.0| Scala|
|[6.7,3.0,5.2,2.3]| Scala| 2.0|[6.7,3.0,5.2,2.3]| 2.0| Scala|
|[6.7,3.1,4.7,1.5]| spark| 0.0|[6.7,3.1,4.7,1.5]| 0.0| spark|
|[6.8,3.0,5.5,2.1]| Scala| 2.0|[6.8,3.0,5.5,2.1]| 2.0| Scala|
|[6.9,3.1,5.4,2.1]| Scala| 2.0|[6.9,3.1,5.4,2.1]| 2.0| Scala|
|[7.0,3.2,4.7,1.4]| spark| 0.0|[7.0,3.2,4.7,1.4]| 0.0| spark|
|[7.1,3.0,5.9,2.1]| Scala| 2.0|[7.1,3.0,5.9,2.1]| 2.0| Scala|
|[7.2,3.0,5.8,1.6]| Scala| 2.0|[7.2,3.0,5.8,1.6]| 0.0| spark|
|[7.2,3.2,6.0,1.8]| Scala| 2.0|[7.2,3.2,6.0,1.8]| 2.0| Scala|
|[7.2,3.6,6.1,2.5]| Scala| 2.0|[7.2,3.6,6.1,2.5]| 2.0| Scala|
|[7.4,2.8,6.1,1.9]| Scala| 2.0|[7.4,2.8,6.1,1.9]| 2.0| Scala|
|[7.7,2.6,6.9,2.3]| Scala| 2.0|[7.7,2.6,6.9,2.3]| 2.0| Scala|
|[7.7,2.8,6.7,2.0]| Scala| 2.0|[7.7,2.8,6.7,2.0]| 2.0| Scala|
+-----------------+------+------------+-----------------+----------+--------------+
均方根误差为: 0.43643578047198484
决策树分类模型的结构为: DecisionTreeRegressionModel (uid=dtr_6015411b1a3d) of depth 4 with 11 nodes
If (feature 3 <= 1.7)
If (feature 2 <= 1.9)
Predict: 1.0
Else (feature 2 > 1.9)
If (feature 2 <= 4.9)
If (feature 3 <= 1.6)
Predict: 0.0
Else (feature 3 > 1.6)
Predict: 2.0
Else (feature 2 > 4.9)
If (feature 3 <= 1.5)
Predict: 2.0
Else (feature 3 > 1.5)
Predict: 0.0
Else (feature 3 > 1.7)
Predict: 2.0