java实现gbdt

DATA类

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner;

public class Data {
	private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
	public ArrayList<ArrayList<String>> getTrainData() {
		return this.trainData;
	}

	public Data() {
		String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
		Scanner in;
		try {
			in = new Scanner(new File(dataPath));
			while (in.hasNext()) {
				String line=in.nextLine();
				String []strs=line.trim().split(",");
				ArrayList<String> tmp=new ArrayList<>();
				for(int i=0;i<strs.length;i++)
				{
					tmp.add(strs[i]);
				}
				this.trainData.add(tmp);
			}
		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

	}

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		Data d =new Data();

	}

}

  TREE类

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.spi.TimeZoneNameProvider;

public class Tree {
	private Tree leftTree=new Tree();
	private Tree rightTree=new Tree();
	private double loss=-1;
	private int attributeSplit=0;
	private String attributeSplitType="";
	boolean isLeaf;
	double leafValue;
	private ArrayList<Integer> leafNodeSet=new ArrayList<>();

	public ArrayList<String> getAttributeSet(ArrayList<ArrayList<String>> trainData,int idx)
	{
		HashSet<String> mySet=new HashSet<>();
		ArrayList<String> ans =new ArrayList<>();
		for(int i=0;i<trainData.size();i++)
		{
			mySet.add(trainData.get(i).get(idx));
		}

		Iterator<String> it=mySet.iterator();

		while(it.hasNext())
		{
			ans.add(it.next());
		}

		return ans;
	}
	public boolean myCmpLess(String str1,String str2)
	{
		if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
			return true;
		else return false;

	}
	public double computeLoss(ArrayList<Double> values)
	{
		double loss=0;
		for(int i=0;i<values.size();i++)
		{
			loss+=values.get(i);
		}
		double mean=loss/values.size();
		loss=0;
		for(int i=0;i<values.size();i++)
		{
			loss+=Math.pow(values.get(i)-mean,2);
		}
		return Math.sqrt(loss);
	}
	public double getPredictValue(int K, ArrayList<Integer> subIdx,ArrayList<Double> target) {
		double ans=0;
		double sum=0,sum1=0;
		for(int i=0;i<subIdx.size();i++)
		{
			sum+=target.get(subIdx.get(i));
		}
		for(int i=0;i<subIdx.size();i++)
		{
			sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
		}
		ans=(K-1)/K*sum/sum1;
		return ans;
	}
	public double getPredictValue(Tree root)
	{
		return root.leafValue;
	}
	public double getPredictValue(Tree root,ArrayList<String> instance,Boolean isDigit[])
	{

		if(root.isLeaf)
			return root.leafValue;
		else if(isDigit[root.attributeSplit])
		{
			if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
				return getPredictValue(root.leftTree, instance, isDigit);
			return getPredictValue(root.rightTree, instance, isDigit);
		}
		else
		{
			if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
				return getPredictValue(root.leftTree, instance, isDigit);
			return getPredictValue(root.rightTree, instance, isDigit);
		}

	}
	public Tree constructTree(ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList<Integer> subIdx,ArrayList<ArrayList<String>> trainData,ArrayList<Double> target,int maxDepth[],int depth)
	{

		int n=trainData.size();
		int dim=trainData.get(0).size();
		ArrayList<Integer> leftTreeIdx=new ArrayList<>();
		ArrayList<Integer> rightTreeIdx=new ArrayList<>();

		if(depth<maxDepth[0])
		{
			/*
			 * 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
			 * */
			double loss=-1;
			ArrayList<Integer> leftNodes=new ArrayList<>();
			ArrayList<Integer> rightNodes=new ArrayList<>();
			int attributeSplit=0;
			String attributeSplitType="";

			for(int i=0;i<dim;i++)//遍历所有的attribute
			{
				//得到该attribute下所有的distinct的值
				ArrayList<String> myAttributeSet=new ArrayList<>();
				ArrayList<String> subDigitAttribute=new ArrayList<>();
				myAttributeSet=getAttributeSet(trainData, i);
				if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
				{
					while(subDigitAttribute.size()<splitPoints)
					{
						Random r=new Random();
						int tmp=r.nextInt(myAttributeSet.size());
						subDigitAttribute.add(myAttributeSet.get(tmp));
						myAttributeSet.clear();
						myAttributeSet=subDigitAttribute;
					}
				}
				for(int j=0;j<myAttributeSet.size();j++)
				{
					for(int k=0;k<subIdx.size();k++)
					{
						if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
						{
							leftTreeIdx.add(subIdx.get(k));
						}
						else
						{
							rightTreeIdx.add(subIdx.get(k));
						}
					}
					ArrayList<Double> leftTarget=new ArrayList<>();
					ArrayList<Double> rightTarget=new ArrayList<>();
					for(int k=0;k<leftTreeIdx.size();k++)
						leftTarget.add(target.get(leftTreeIdx.get(k)));
					for(int k=0;k<rightTreeIdx.size();k++)
						rightTarget.add(target.get(rightTreeIdx.get(k)));
					double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);
					if(loss<0||loss<lossTmp)
					{
						leftNodes.clear();
						rightNodes.clear();
						for(int k=0;k<leftTreeIdx.size();k++)
							leftNodes.add(leftTreeIdx.get(k));
						for(int k=0;k<rightTreeIdx.size();k++)
							rightNodes.add(rightTreeIdx.get(k));
						attributeSplit=i;
						attributeSplitType=myAttributeSet.get(j);
					}

				}

			}

			Tree tmpTree=new Tree();
			tmpTree.attributeSplit=attributeSplit;
			tmpTree.attributeSplitType=attributeSplitType;
			tmpTree.loss=loss;
			tmpTree.isLeaf=false;
			tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
			tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
			return tmpTree;

		}
		else
		{
			Tree tmpTree=new Tree();
			tmpTree.isLeaf=true;
			tmpTree.leafValue=getPredictValue(K, subIdx, target);
			for(int i=0;i<subIdx.size();i++)
				tmpTree.leafNodeSet.add(subIdx.get(i));
			leafNodes.add(subIdx);
			leafValues.add(tmpTree.leafValue);
			return tmpTree;
		}
	}

	public static void main(String[] args) {
		// TODO Auto-generated method stub
		Tree aTree=new Tree();
	}

}

  

GBDT类

import java.rmi.server.SkeletonNotFoundException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;

public class GBDT {

	private ArrayList<ArrayList<String>> datas=new ArrayList<ArrayList<String>>();
	private ArrayList<String> labelSets=new ArrayList<>();
	private ArrayList<ArrayList<Double>> F=new ArrayList<ArrayList<Double>>();
	private ArrayList<ArrayList<Double>> residual=new ArrayList<ArrayList<Double>>();
	private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
	private ArrayList<Integer> labelTrainData=new ArrayList<Integer>();
	private int K;
	private Boolean isDigit[];
	private int dim;
	private int n;
	private double learningRate;

	private ArrayList<ArrayList<Tree>> trees=new ArrayList<ArrayList<Tree>>(); //存放所有的树

	private int max_iter;
	private double sampleRate;
	private int maxDepth;
	private int splitPoints;

	public void computeResidual(ArrayList<Integer> subId)
	{
		for(int i=0;i<subId.size();i++)
		{
			int idx=subId.get(i);
			int y=0;
			if(this.labelTrainData.get(idx)==-1) y=0;
			else y=1;
			double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
			double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
			this.residual.get(idx).set(0, y-p1);
			this.residual.get(idx).set(1, y-p2);
		}
	}
	public ArrayList<Integer> myrandom(int maxNum,int num)
	{
		ArrayList<Integer> ans=new ArrayList<>();
		Set<Integer> mySet=new HashSet<>();
		while(mySet.size()<num)
		{
			Random r=new Random();
			int tmp=r.nextInt(maxNum);
			mySet.add(tmp);
		}
		Iterator<Integer> it=mySet.iterator();
		while(it.hasNext())
		{
			ans.add(it.next());
		}
		return ans;
	}

	public GBDT()
	{
		this.max_iter=50;
		this.sampleRate=0.8;
		this.K=2;//2分类问题
		this.maxDepth=6;
		this.splitPoints=3;
		this.learningRate=0.01;
		getData();
	}

	public void train()
	{
		for(int i=0;i<max_iter;i++)
		{
			ArrayList<Integer> subSet=new ArrayList<>();
			int numSubset=(int)(this.n*this.sampleRate);
			subSet=myrandom(this.n,numSubset);
			computeResidual(subSet);
			ArrayList<Double> target=new ArrayList<>();
			ArrayList<Tree> tmpTree=new ArrayList<>();
			int maxdepths[]={this.maxDepth};
			for(int j=0;j<this.K;j++)
			{
				target.clear();
				for(int k=0;k<subSet.size();k++)
				{
					target.add(residual.get(subSet.get(k)).get(j));
				}
				ArrayList<ArrayList<Integer>> leafNodes=new ArrayList<ArrayList<Integer>>();
				ArrayList<Double> leafValues=new ArrayList<>();
				Tree treeSub=new Tree();
				Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
				tmpTree.add(iterTree);
				updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
			}

			trees.add(tmpTree);
		}
	}

	public void updateFvalue(Boolean isDigit[], ArrayList<Integer> subIdx,ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int label,Tree root)
	{
		ArrayList<Integer> remainIdx=new ArrayList<>();
		int arr[]=new int[this.n];
		for(int i=0;i<this.n;i++)
			arr[i]=i;
		for(int i=0;i<subIdx.size();i++)
		{
			arr[subIdx.get(i)]=-1;
		}
		//求出不是用来训练树的余下集合
		for(int i=0;i<this.n;i++)
		{
			if(arr[i]!=-1)
				remainIdx.add(i);
		}
		for(int i=0;i<leafNodes.size();i++)
		{
			for(int j=0;j<leafNodes.get(i).size();j++)
			{
				this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
			}
		}
		for(int i=0;i<remainIdx.size();i++)
		{
			double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
			this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
		}

	}

	public boolean checkDigit(String str) {
		for(int i=0;i<str.length();i++)
		{
			if(!(str.charAt(i)>=‘0‘&&str.charAt(i)<=‘9‘))
			{
				return false;
			}
		}
		return true;
	}

	public void getData() {
		Data d =new Data();
		this.datas=d.getTrainData();
		this.dim=this.datas.get(0).size()-1;
		this.isDigit=new Boolean[this.dim];
		//遍历所有样本,去掉中间含有不是正常的数据
		for(int i=0;i<this.datas.get(0).size()-1;i++)
			labelSets.add(this.datas.get(0).get(i));
		//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
		for(int i=0;i<this.dim;i++)
		{
			if(checkDigit(this.datas.get(0).get(i)))
				this.isDigit[i]=true;
			else this.isDigit[i]=false;
		}
		//如果字符串==?说明是异常数据,这里做数据的清理
		for(int i=1;i<this.datas.size();i++)
		{
			ArrayList<String> tmp=new ArrayList<>();
			boolean flag=true;
			for(int j=0;j<this.dim;j++)
			{
				if(datas.get(i).get(j).trim().equals("?"))
				{
					flag=false;
					break;
				}
			}
			if(!flag) continue;
			if(datas.get(i).get(this.dim).trim().equals("?")) continue;
			trainData.add(tmp);
			if(datas.get(i).get(this.dim).trim().equals("<=50K"))
				labelTrainData.add(-1);
			else
				labelTrainData.add(1);

		}
		this.n=this.labelTrainData.size();

		for(int i=0;i<this.datas.get(0).size()-1;i++)
			labelSets.add(this.datas.get(0).get(i));

		//初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
		for(int i=0;i<this.n;i++)
		{
			ArrayList<Double> arrTmp=new ArrayList<Double>();
			for(int j=0;j<2;j++)
			{
				arrTmp.add(0.0);
			}
			this.F.add(arrTmp);
			this.residual.add(arrTmp);
		}

	}

	public static void main(String[] args) {
		GBDT dGbdt=new GBDT();
		dGbdt.getData();
		System.err.println(dGbdt.n);

	}
}

  

时间: 2024-10-25 08:14:38

java实现gbdt的相关文章

梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python) http://blog.csdn.net/liulingyuan6/article/details/53426350 梯度迭代树 算法简介: 梯度提升树是一种决策树的集成算法.它通过反复迭代训练决策树来最小化损失函数.决策树类似,梯度提升树具有可处理类别特征.易扩展到多分类问题.不需特征缩放等性质.Spark.ml通过使用现有decision tree工具来实现. 梯度提升树依次迭代训练一系列的

决策树和基于决策树的集成方法(DT,RF,GBDT,XGB)复习总结

摘要: 1.算法概述 2.算法推导 3.算法特性及优缺点 4.注意事项 5.实现和具体例子 内容: 1.算法概述 1.1 决策树(DT)是一种基本的分类和回归方法.在分类问题中它可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布,学习思想包括ID3,C4.5,CART(摘自<统计学习方法>). 1.2 Bagging :基于数据随机重抽样的集成方法(Ensemble methods),也称为自举汇聚法(boostrap aggregating),整个数据集是

随机森林和GBDT的学习

参考文献:http://www.zilhua.com/629.html http://www.tuicool.com/articles/JvMJve http://blog.sina.com.cn/s/blog_573085f70101ivj5.html 我的数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm我的算法库:https://github.com/linyiqun/lyq-algorithms-lib 前言 提到森林,就不得不联

Spark2.0机器学习系列之6:GBDT(梯度提升决策树)、GBDT与随机森林差异、参数调试及Scikit代码分析

概念梳理 GBDT的别称 GBDT(Gradient Boost Decision Tree),梯度提升决策树.     GBDT这个算法还有一些其他的名字,比如说MART(Multiple Additive Regression Tree),GBRT(Gradient Boost Regression Tree),Tree Net等,其实它们都是一个东西(参考自wikipedia – Gradient Boosting),发明者是Friedman. 研究GBDT一定要看看Friedman的pa

90天,从Java转机器学习面试总结

前 言 辗转几年Java开发,换了几份工作,没一个稳定的学习.工作过程.中间也相亲几次,都是没啥结果.换工作频繁也严重打乱了和姑娘接触的节奏.糟心工作连着遇到几次,也怪自己眼光有问题. 2018也找了2次工作,中间有4.5个月没有工作.看了个世界杯,看了个亚运会.也怪自己这段时间一直是换工作.找工作,节奏太乱了.当然,节奏不乱也可能不会比现在好吧.谁说的准呢? Java转机器学习--为啥呢? 主要有以下三方面原因: 1.Java感觉遇到瓶颈.Spring.Mytatis.设计模式等等,源码看不动

Java多线程学习(吐血超详细总结)

林炳文Evankaka原创作品.转载请注明出处http://blog.csdn.net/evankaka 目录(?)[-] 一扩展javalangThread类 二实现javalangRunnable接口 三Thread和Runnable的区别 四线程状态转换 五线程调度 六常用函数说明 使用方式 为什么要用join方法 七常见线程名词解释 八线程同步 九线程数据传递 本文主要讲了java中多线程的使用方法.线程同步.线程数据传递.线程状态及相应的一些线程函数用法.概述等. 首先讲一下进程和线程

Java TM 已被阻止,因为它已过时需要更新的解决方法

公司的堡垒机需要通过浏览器登陆,且该堡垒机的网站需要Java的支持,最近通过浏览器登陆之后总是提示"java TM 已被阻止,因为它已过时需要更新的解决方法"导致登陆之后不能操作, 但是操作系统中确实已经安装了比较新的JDK,安装的JDK版本是jdk-7u67-windows-i586,因为太烦人,所以决定搞清楚报错的原因,一劳永逸,彻底解决这个问题 准备工作:安装JDK,安装版本jdk-7u67-windows-i586.exe,因为机器的Eclipse还依赖64位的JDK,所以另安

Java四种线程池newCachedThreadPool,newFixedThreadPool,newScheduledThreadPool,newSingleThreadExecutor

介绍new Thread的弊端及Java四种线程池的使用,对Android同样适用.本文是基础篇,后面会分享下线程池一些高级功能. 1.new Thread的弊端 执行一个异步任务你还只是如下new Thread吗? Java new Thread(new Runnable() { @Override public void run() { // TODO Auto-generated method stub } }).start(); 1 2 3 4 5 6 7 new Thread(new

由@NotNull 注解引出的关于Java空指针的控制(转)

Java 小技巧和在java应用避免NullPonintException的最佳方法 在java应用程序中,一个NullPonintException(空指针异常)是最好解决(问题)的方法.同时,空指针也是写健壮的顺畅运行的代码的关键.“预防好过治疗”这句话也同样适用于令人不爽的NullPonintException.通过应用防御性的编码技术和在遵守多个部分之间的约定,你可以再很大程度上避免NullPointException.下面的这些java小技巧可以最小化像!=null这种检查的代码.作为