jrae源码解析(一)

jare用java实现了论文《Semi-Supervised Recursive Autoencoders for Predicting Sentiment Distributions》中提出的算法——基于半监督的递归自动编码机,用来预测情感分类。详情可查看论文内容,代码git地址为:https://github.com/sancha/jrae。

鸟瞰

主函数训练流程

FineTunableTheta tunedTheta = rae.train(params);// 根据参数和数据训练神经网络权重
      tunedTheta.Dump(params.ModelFile);

      System.out.println("RAE trained. The model file is saved in "
          + params.ModelFile);
    // 特征抽取器
      RAEFeatureExtractor fe = new RAEFeatureExtractor(params.EmbeddingSize,
          tunedTheta, params.AlphaCat, params.Beta, params.CatSize,
          params.Dataset.Vocab.size(), rae.f);
    // 获取训练数据
      List<LabeledDatum<Double, Integer>> classifierTrainingData = fe
          .extractFeaturesIntoArray(params.Dataset, params.Dataset.Data,
              params.TreeDumpDir);
    // 测试精度
      SoftmaxClassifier<Double, Integer> classifier = new SoftmaxClassifier<Double, Integer>();
      Accuracy TrainAccuracy = classifier.train(classifierTrainingData);
      System.out.println("Train Accuracy :" + TrainAccuracy.toString());

几个重要的接口以及实现类

1、Minimizer<T extends DifferentiableFunction>

public interface Minimizer<T extends DifferentiableFunction> {

  /**
   * Attempts to find an unconstrained minimum of the objective
   * <code>function</code> starting at <code>initial</code>, within
   * <code>functionTolerance</code>.
   *
   * @param function          the objective function
   * @param functionTolerance a <code>double</code> value
   * @param initial           a initial feasible point
   * @return Unconstrained minimum of function
   */
  double[] minimize(T function, double functionTolerance, double[] initial);
  double[] minimize(T function, double functionTolerance, double[] initial, int maxIterations);

}

如其所述,该接口用来找到给定目标函数的最小化极值,目标函数必须是处处可微的,并实现DifferentiableFunction接口。functionTolerance是最小误差,initial是初始点,maxIterations是最大迭代次数。

public interface DifferentiableFunction extends Function {
  double[] derivativeAt(double[] x);
}

public interface Function {
  int dimension();
  double valueAt(double[] x);
}

QNMinimizer类实现了该接口,利用L-BFGS优化算法对目标函数进行优化,下面是算法的注释:

/**
 * This code is part of the Stanford NLP Toolkit.
 *
 *
 * An implementation of L-BFGS for Quasi Newton unconstrained minimization.
 *
 * The general outline of the algorithm is taken from: <blockquote> <i>Numerical
 * Optimization</i> (second edition) 2006 Jorge Nocedal and Stephen J. Wright
 * </blockquote> A variety of different options are available.
 *
 * <h3>LINESEARCHES</h3>
 *
 * BACKTRACKING: This routine simply starts with a guess for step size of 1. If
 * the step size doesn‘t supply a sufficient decrease in the function value the
 * step is updated through step = 0.1*step. This method is certainly simpler,
 * but doesn‘t allow for an increase in step size, and isn‘t well suited for
 * Quasi Newton methods.
 *
 * MINPACK: This routine is based off of the implementation used in MINPACK.
 * This routine finds a point satisfying the Wolfe conditions, which state that
 * a point must have a sufficiently smaller function value, and a gradient of
 * smaller magnitude. This provides enough to prove theoretically quadratic
 * convergence. In order to find such a point the linesearch first finds an
 * interval which must contain a satisfying point, and then progressively
 * reduces that interval all using cubic or quadratic interpolation.
 *
 *
 * SCALING: L-BFGS allows the initial guess at the hessian to be updated at each
 * step. Standard BFGS does this by approximating the hessian as a scaled
 * identity matrix. To use this method set the scaleOpt to SCALAR. A better way
 * of approximate the hessian is by using a scaling diagonal matrix. The
 * diagonal can then be updated as more information comes in. This method can be
 * used by setting scaleOpt to DIAGONAL.
 *
 *
 * CONVERGENCE: Previously convergence was gauged by looking at the average
 * decrease per step dividing that by the current value and terminating when
 * that value because smaller than TOL. This method fails when the function
 * value approaches zero, so two other convergence criteria are used. The first
 * stores the initial gradient norm |g0|, then terminates when the new gradient
 * norm, |g| is sufficiently smaller: i.e., |g| < eps*|g0| the second checks
 * if |g| < eps*max( 1 , |x| ) which is essentially checking to see if the
 * gradient is numerically zero.
 *
 * Each of these convergence criteria can be turned on or off by setting the
 * flags: <blockquote><code>
 * private boolean useAveImprovement = true;
 * private boolean useRelativeNorm = true;
 * private boolean useNumericalZero = true;
 * </code></blockquote>
 *
 * To use the QNMinimizer first construct it using <blockquote><code>
 * QNMinimizer qn = new QNMinimizer(mem, true)
 * </code>
 * </blockquote> mem - the number of previous estimate vector pairs to store,
 * generally 15 is plenty. true - this tells the QN to use the MINPACK
 * linesearch with DIAGONAL scaling. false would lead to the use of the criteria
 * used in the old QNMinimizer class.
 */

OK,可以结合我前面文章,了解L-BFGS算法的原理,然后该类实现了这个算法,并且在某些细节上做了一些修改。具体的实现算法先略去不议,日后再说。

2、DifferentiableFunction

DifferentiableFunction定义上面已经给出,对应一个可微的函数。抽象类MemoizedDifferentiableFunction实现了这个接口,封装了一些通用的代码:

public abstract class MemoizedDifferentiableFunction implements DifferentiableFunction {
	protected double[] prevQuery, gradient;
	protected double value;
	protected int evalCount;

	protected void initPrevQuery()
	{
		prevQuery = new double[ dimension() ];
	}

	protected boolean requiresEvaluation(double[] x)
	{
		if(DoubleArrays.equals(x,prevQuery))
			return false;

		System.arraycopy(x, 0, prevQuery, 0, x.length);
		evalCount++;
		return true;
	}

	@Override
	public double[] derivativeAt(double[] x){
		if(DoubleArrays.equals(x,prevQuery))
			return gradient;
		valueAt(x);
		return gradient;
	}
}

封装的通用方法为,保存了上次请求的参数,如果传入参数已经被请求过,直接返回结果即可;保存了执行请求的次数;实现了求导流程,首先调用valueAt求得当前值$f(x)$,然后返回梯度(导数),valueAt由子类实现,即约定子类在计算$f(x)$的时候顺便计算好了$f‘(x)$,然后保存到gradient变量中。

两个子类分别为RAECost和SoftmaxCost。

SoftmaxCost类表示,在给定样本的情况下,计算出给定权重的误差,导数指明减小误差的梯度。对应的是一个2层的网络,输入层为features(特征),输出层为label,并且转换函数为softmax(能量函数)。

RAECost类表示,在给定样本的情况下,计算出给定权重的误差,误差包括生成递归树的误差与label分类的误差只和,导数指明梯度,也是两者梯度之和。

在调用Minimizer接口进行优化时,传入的第一个参数即是RAECost对象,优化完毕时即是训练完毕时。

参考文献:

http://www.socher.org/index.php/Main/Semi-SupervisedRecursiveAutoencodersForPredictingSentimentDistributions

时间: 2024-10-27 18:46:12

jrae源码解析(一)的相关文章

jrae源码解析(二)

本文细述上文引出的RAECost和SoftmaxCost两个类. SoftmaxCost 我们已经知道,SoftmaxCost类在给定features和label的情况下(超参数给定),衡量给定权重($hidden\times catSize$)的误差值$cost$,并指出当前的权重梯度.看代码. @Override public double valueAt(double[] x) { if( !requiresEvaluation(x) ) return value; int numData

ChrisRenke/DrawerArrowDrawable源码解析

转载请注明出处http://blog.csdn.net/crazy__chen/article/details/46334843 源码下载地址http://download.csdn.net/detail/kangaroo835127729/8765757 这次解析的控件DrawerArrowDrawable是一款侧拉抽屉效果的控件,在很多应用上我们都可以看到(例如知乎),控件的github地址为https://github.com/ChrisRenke/DrawerArrowDrawable

五.jQuery源码解析之jQuery.extend(),jQuery.fn.extend()

给jQuery做过扩展或者制作过jQuery插件的人这两个方法东西可能不陌生.jQuery.extend([deep],target,object1,,object2...[objectN]) jQuery.fn.extend([deep],target,object1,,object2...[objectN])这两个属性都是用于合并两个或多个对象的属性到target对象.deep是布尔值,表示是否进行深度合并,默认是false,不执行深度合并.通过这种方式可以在jQuery或jQuery.fn

eclipse中导入jdk源码、SpringMVC注解@RequestParam、SpringMVC文件上传源码解析、ajax上传excel文件

eclipse中导入jdk源码:http://blog.csdn.net/evolly/article/details/18403321, http://www.codingwhy.com/view/799.html. ------------------------------- SpringMVC注解@RequestParam:http://825635381.iteye.com/blog/2196911. --------------------------- SpringMVC文件上传源

String源码解析(一)

本篇文章内的方法介绍,在方法的上面的注释讲解的很清楚,这里只阐述一些要点. Java中的String类的定义如下: 1 public final class String 2 implements java.io.Serializable, Comparable<String>, CharSequence { ...} 可以看到,String是final的,而且继承了Serializable.Comparable和CharSequence接口. 正是因为这个特性,字符串对象可以被共享,例如下面

Flume-ng源码解析之Channel组件

如果还没看过Flume-ng源码解析之启动流程,可以点击Flume-ng源码解析之启动流程 查看 1 接口介绍 组件的分析顺序是按照上一篇中启动顺序来分析的,首先是Channel,然后是Sink,最后是Source,在开始看组件源码之前我们先来看一下两个重要的接口,一个是LifecycleAware ,另一个是NamedComponent 1.1 LifecycleAware @[email protected] interface LifecycleAware {  public void s

Spring源码解析-applicationContext

Demo uml类图 ApplicationContext ApplicationListener 源码解析 主流程 obtainFreshBeanFactory prepareBeanFactory invokeBeanFactoryPostProcessors registerBeanPostProcessors registerListeners finishRefresh 总结 在已经有BeanFactory可以完成Ioc功能情况下,spring又提供了ApplicationContex

socketserver源码解析和协程版socketserver

来,贴上一段代码让你仰慕一下欧socketserver的魅力,看欧怎么完美实现多并发的魅力 client import socket ip_port = ('127.0.0.1',8009) sk = socket.socket() sk.connect(ip_port) sk.settimeout(5) while True: data = sk.recv(1024) print('receive:',data.decode()) inp = input('please input:') sk

Handler机制(四)---Handler源码解析

Handler的主要用途有两个:(1).在将来的某个时刻执行消息或一个runnable,(2)把消息发送到消息队列. 主要依靠post(Runnable).postAtTime(Runnable, long).postDelayed(Runnable, long).sendEmptyMessage(int).sendMessage(Message).sendMessageAtTime(Message).sendMessageDelayed(Message, long)这些方法来来完成消息调度.p