- updateStateByKey 解释:
1) 定义状态:可以是任意数据类型
2) 定义状态更新函数:用一个函数指定如何使用先前的状态,从输入流中的新值更新状态。
- updateStateByKey源码:
- Return a new “state” DStream where the state for each key is updated by applying
- the given function on the previous state of the key and the new values of the key.
- org.apache.spark.Partitioner is used to control the partitioning of each RDD.
- @param updateFunc State update function. If
function returns None, then - corresponding state key-value pair will be eliminated.
- @param partitioner Partitioner for controlling the partitioning of each RDD in the new
- DStream.
- @param initialRDD initial state value of each key.
- @tparam S State type
def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
partitioner: Partitioner,
initialRDD: RDD[(K, S)]
): DStream[(K, S)] = {
val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)
- 代码实现
- StatefulNetworkWordCount
object StatefulNetworkWordCount { def main(args: Array[String]) { if (args.length < 2) { System.err.println("Usage: StatefulNetworkWordCount <hostname> <port>") System.exit(1) } Logger.getLogger("org.apache.spark").setLevel(Level.WARN) val updateFunc = (values: Seq[Int], state: Option[Int]) => { val currentCount = values.sum val previousCount = state.getOrElse(0) Some(currentCount + previousCount) } val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) } val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount").setMaster("local") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") // Initial RDD input to updateStateByKey val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by ‘nc‘) val lines = ssc.socketTextStream(args(0), args(1).toInt) val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) stateDstream.print() ssc.start() ssc.awaitTermination() } }
- NetworkWordCount
- StatefulNetworkWordCount
import org.apache.spark.SparkConf
import org.apache.spark.HashPartitioner
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
object NetworkWordCount {
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: NetworkWordCount <hostname> <port>")
val sparkConf = new SparkConf().setAppName("NetworkWordCount")
val ssc = new StreamingContext(sparkConf, Seconds(10))
val addFunc = (currValues: Seq[Int], prevValueState: Option[Int]) => {
val currentCount = currValues.sum
// 已累加的值
val previousCount = prevValueState.getOrElse(0)
// 返回累加后的结果,是一个Option[Int]类型
Some(currentCount + previousCount)
val lines = ssc.socketTextStream(args(0), args(1).toInt)
val words = lines.flatMap(_.split(" "))
val pairs = words.map(word => (word, 1))
//val currWordCounts = pairs.reduceByKey(_ + _)
val totalWordCounts = pairs.updateStateByKey[Int](addFunc)
- WebPagePopularityValueCalculator
object WebPagePopularityValueCalculator {
private val checkpointDir = "popularity-data-checkpoint"
private val msgConsumerGroup = "user-behavior-topic-message-consumer-group"
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage:WebPagePopularityValueCalculator zkserver1:2181, zkserver2: 2181, zkserver3: 2181 consumeMsgDataTimeInterval (secs) ")
val Array(zkServers, processingInterval) = args
val conf = new SparkConf().setAppName("Web Page Popularity Value Calculator")
val ssc = new StreamingContext(conf, Seconds(processingInterval.toInt))
//using updateStateByKey asks for enabling checkpoint
val kafkaStream = KafkaUtils.createStream(
//Spark streaming context
//zookeeper quorum. e.g zkserver1:2181,zkserver2:2181,...
//kafka message consumer group ID
//Map of (topic_name -> numPartitions) to consume. Each partition is consumed in its own thread
Map("user-behavior-topic" -> 3))
val msgDataRDD = kafkaStream.map(_._2)
//for debug use only
//println("Coming data in this interval...")
// e.g page37|5|1.5119122|-1
val popularityData = msgDataRDD.map { msgLine => {
val dataArr: Array[String] = msgLine.split("\\|")
val pageID = dataArr(0)
//calculate the popularity value
val popValue: Double = dataArr(1).toFloat * 0.8 + dataArr(2).toFloat * 0.8 + dataArr(3).toFloat * 1
(pageID, popValue)
//sum the previous popularity value and current value
val updatePopularityValue = (iterator: Iterator[(String, Seq[Double], Option[Double])]) => {
iterator.flatMap(t => {
val newValue: Double = t._2.sum
val stateValue: Double = t._3.getOrElse(0);
Some(newValue + stateValue)
}.map(sumedValue => (t._1, sumedValue)))
val initialRDD = ssc.sparkContext.parallelize(List(("page1", 0.00)))
//调用 updateStateByKey 原语并传入上面定义的匿名函数更新网页热度值。
val stateDStream = popularityData.updateStateByKey[Double](updatePopularityValue,
new HashPartitioner(ssc.sparkContext.defaultParallelism), true, initialRDD)
//set the checkpoint interval to avoid too frequently data checkpoint which may
//may significantly reduce operation throughput
stateDStream.checkpoint(Duration(8 * processingInterval.toInt * 1000))
//after calculation, we need to sort the result and only show the top 10 hot pages
//最后得到最新结果后,需要对结果进行排序,最后打印热度值最高的 10 个网页。
stateDStream.foreachRDD { rdd => {
val sortedData = rdd.map { case (k, v) => (v, k) }.sortByKey(false)
val topKData = sortedData.take(10).map { case (v, k) => (k, v) }
topKData.foreach(x => {
时间: 2025-01-04 02:09:31