jrae源码解析(二)

本文细述上文引出的RAECost和SoftmaxCost两个类。

SoftmaxCost

我们已经知道,SoftmaxCost类在给定features和label的情况下(超参数给定),衡量给定权重($hidden\times catSize$)的误差值$cost$,并指出当前的权重梯度。看代码。

@Override
	public double valueAt(double[] x)
	{
		if( !requiresEvaluation(x) )
			return value;
		int numDataItems = Features.columns;

		int[] requiredRows = ArraysHelper.makeArray(0, CatSize-2);
		ClassifierTheta Theta = new ClassifierTheta(x,FeatureLength,CatSize);
		DoubleMatrix Prediction = getPredictions (Theta, Features);

		double MeanTerm = 1.0 / (double) numDataItems;
		double Cost = getLoss (Prediction, Labels).sum() * MeanTerm;
		double RegularisationTerm = 0.5 * Lambda * DoubleMatrixFunctions.SquaredNorm(Theta.W);

		DoubleMatrix Diff = Prediction.sub(Labels).muli(MeanTerm);
	    DoubleMatrix Delta = Features.mmul(Diff.transpose());

	    DoubleMatrix gradW = Delta.getColumns(requiredRows);
	    DoubleMatrix gradb = ((Diff.rowSums()).getRows(requiredRows));

	    //Regularizing. Bias does not have one.
	    gradW = gradW.addi(Theta.W.mul(Lambda));

	    Gradient = new ClassifierTheta(gradW,gradb);
	    value = Cost + RegularisationTerm;
	    gradient = Gradient.Theta;
		return value;
	}

public DoubleMatrix getPredictions (ClassifierTheta Theta, DoubleMatrix Features)    {        int numDataItems = Features.columns;        DoubleMatrix Input = ((Theta.W.transpose()).mmul(Features)).addColumnVector(Theta.b);        Input = DoubleMatrix.concatVertically(Input, DoubleMatrix.zeros(1,numDataItems));        return Activation.valueAt(Input);     }

是个典型的2层神经网络,没有隐层,首先根据features预测labels,预测结果用softmax归一化,然后根据误差反向传播算出权重梯度。

此处增加200字。

这个典型的2层神经网络,label为一列向量,目标label置1,其余为0;转换函数为softmax函数,输出为每个label的概率。

计算cost的函数为getLoss,假设目标label的预测输出为$p^*$,则每个样本的cost也即误差函数为:

$$cost=E(p^*)=-\log(p^*)$$

根据前述的神经网络后向传播算法,我们得到($j$为目标label时,否则为0):

$$\frac{\partial E}{\partial w_{ij}}=\frac{\partial E}{\partial p_j}\frac{\partial h_j}{\partial net_j}x_i=-\frac{1}{p_j}p_j(1-p_j)x_i=-(1-p_j)x_i=-(label_j-p_j)feature_i$$

因此我们便理解了下面代码的含义:

DoubleMatrix Delta = Features.mmul(Diff.transpose());

RAECost

先看实现代码:

@Override
	public double valueAt(double[] x)
	{
		if(!requiresEvaluation(x))
			return value;

		Theta Theta1 = new Theta(x,hiddenSize,visibleSize,dictionaryLength);
		FineTunableTheta Theta2 = new FineTunableTheta(x,hiddenSize,visibleSize,catSize,dictionaryLength);
		Theta2.setWe( Theta2.We.add(WeOrig) );

		final RAEClassificationCost classificationCost = new RAEClassificationCost(
				catSize, AlphaCat, Beta, dictionaryLength, hiddenSize, Lambda, f, Theta2);
		final RAEFeatureCost featureCost = new RAEFeatureCost(
				AlphaCat, Beta, dictionaryLength, hiddenSize, Lambda, f, WeOrig, Theta1);

		Parallel.For(DataCell,
			new Parallel.Operation<LabeledDatum<Integer,Integer>>() {
				public void perform(int index, LabeledDatum<Integer,Integer> Data)
				{
					try {
						LabeledRAETree Tree = featureCost.Compute(Data);
						classificationCost.Compute(Data, Tree);
					} catch (Exception e) {
						System.err.println(e.getMessage());
					}
				}
		});

		double costRAE = featureCost.getCost();
		double[] gradRAE = featureCost.getGradient().clone();

		double costSUP = classificationCost.getCost();
		gradient = classificationCost.getGradient();

		value = costRAE + costSUP;
		for(int i=0; i<gradRAE.length; i++)
			gradient[i] += gradRAE[i];

		System.gc();	System.gc();
		System.gc();	System.gc();
		System.gc();	System.gc();
		System.gc();	System.gc();

		return value;
	}

cost由两部分组成,featureCost和classificationCost。程序遍历每个样本,用featureCost.Compute(Data)生成一个递归树,同时累加cost和gradient,然后用classificationCost.Compute(Data, Tree)根据生成的树计算并累加cost和gradient。因此关键类为RAEFeatureCost和RAEClassificationCost。

RAEFeatureCost类在Compute函数中调用RAEPropagation的ForwardPropagate函数生成一棵树,然后调用BackPropagate计算梯度并累加。具体的算法过程,下一章分解。

时间: 2024-10-13 03:02:54

jrae源码解析(二)的相关文章

Spring 源码解析之HandlerAdapter源码解析(二)

Spring 源码解析之HandlerAdapter源码解析(二) 前言 看这篇之前需要有Spring 源码解析之HandlerMapping源码解析(一)这篇的基础,这篇主要是把请求流程中的调用controller流程单独拿出来了 解决上篇文章遗留的问题 getHandler(processedRequest) 这个方法是如何查找到对应处理的HandlerExecutionChain和HandlerMapping的,比如说静态资源的处理和请求的处理肯定是不同的HandlerMapping ge

chenglei1986/DatePicker源码解析(二)

接上一篇文章chenglei1986/DatePicker源码解析(一),我们继续将剩余的部分讲完,其实剩余的内容,就是利用Numberpicker来组成一个datePicker,代码非常的简单 为了实现自定义布局的效果,我们给Datepciker定制了一个layout,大家可以定制自己的layout <?xml version="1.0" encoding="utf-8"?> <LinearLayout xmlns:android="h

erlang下lists模块sort(排序)方法源码解析(二)

上接erlang下lists模块sort(排序)方法源码解析(一),到目前为止,list列表已经被分割成N个列表,而且每个列表的元素是有序的(从大到小) 下面我们重点来看看mergel和rmergel模块,因为我们先前主要分析的split_1_*对应的是rmergel,我们先从rmergel查看,如下 ....................................................... split_1(X, Y, [], R, Rs) -> rmergel([[Y, X

AFNetworking2.0源码解析&lt;二&gt;

本篇我们继续来看看AFNetworking的下一个模块 — AFURLRequestSerialization. AFURLRequestSerialization用于帮助构建NSURLRequest,主要做了两个事情: 1.构建普通请求:格式化请求参数,生成HTTP Header. 2.构建multipart请求. 分别看看它在这两点具体做了什么,怎么做的. 1.构建普通请求 A.格式化请求参数 一般我们请求都会按key=value的方式带上各种参数,GET方法参数直接加在URL上,POST方

volley源码解析(二)--Request&lt;T&gt;类的介绍

在上一篇文章中,我们已经提到volley的使用方式和设计的整体思路,从这篇文章开始,我就要结合具体的源码来给大家说明volley功能的具体实现. 我们第一个要介绍的类是Request<T>这个一个抽象类,我将Request称为一个请求,通过继承Request<T>来自定义request,为volley提供了更加灵活的接口. Request<T>中的泛型T,是指解析response以后的结果.在上一篇文章中我们知道,ResponseDelivery会把response分派

Mybatis 源码解析(二) - Configuration.xml解析

文章个人学习源码所得,若存在不足或者错误之处,请大家指出. 上一章中叙述了Configuration.xml流化到Mybatis内存中的过程,那么接下来肯定就是Configuration.xml文件解析操作,在Mybatis中,这个解析的操作由SqlSesssionFactoryBuilder负责.接下来我们看看SqlSessionFactoryBuilder的方法签名: SqlSessionFactoryBuilder提供了9个签名方法,其中前8个方法都是Configuration.xml的解

jQuery 源码解析(二十五) DOM操作模块 html和text方法的区别

html和text都可以获取和修改DOM节点里的内容,方法如下: html(value)     ;获取匹配元素集合中的一个元素的innerHTML内容,或者设置每个元素的innerHTML内容,                ;value可选,可以是html代码或返回html代码的函数,如果没有参数则获取匹配元素集合中第一个元素的innerHTML内容 text(text)         ;获取匹配元素集合中所有元素合并后的文本内容,或者设置每个元素的文本内容,封装了createTextNo

第37篇 Asp.Net源码解析(二)--详解HttpApplication

这篇文章花了点时间,差点成烂到电脑里面,写的过程中有好几次修改,最终的这个版本也不是很满意,东西说的不够细,还需要认真的去看下源码才能有所体会,先这样吧,后面有时间把细节慢慢的再修改.顺便对于开发的学习,个人是觉得源码的阅读是最快的提高方式,当然阅读不是走马观花,应该多次阅读. 上次说到获得HttpApplication对象的创建,创建完成后调用InitInternal方法,这个方法任务比较多,也比较长,这里就不贴全码了,一个一个过程的去说: 初始化HttpModule 对于HttpModule

dmytrodanylyk/circular-progress-button源码解析(二)

源码下载http://download.csdn.net/detail/kangaroo835127729/8755815 在上篇文章http://blog.csdn.net/crazy__chen/article/details/46278423中,我主要讲述了circular-progress-button状态切换的动画过程,接下来我们看一个最特殊的状态,就是加载状态,这个状态会显示一个圆环来表示当前加载的进度,但是其实circular-progress-button提供给了我们两个选择,一