Large Scale Distributed Semi-Supervised Learning Using Streaming Approximation
官方 Blog 链接:https://research.googleblog.com/2016/10/graph-powered-machine-learning-at-google.html
今天讲的是一个基于 streaming approximation 的大规模分布式半监督学习框架,出自 Google 。
摘要:众所周知,传统的 graph-based 半监督学习方法不适合处理大批量数据和大型标签场景,因为其计算量和他们的 边 |E| 和 直接标签 m 的个数是线性关系。为了处理大型标签尺度问题,最近的工作提出了 sketch-based methods 来预测每一个节点的标签分布,故而将空间复杂度由 O(m) 降到了 O(log m),在一定的条件下。
本文提出一种 新颖的 streaming graph-based SSL approximation 的方法有效的抓住了标签分布的稀疏性(the sparisity),进一步的将空间复杂度降到了 O(1). 与此同时,本文提出一种分布式版本的算法可以处理大批量数据的情况。在实际世界的数据集中的实验,证明所提出的方法比现有方法可以达到明显的内存降低。最后,本文提出一种鲁邦的利用半监督深度学习框架的 graph augmentation strategy,并且在自然语言应用上取得了较好的半监督学习效果。
引言:SSL 是利用少量有标签数据和海量无标签数据去训练一个预测系统(prediction systems)。其研究意义就在于,现有的标注总是少量的,而且标注工作是枯燥耗时的,而无标签数据又是海量的,如何利用有限的有标签数据结合海量无标签数据,进一步的提升现有模型的性能,是一个值得关注的课题。
关于不同 SSL methods 的局限性,主要体现在:昂贵的计算代价! 比如,transductive SVM 和 Graph-based SSL 算法是 SSL 算法中比较出名的一个子类。这些方法的核心 idea 就是构建和平滑一个 graph,利用 点 和 边 去链接他们之间的关系。边权(edge weights)是根据节点之间的相似性得到的。基于标签传递(label propagation)的 Graph-based methods 利用已有的种子节点,通过 graph 去传递其标签信息。这些方法通常收敛的很快,并且他们的时间和空间复杂度和边的个数以及 label 的个数呈线性关系。
但是,有些场景所涉及到的样本数量 和 label 个数真的是非常巨大,常规的基于 graph 的 SSL 方法无法处理。通常,单独的节点用稀疏的标签分布来进行初始化,但是随着迭代次数的增加,他们将变得 dense。Talukdar and Cohen 最近提出一种方法【1】试图克服 label scale problem ,通过一个 Count-Min Sketch 的方法来预测每一个 node 的 label 和他们的 score 。这使得内存复杂度变得非常低。但是,在实际世界的应用中, actual label k 的个数和每一个节点的连接实际上是 sparse 的,尽管总的 label space 是非常 huge 的,也就是说 K 是远小于 m 的。很明显,在实际应用中,考虑到label 的稀疏性可以显著的降低复杂度。
Contributions:
1. 本文提出一种新的 graph propagation algorithm 进行 general purpose SSL 。
2. 该算法可以处理有大量 label 的情况。其核心是,利用一种 approximation 有效的抓住了 标签分布的稀疏性,确保算法可以准确的传递标签。
3. 提出 并行化处理版本的算法,可以很好的处理 large graph sizes.
4. 提出一种 有效的线性时间 构图策略,可以有效的结合多种信号,可以动态的从 sparse 到 dense representation。
5. 特别的,graphs ,节点表示文本信息,仅仅利用 原始文本 和 顶尖的 DL 技术,可能会鲁邦的学习到和这些节点联系的 latent semantic embeddings 。
用这种 embedding 的方式增强原始 graph,然后用 graph SSL 产生了明显的提升。
Graph-based Semi-Supervised Learning :
Preliminary : 目标是产生一个 soft assignment of labels to each node in a graph G=(V,E,W)。
Graph SSL Optimization :
通过最小化下列的目标函数来学习一个 label distribution $Y^\hat$ :
其中,N(v) 代表 节点 v 的近邻节点,U 是所有label 的先验分布。