package com.profile.mainimport org.apache.spark.sql.expressions.Windowimport org.apache.spark.sql.functions._ import com.profile.tools.{DateTools, JdbcTools, LogTools, SparkTools}import com.dhd.comment.Constantimport com.profile.comment.Comments /** * 测试类 //使用dataframe解决spark TopN问题:分组、排序、取TopN * @author * date 2017-09-27 14:55 */object Test { def main(args: Array[String]): Unit = { val sc=SparkTools.getSparkContext val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.implicits._ val df = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue") df.show /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| +----+--------+----------+ */ /* val w = Window.partitionBy($"Hour").orderBy($"TotalValue".desc) //取Top1 val dfTop1 = df.withColumn("rn", rowNumber.over(w)).where($"rn" === 1).drop("rn") //注意:row_number()在spark1.x版本中为rowNumber(),在2.x版本为row_number() //取Top3 val dfTop3 = df.withColumn("rn", rowNumber.over(w)).where($"rn" <= 3).drop("rn") dfTop1.show*/ /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 1| cat67| 28.5| | 3| cat8| 35.6| | 2| cat56| 39.6| | 0| cat26| 30.9| +----+--------+----------+ */// dfTop3.show /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 3| cat8| 35.6| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| +----+--------+----------+ */ //使用RDD解决spark TopN问题:分组、排序、取TopN val rdd1 = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))) val rdd2 = rdd1.map(x => (x._1,(x._2, x._3))).groupByKey() /* rdd2.collect res9: Array[(Int, Iterable[(String, Double)])] = Array((0,CompactBuffer((cat26,30.9), (cat13,22.1), (cat95,19.6), (cat105,1.3))), (1,CompactBuffer((cat67,28.5), (cat4,26.8), (cat13,12.6), (cat23,5.3))), (2,CompactBuffer((cat56,39.6), (cat40,29.7), (cat187,27.9), (cat68,9.8))), (3,CompactBuffer((cat8,35.6)))) */ val N_value = 1 //取前3 val rdd3 = rdd2.map( x => { val i2 = x._2.toBuffer val i2_2 = i2.sortBy(_._2) if (i2_2.length > N_value) i2_2.remove(0, (i2_2.length - N_value)) (x._1, i2_2.toIterable) }) /* rdd3.collect res8: Array[(Int, Iterable[(String, Double)])] = Array((0,ArrayBuffer((cat95,19.6), (cat13,22.1), (cat26,30.9))), (1,ArrayBuffer((cat13,12.6), (cat4,26.8), (cat67,28.5))), (2,ArrayBuffer((cat187,27.9), (cat40,29.7), (cat56,39.6))), (3,ArrayBuffer((cat8,35.6)))) */ val rdd4 = rdd3.flatMap(x => { val y = x._2 for (w <- y) yield (x._1, w._1, w._2) }) rdd4.collect /* res3: Array[(Int, String, Double)] = Array((0,cat95,19.6), (0,cat13,22.1), (0,cat26,30.9), (1,cat13,12.6), (1,cat4,26.8), (1,cat67,28.5), (2,cat187,27.9), (2,cat40,29.7), (2,cat56,39.6), (3,cat8,35.6)) */ rdd4.toDF("Hour", "Category", "TotalValue").show /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat95| 19.6| | 0| cat13| 22.1| | 0| cat26| 30.9| | 2| cat187| 27.9| | 2| cat40| 29.7| | 2| cat56| 39.6| | 1| cat13| 12.6| | 1| cat4| 26.8| | 1| cat67| 28.5| | 3| cat8| 35.6| +----+--------+----------+*/ } }
时间: 2024-11-13 04:59:22