word2vec学习 spark版

参考资料:

http://ir.dlut.edu.cn/NewsShow.aspx?ID=291

http://www.douban.com/note/298095260/

http://machinelearning.wustl.edu/mlpapers/paper_files/BengioDVJ03.pdf

https://code.google.com/p/word2vec/

word2vec是NLP领域的重要算法,它的功能是将word用K维的dense vector来表达,训练集是语料库,不含标点,以空格断句。因此可以看作是种特征处理方法。

主要优点:

  • 加法操作。
  • 高效。单机可处理1小时2千万词。

google的开源版本比较权威,地址( http://word2vec.googlecode.com/svn/trunk/ ),不过我以spark版本学习的。

I.背景知识

Distributed representation,word的特征表达方式,通过训练将每个词映射成 K 维实数向量(K 一般为模型中的超参数),通过词之间的距离(比如 cosine 相似度、欧氏距离等)来判断它们之间的语义相似度。

语言模型:n-gram等。

II.模型

0.word window构成context,对于一个单词i,以$u_i$表示,它作为别的单词的context时用$v_i$表示(也即它作为context的表示是不同的)。只有word window内的word才被认为是context,并且是顺序无关的。

1.概率模型为\[ P=\sum lot p(u_i) \],i表示位置(或单词),也即各单词出现概率的累积函数。

2.以skip gram为例(CBOW条件概率反过来),则位置i的单词出现概率为

\[ p(u_i)=\sum_{-c\leq j\leq c,j\neq 0} p(v_{i+j}|u_{i}) \]

表示位置i只和其context有关。

3.条件概率$p(v_{i+j}|u_i)$ 通过softmax实现K维向量到概率的转化表达。

III.优化

最开始使用神经网络,后来用层次softmax来降低时间复杂度。还用了很多trick,比如ExpTable。

IV.spark源码分析

  1 /**
  2 * Licensed to the Apache Software Foundation (ASF) under one or more
  3 * contributor license agreements.  See the NOTICE file distributed with
  4 * this work for additional information regarding copyright ownership.
  5 * The ASF licenses this file to You under the Apache License, Version 2.0
  6 * (the "License"); you may not use this file except in compliance with
  7 * the License.  You may obtain a copy of the License at
  8 *
  9 *    http://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 */
 17
 18 package org.apache.spark.mllib.feature
 19
 20 import java.lang.{Iterable => JavaIterable}
 21
 22 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 23 import org.apache.spark.Logging
 24 import org.apache.spark.SparkContext._
 25 import org.apache.spark.annotation.Experimental
 26 import org.apache.spark.api.java.JavaRDD
 27 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 28 import org.apache.spark.rdd.RDD
 29 import org.apache.spark.util.Utils
 30 import org.apache.spark.util.random.XORShiftRandom
 31 import scala.collection.JavaConverters._
 32 import scala.collection.mutable
 33 import scala.collection.mutable.ArrayBuffer
 34
 35
 36 /**
 37  *  Entry in vocabulary
 38  */
 39 private case class VocabWord(
 40                               var word: String,
 41                               var cn: Int,
 42                               var point: Array[Int],
 43                               var code: Array[Int],
 44                               var codeLen:Int
 45                               )
 46
 47 /**
 48  * :: Experimental ::
 49  * Word2Vec creates vector representation of words in a text corpus.
 50  * The algorithm first constructs a vocabulary from the corpus
 51  * and then learns vector representation of words in the vocabulary.
 52  * The vector representation can be used as features in
 53  * natural language processing and machine learning algorithms.
 54  *
 55  * We used skip-gram model in our implementation and hierarchical softmax
 56  * method to train the model. The variable names in the implementation
 57  * matches the original C implementation.
 58  *
 59  * For original C implementation, see https://code.google.com/p/word2vec/
 60  * For research papers, see
 61  * Efficient Estimation of Word Representations in Vector Space
 62  * and
 63  * Distributed Representations of Words and Phrases and their Compositionality.
 64  */
 65 @Experimental
 66 class Word2VectorEX extends Serializable with Logging {
 67
 68   private var vectorSize = 100
 69   private var startingAlpha = 0.025
 70   private var numPartitions = 1
 71   private var numIterations = 1
 72   private var seed = Utils.random.nextLong()
 73
 74   /**
 75    * Sets vector size (default: 100).
 76    */
 77   def setVectorSize(vectorSize: Int): this.type = {
 78     this.vectorSize = vectorSize
 79     this
 80   }
 81
 82   /**
 83    * Sets initial learning rate (default: 0.025).
 84    */
 85   def setLearningRate(learningRate: Double): this.type = {
 86     this.startingAlpha = learningRate
 87     this
 88   }
 89
 90   /**
 91    * Sets number of partitions (default: 1). Use a small number for accuracy.
 92    */
 93   def setNumPartitions(numPartitions: Int): this.type = {
 94     require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
 95     this.numPartitions = numPartitions
 96     this
 97   }
 98
 99   /**
100    * Sets number of iterations (default: 1), which should be smaller than or equal to number of
101    * partitions.
102    */
103   def setNumIterations(numIterations: Int): this.type = {
104     this.numIterations = numIterations
105     this
106   }
107
108   /**
109    * Sets random seed (default: a random long integer).
110    */
111   def setSeed(seed: Long): this.type = {
112     this.seed = seed
113     this
114   }
115
116   private val EXP_TABLE_SIZE = 1000
117   private val MAX_EXP = 6
118   private val MAX_CODE_LENGTH = 40
119   private val MAX_SENTENCE_LENGTH = 1000
120
121   /** context words from [-window, window] */
122   private val window = 5            //context 范围限定
123
124   /** minimum frequency to consider a vocabulary word */
125   private val minCount = 5           //过滤单词阈值
126
127   private var trainWordsCount = 0          //语料库总共词量(计重复出现)
128   private var vocabSize = 0                 //词表内单词总数
129   private var vocab: Array[VocabWord] = null      //词表
130   private var vocabHash = mutable.HashMap.empty[String, Int]      //词表反查索引
131
132   private def learnVocab(words: RDD[String]): Unit = {           //构造词表,统计更新上面四个量
133     vocab = words.map(w => (w, 1))
134       .reduceByKey(_ + _)
135       .map(x => VocabWord(
136       x._1,
137       x._2,
138       new Array[Int](MAX_CODE_LENGTH),
139       new Array[Int](MAX_CODE_LENGTH),
140       0))
141       .filter(_.cn >= minCount)
142       .collect()
143       .sortWith((a, b) => a.cn > b.cn)
144
145     vocabSize = vocab.length
146     var a = 0
147     while (a < vocabSize) {
148       vocabHash += vocab(a).word -> a
149       trainWordsCount += vocab(a).cn
150       a += 1
151     }
152     logInfo("trainWordsCount = " + trainWordsCount)
153   }
154
155   private def createExpTable(): Array[Float] = {        //指数运算查表
156     val expTable = new Array[Float](EXP_TABLE_SIZE)
157     var i = 0
158     while (i < EXP_TABLE_SIZE) {
159       val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
160       expTable(i) = (tmp / (tmp + 1.0)).toFloat
161       i += 1
162     }
163     expTable
164   }
165
166   private def createBinaryTree(): Unit = {
167     val count = new Array[Long](vocabSize * 2 + 1)
168     val binary = new Array[Int](vocabSize * 2 + 1)
169     val parentNode = new Array[Int](vocabSize * 2 + 1)
170     val code = new Array[Int](MAX_CODE_LENGTH)
171     val point = new Array[Int](MAX_CODE_LENGTH)
172     var a = 0
173     while (a < vocabSize) {
174       count(a) = vocab(a).cn
175       a += 1
176     }
177     while (a < 2 * vocabSize) {
178       count(a) = 1e9.toInt
179       a += 1
180     }
181     var pos1 = vocabSize - 1
182     var pos2 = vocabSize
183
184     var min1i = 0
185     var min2i = 0
186
187     a = 0
188     while (a < vocabSize - 1) {
189       if (pos1 >= 0) {
190         if (count(pos1) < count(pos2)) {
191           min1i = pos1
192           pos1 -= 1
193         } else {
194           min1i = pos2
195           pos2 += 1
196         }
197       } else {
198         min1i = pos2
199         pos2 += 1
200       }
201       if (pos1 >= 0) {
202         if (count(pos1) < count(pos2)) {
203           min2i = pos1
204           pos1 -= 1
205         } else {
206           min2i = pos2
207           pos2 += 1
208         }
209       } else {
210         min2i = pos2
211         pos2 += 1
212       }
213       count(vocabSize + a) = count(min1i) + count(min2i)
214       parentNode(min1i) = vocabSize + a
215       parentNode(min2i) = vocabSize + a
216       binary(min2i) = 1
217       a += 1
218     }
219     // Now assign binary code to each vocabulary word
220     var i = 0
221     a = 0
222     while (a < vocabSize) {
223       var b = a
224       i = 0
225       while (b != vocabSize * 2 - 2) {
226         code(i) = binary(b)
227         point(i) = b
228         i += 1
229         b = parentNode(b)
230       }
231       vocab(a).codeLen = i
232       vocab(a).point(0) = vocabSize - 2
233       b = 0
234       while (b < i) {
235         vocab(a).code(i - b - 1) = code(b)
236         vocab(a).point(i - b) = point(b) - vocabSize
237         b += 1
238       }
239       a += 1
240     }
241   }
242
243   /**
244    * Computes the vector representation of each word in vocabulary.
245    * @param dataset an RDD of words
246    * @return a Word2VecModel
247    */
248   def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VectorModel = {
249
250     val words = dataset.flatMap(x => x)       //拉成词序列,句话断点通过Iterable来表征
251
252     learnVocab(words)        //学习词库
253
254     createBinaryTree()
255
256     val sc = dataset.context
257
258     val expTable = sc.broadcast(createExpTable())
259     val bcVocab = sc.broadcast(vocab)
260     val bcVocabHash = sc.broadcast(vocabHash)
261
262     val sentences: RDD[Array[Int]] = words.mapPartitions { iter =>            //按句子划分,单词以Int表征
263       new Iterator[Array[Int]] {
264         def hasNext: Boolean = iter.hasNext
265
266         def next(): Array[Int] = {
267           var sentence = new ArrayBuffer[Int]
268           var sentenceLength = 0
269           while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
270             val word = bcVocabHash.value.get(iter.next())
271             word match {
272               case Some(w) =>
273                 sentence += w
274                 sentenceLength += 1
275               case None =>
276             }
277           }
278           sentence.toArray
279         }
280       }
281     }
282
283     //Hierarchical Softmax
284     val newSentences = sentences.repartition(numPartitions).cache()
285     val initRandom = new XORShiftRandom(seed)
286     val syn0Global =
287       Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
288     val syn1Global = new Array[Float](vocabSize * vectorSize)
289     var alpha = startingAlpha
290     for (k <- 1 to numIterations) {
291       val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
292         val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))    //随机梯度下降
293         val syn0Modify = new Array[Int](vocabSize)
294         val syn1Modify = new Array[Int](vocabSize)
295         val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
296           case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
297             var lwc = lastWordCount
298             var wc = wordCount
299             if (wordCount - lastWordCount > 10000) {
300               lwc = wordCount
301               // TODO: discount by iteration?
302               alpha =
303                 startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
304               if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
305               logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
306             }
307             wc += sentence.size
308             var pos = 0
309             while (pos < sentence.size) {
310               val word = sentence(pos)
311               val b = random.nextInt(window)
312               // Train Skip-gram
313               var a = b
314               while (a < window * 2 + 1 - b) {
315                 if (a != window) {
316                   val c = pos - window + a
317                   if (c >= 0 && c < sentence.size) {
318                     val lastWord = sentence(c)
319                     val l1 = lastWord * vectorSize
320                     val neu1e = new Array[Float](vectorSize)
321                     // Hierarchical softmax
322                     var d = 0
323                     while (d < bcVocab.value(word).codeLen) {
324                       val inner = bcVocab.value(word).point(d)
325                       val l2 = inner * vectorSize
326                       // Propagate hidden -> output
327                       var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
328                       if (f > -MAX_EXP && f < MAX_EXP) {
329                         val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
330                         f = expTable.value(ind)
331                         val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
332                         blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
333                         blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
334                         syn1Modify(inner) += 1
335                       }
336                       d += 1
337                     }
338                     blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
339                     syn0Modify(lastWord) += 1
340                   }
341                 }
342                 a += 1
343               }
344               pos += 1
345             }
346             (syn0, syn1, lwc, wc)
347         }
348         val syn0Local = model._1
349         val syn1Local = model._2
350         // Only output modified vectors.
351         Iterator.tabulate(vocabSize) { index =>
352           if (syn0Modify(index) > 0) {
353             Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
354           } else {
355             None
356           }
357         }.flatten ++ Iterator.tabulate(vocabSize) { index =>
358           if (syn1Modify(index) > 0) {
359             Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
360           } else {
361             None
362           }
363         }.flatten
364       }
365       val synAgg = partial.reduceByKey { case (v1, v2) =>
366         blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
367         v1
368       }.collect()
369       var i = 0
370       while (i < synAgg.length) {
371         val index = synAgg(i)._1
372         if (index < vocabSize) {
373           Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
374         } else {
375           Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
376         }
377         i += 1
378       }
379     }
380     newSentences.unpersist()
381
382     val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
383     var i = 0
384     while (i < vocabSize) {
385       val word = bcVocab.value(i).word
386       val vector = new Array[Float](vectorSize)
387       Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
388       word2VecMap += word -> vector
389       i += 1
390     }
391
392     new Word2VectorModel(word2VecMap.toMap)
393   }
394
395   /**
396    * Computes the vector representation of each word in vocabulary (Java version).
397    * @param dataset a JavaRDD of words
398    * @return a Word2VecModel
399    */
400   def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VectorModel = {
401     fit(dataset.rdd.map(_.asScala))
402   }
403
404 }
405
406 /**
407  * :: Experimental ::
408  * Word2Vec model
409  */
410 @Experimental
411 class Word2VectorModel private[mllib] (
412                                      private  val model: Map[String, Array[Float]]) extends Serializable {
413
414   private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
415     require(v1.length == v2.length, "Vectors should have the same length")
416     val n = v1.length
417     val norm1 = blas.snrm2(n, v1, 1)
418     val norm2 = blas.snrm2(n, v2, 1)
419     if (norm1 == 0 || norm2 == 0) return 0.0
420     blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
421   }
422
423   /**
424    * Transforms a word to its vector representation
425    * @param word a word
426    * @return vector representation of word
427    */
428   def transform(word: String): Vector = {
429     model.get(word) match {
430       case Some(vec) =>
431         Vectors.dense(vec.map(_.toDouble))
432       case None =>
433         throw new IllegalStateException(s"$word not in vocabulary")
434     }
435   }
436
437   /**
438    * Find synonyms of a word
439    * @param word a word
440    * @param num number of synonyms to find
441    * @return array of (word, similarity)
442    */
443   def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
444     val vector = transform(word)
445     findSynonyms(vector,num)
446   }
447
448   /**
449    * Find synonyms of the vector representation of a word
450    * @param vector vector representation of a word
451    * @param num number of synonyms to find
452    * @return array of (word, cosineSimilarity)
453    */
454   def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
455     require(num > 0, "Number of similar words should > 0")
456     // TODO: optimize top-k
457     val fVector = vector.toArray.map(_.toFloat)
458     model.mapValues(vec => cosineSimilarity(fVector, vec))
459       .toSeq
460       .sortBy(- _._2)
461       .take(num + 1)
462       .tail
463       .toArray
464   }
465
466
467   def getModel(): Map[String, Array[Float]] = {
468      model
469   }
470
471
472 }
时间: 2024-08-01 03:38:14

word2vec学习 spark版的相关文章

Spark 学习: spark 原理简述与 shuffle 过程介绍

Spark学习: 简述总结 Spark 是使用 scala 实现的基于内存计算的大数据开源集群计算环境.提供了 java,scala, python,R 等语言的调用接口. Spark学习 简述总结 引言 1 Hadoop 和 Spark 的关系 Spark 系统架构 1 spark 运行原理 RDD 初识 shuffle 和 stage 性能优化 1 缓存机制和 cache 的意义 2 shuffle 的优化 3 资源参数调优 4 小结 本地搭建 Spark 开发环境 1 Spark-Scal

小白学习Spark系列四:rdd踩坑总结

初次尝试用 Spark+scala 完成项目的重构,由于两者之前都没接触过,所以边学边用的过程大多艰难.首先面临的是如何快速上手,然后是代码调优.性能调优.本章主要记录自己在项目中遇到的问题以及解决方式,下篇会尝试调优方法.末尾会分享自己的学习资料,也供大多菜鸟第一次使用作为参考.由于自己项目中大量使用spark sql,所以下面的经验大多是和spark sql有关.同样下面也列出作为菜鸟在学习过程中的困惑以及踩的坑,还请大牛勿笑 ~_~ 如果有更好的方式解决,欢迎留言,一起学习. 1.常用场景

word2vec学习总结

目录 1.简介 2.从统计语言模型开始 2.1序列概率模型 2.2 N元统计模型 3.深度序列模型 3.1神经概率模型 3.2 one-hot向量表示法 3.3 word2vec 3.4word2vec的实际运用 4.总结 参考资料 1.简介 word2vec是Google于2013年推出的开源的获取词向量word2vec的工具包.它包括了一组用于word embedding的模型,这些模型通常都是用浅层(两层)神经网络训练词向量. Word2vec的模型以大规模语料库作为输入,通过神经网络训练

【转载】Spark学习——spark中的几个概念的理解及参数配置

首先是一张Spark的部署图: 节点类型有: 1. master 节点: 常驻master进程,负责管理全部worker节点.2. worker 节点: 常驻worker进程,负责管理executor 并与master节点通信.dirvier:官方解释为: The process running the main() function of the application and creating the SparkContext.即理解为用户自己编写的应用程序 一.Application ap

努力学习Spark技术,顺应大数据开发潮流

当前,大数据的实时计算.分析和可视化是行业大数据应用真正落地的关键.为适应这一需求和趋势,开源组织Apache提出了基于Spark 分析和计算框架,其优点: (1)       性能优越.框架中的Spark技术指内存计算:数据处理仅运行于系统内存中,避免以前的框架系统中的硬盘和内存频繁交换数据的时间消耗.另外,Spark框架为批处理(Spark Core),交互式(Spark SQL),流式(Spark Streaming),机器学习(MLlib),图计算(Graphx)提供一个统一的数据处理平

Oracle11g与PL/SQL的安装与配置(单机学习简易版)

数据库的安装程序步骤 不勾选接受安全更新,可不提供电子邮箱,跳过警告即可. 选择仅安装数据库软件 单实例数据库安装 默认语言 数据库版本选择企业版 Oracle基目录:存放所有Oracle软件产品及配置文件,Oracle主目录:Oracle软件安装位置 直接下一步 (自动进行先决条件检查) 直接下一步 (概要显示) 直接下一步 (自动安装) 安装完成 数据库的创建与监听器的配置     数据库的创建 开始菜单->所有程序->Oracle->配置和移植工具->Database Con

Eclipse 编写的spark版WordCount运行在spark上

1.代码编写 if (args.length != 3 ){       println("usage is org.test.WordCount <master> <input> <output>")       return     }     val sc = new SparkContext(args(0), "WordCount",     System.getenv("SPARK_HOME"), S

系统学习spark计划_spark老汤

老汤spark由浅入深深到底系列课程计划简介如下: spark 2.x 由浅入深深到底一_正确理解spark(课程已经发布) 本课程不会有太多的技术细节,主要是一些分布式存储和计算的基本理念和思想 目标: 1: 彻底弄懂什么是RDD及其特点 2: 彻底弄懂什么是spark的分布式内存计算,进而理解spark真正解决的问题 3: 彻底弄懂spark是怎样解决各个领域问题,以及在解决各个领域问题的时候的特点 见:http://edu.51cto.com/course/10932.html spark

为什么学习Spark?

讲到Spark,我们要先讲将大数据,因为Spark是一个处理大数据的平台.那什么是大数据呢?大数据也称为海量数据,通常都是指TG级.PB级的数据,但是随着时间的流逝和验证,一些人们对大数据也发生了观念上的转变,而是把对企业有价值的数据称为大数据,它不一定大,但要有价值. 近二十年来,随着计算机技术和互联网技术的高速发展,计算机在记录着人类的日常活动,如手机通讯.网站访问.微博留言.视频上传.商品买卖.物流运送...,无处不在的社会活动和商业活动正在源源不断地产生着各种各样的数据,而这些数据正以不