自己实现的SVM源码

首先是DATA类

import java.awt.print.Printable;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

public class Data {
public Map<List<Double>, Integer> getTrainData() {
	Map<List<Double>, Integer> data=new HashMap<List<Double>, Integer>();

	try {
		Scanner in=new Scanner(new File("G://download//testSet.txt"));
		while(in.hasNextLine())
		{
			String str =in.nextLine();
			String []strs=str.trim().split("\t");
			List<Double> pointTmp=new ArrayList<>();
			for(int i=0;i<strs.length-1;i++)
				pointTmp.add(Double.parseDouble(strs[i]));
			data.put(pointTmp, Integer.parseInt(strs[strs.length-1]));
		}
	} catch (FileNotFoundException e) {
		// TODO: handle exception
		e.printStackTrace();
	}

	return data;
}

public static void main(String[] args)
{
	Data data=new Data();
	data.getTrainData();
}
}

  SVM类:

import java.awt.print.Printable;
import java.io.FileNotFoundException;
import java.io.ObjectInputStream.GetField;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry;

public class SVM {
	private List<ArrayList<Double>> trainData;
	private List<Integer> labelTrainData;
	private double sigma;
	private double C;
	private List<Double> alpha;
	private double b;
	private List<Double> E;
	private int N;
	private int dim;
	private double tol;
	private double eta;
	private double eps;
	private double eps2;

	public boolean satisfyKkt(int id)
	{
		double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x)
		if(Math.abs(this.alpha.get(id))<=this.eps)
		{
			if(ypgx-1<-this.tol) return false;
		}
		else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps)
		{
			if(ypgx-1>this.tol) return false;
		}
		else {
			if(Math.abs(ypgx-1)>this.tol) return false;
		}
		return true;
	}

	public void updateE() {

		for(int i=0;i<this.N;i++)
		{
			double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i);
			this.E.set(i, Ei);
		}
	}

	public double kernelLinear(List<Double> X,List<Double> Y) {
		//linear kernel function
		int len=Y.size();
		double s=0;
		for(int i=0;i<len;i++)
			s+=X.get(i)*Y.get(i);
		return s;
	}

	public double kernelRBF(List<Double> X,List<Double> Y)
	{
		//gauss kernel function

		int len=Y.size();
		double s=0;
		for(int i=0;i<len;i++)
			s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i));
		s=Math.exp(-s/(2*Math.pow(this.sigma, 2)));
		return s;
	}

	public double getGx(List<Double> X)
	{
		//calculate wx+b value
		double s=0;
		for(int i=0;i<this.N;i++)
		{
			//for debug
			double debug1=kernelRBF(X, this.trainData.get(i));
			double debug2=this.alpha.get(i);

			s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i));
		}
		s+=this.b;
		return s;
	}

	public int update(int x1,int x2)
	{
		double low=0;
		double high=0;
		if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2))
		{
			low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C);
			high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1));
		}
		else
		{
			low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1));
			high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C);
		}
		double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta;
		double newAlpha1=0;

		if(newAlpha2>high) newAlpha2=high;
		else if(newAlpha2<low) newAlpha2=low;
		newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2);

		if(Math.abs(newAlpha1)<=this.eps)
			newAlpha1=0;
		if(Math.abs(newAlpha2)<=this.eps)
			newAlpha2=0;
		if(Math.abs(newAlpha1-this.C)<=this.eps)
			newAlpha1=this.C;
		if(Math.abs(newAlpha2-this.C)<=this.eps)
			newAlpha2=this.C;
		if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2)
			return 0;
		if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2)
			return 0;

		double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b;
		double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b;

		if(newAlpha1>0&&newAlpha1<this.C)
			this.b=b1;
		else if(newAlpha2>0&&newAlpha2<this.C)
			this.b=b2;
		else
			this.b=(b1+b2)/2;

		this.alpha.set(x1,newAlpha1);
		this.alpha.set(x2,newAlpha2);
		updateE();
		return 1;
	}
	public int selectAlpha2(int x1) {

		int x2=-1;
		double maxDiff=-1;
		//first select x2 from 0<a<c to max(E(x1)-E(x2))

		for(int i=0;i<this.N;++i)
		{
			if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue;
			double diff=Math.abs(this.E.get(x1)-this.E.get(i));
			if(diff>maxDiff)
			{
				maxDiff=diff;
				x2=i;
			}
		}

		//second calculate eta (eta!=0)
		if(x2!=-1)
		{
			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2));
			if(eta!=0) return x2;
		}

		//third if cannot find in the whole train set
		for(int i=0;i<this.N;i++)
		{
			if(i==x1) continue;
			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i));
			if(Math.abs(this.eta)>this.eps) return i;
		}
		return -1;

	}

	public void SMO() {
		//to solve alpha
		int numChanged=0;
		int cnt=0;
		while(true)
		{
			cnt++;
			System.out.println(cnt);

			numChanged=0;
			for(int x1=0;x1<this.N;++x1)
			{
				if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue;
				if(!satisfyKkt(x1))
				{
					int x2=selectAlpha2(x1);
					if(x2==-1) continue;
					numChanged+=update(x1, x2);
				}
			}
			if(numChanged==0)
			{
				for(int x1=0;x1<this.N;++x1)
				{
					if(!satisfyKkt(x1))
					{
						int x2=selectAlpha2(x1);
						if(x2==-1) continue;
						update(x1, x2);
						numChanged++;
					}
				}
			}
			if(numChanged==0)
				break;
		}
	}

	public SVM() {
		//load train data

		Data data=new Data();
		Map<List<Double>, Integer> Datas=data.getTrainData();
		int totalData=Datas.size();
		this.trainData=new ArrayList<ArrayList<Double>>();
		this.labelTrainData=new ArrayList<Integer>();
		this.alpha=new ArrayList<Double>();
		this.E=new ArrayList<Double>();

		int i=0;
		for(Map.Entry<List<Double>, Integer> entry: Datas.entrySet())
		{
			this.trainData.add((ArrayList<Double>) entry.getKey());
			this.labelTrainData.add(entry.getValue());
			this.alpha.add(0.0);
			this.E.add(0.0-this.labelTrainData.get(i));
			i++;
		}
		this.N=this.labelTrainData.size();
		this.dim=this.trainData.get(0).size();

		this.sigma=12;//sigma=1
		this.C=0.5;//c=6
		this.b=0.0;
		this.tol=0.001;
		this.eta=0;
		this.eps=0.0000001;
		this.eps2=0.00001;
	}

	public double getB() {
		//get b value
		return this.b;
	}
	public double[] getLinearW() {
		double []w=new double[this.N];
		for(int i=0;i<this.N;i++)
		{
			for(int j=0;j<this.dim;j++)
			{
				w[j]+=this.alpha.get(i)*this.labelTrainData.get(i)*this.trainData.get(i).get(j);
			}
		}
		return w;
	}

	public int predict(List<Double> x)
	{
		int ans=1;
		double sum=0;
		for(int i=0;i<this.N;i++)
		{
			sum+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(x, this.trainData.get(i));
		}
		sum+=b;
		if(sum>0)
			ans=1;
		else
			ans=-1;

		return ans;
	}
	public static void main(String[] args) throws FileNotFoundException {

		SVM s=new SVM();
		s.SMO();
		PrintWriter out=new PrintWriter("g://download//resultpoints.txt");
		for(int i=0;i<s.N;i++)
		{
			out.write((s.trainData.get(i).get(0)).toString());
			out.write("\t");
			out.write((s.trainData.get(i).get(1)).toString());
			out.write("\t");
			out.write(Integer.toString(s.predict(s.trainData.get(i))));
			out.write("\n");
		}
		out.close();
		//if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction
		double w[]=s.getLinearW();
		System.out.println(w[0]+" "+w[1]+" "+s.b+"======");
	}

}

  

用线性核函数实现的SVM的到的分类结果

画图,是用python代码

from numpy import *
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

with open("g://download/myresult.txt") as f1:
    data=f1.readlines();

    plt.figure(figsize=(8, 5), dpi=80)
    axes = plt.subplot(111)
    type1_x = []
    type1_y = []
    type2_x = []
    type2_y = []
    for line in data:
        x=line.strip().split(‘\t‘);
        x1=float(x[0])
        x2=float(x[1])
        x3=int(x[2])

        if x3==1:
            type1_x.append(x1)
            type1_y.append(x2)
        else:
            type2_x.append(x1)
            type2_y.append(x2)

    type1 = axes.scatter(type1_x, type1_y,s=40, c=‘red‘ )
    type2 = axes.scatter(type2_x, type2_y, s=40, c=‘green‘)  

    W1 = 0.8148005405344305
    W2 = -0.27263471796762484
    B = -3.8392586254518437
    x = np.linspace(-4,10,200)
    y = (-W1/W2)*x+(-B/W2)
    axes.plot(x,y,‘b‘,lw=3)  

    plt.xlabel(‘x1‘)
    plt.ylabel(‘x2‘)   

    axes.legend((type1, type2), (‘0‘, ‘1‘),loc=1)
    plt.show()  

#0.8148005405344305 -0.27263471796762484 -3.8392586254518437

  用高斯核,当C=6,sigma=1时候

高斯核,当c=0.5,sigma=1时候

当C=0.5,sigma=12时候

说明C的大小和sigma的大小对高斯核影响是很大的

时间: 2024-11-10 02:13:51

自己实现的SVM源码的相关文章

SVM与C++源码实现

1. 推导出函数间隔最小 2. 约束优化函数变形至如下形式 /* min 1/2*||w||^2 s.t.  (w[i]*x[i] + b[i] - y[i]) >= 0; */ 3. 对偶函数 /* min(para alpha) 1/2*sum(i)sum(j)(alpha[i]*alpha[j]*y[i]*y[j]*x[i]*x[j]) - sum(alpha[i]) s.t. sum(alpha[i] * y[i]) = 0 C>= alpha[i] >= 0 * 4. 根据KK

HOG+SVM(OpenCV案例源码train_HOG.cpp解读)

有所更改,参数不求完备,但求实用.源码参考D:\source\opencv-3.4.9\samples\cpp\train_HOG.cpp [解读参考]https://blog.csdn.net/xiao__run/article/details/82902267 [HOG原理]https://livezingy.com/hogdescriptor-in-opencv3-1/ https://cloud.tencent.com/developer/article/1434949 #include

hog源码分析

http://www.cnblogs.com/tornadomeet/archive/2012/08/15/2640754.html 在博客目标检测学习_1(用opencv自带hog实现行人检测) 中已经使用了opencv自带的函数detectMultiScale()实现了对行人的检测,当然了,该算法采用的是hog算法,那么hog算法是怎样实现的呢?这一节就来简单分析一下opencv中自带 hog源码. 网上也有不少网友对opencv中的hog源码进行了分析,很不错,看了很有收获.比如: htt

实验报告: 人脸识别方法回顾与实验分析 【OpenCV测试方法源码】

趁着还未工作,先把过去做的东西整理下出来~   Github源码:https://github.com/Blz-Galaxy/OpenCV-Face-Recognition (涉及个人隐私,源码不包含测试样本,请谅解~) 对实验结果更感兴趣的朋友请直接看 第5章 [摘要]这是一篇关于人脸识别方法的实验报告.报告首先回顾了人脸识别研究的发展历程及基本分类:随后对人脸识别技术方法发展过程中一些经典的流行的方法进行了详细的阐述:最后作者通过设计实验对比了三种方法的识别效果并总结了人脸识别所面临的困难与

spark.mllib源码阅读-优化算法1-Gradient

Spark中定义的损失函数及梯度,在看源码之前,先回顾一下机器学习中定义了哪些损失函数,毕竟梯度求解是为优化求解损失函数服务的. 监督学习问题是在假设空间F中选取模型f作为决策函数,对于给定的输入X,由f(X)给出相应的输出Y,这个输出的预测值f(X)与真实值Y可能一致也可能不一致,用一个损失函数(lossfunction)或代价函数(cost function)来度量预测错误的程度.损失函数是f(X)和Y的非负实值函数,记作L(Y, f(X)). 统计学习中常用的损失函数有以下几种: (1)

Mahout源码目录说明

Mahout源码目录说明 mahout项目是由多个子项目组成的,各子项目分别位于源码的不同目录下,下面对mahout的组成进行介绍: 1.mahout-core:核心程序模块,位于/core目录下: 2.mahout-math:在核心程序中使用的一些数据通用计算模块,位于/math目录下: 3.mahout-utils:在核心程序中使用的一些通用的工具性模块,位于/utils目录下: 上述三个部分是程序的主题,存储所有mahout项目的源码. 另外,mahout提供了样例程序,分别在taste-

struct2源码解读(1)之struts2启动

struct2源码解读(1)之struts启动 之前用struct2.spring.hibernate在开发一个电子商城,每天都在重复敲代码,感觉对struct2.spring.hibernate的理解都在使用层面上,虽然敲了几个月代码,但是技术水平还是得不到显著提高.于是就想着研究下struct2.spring.hibernate的源码,研究完后不仅对struct2.spring.hibernate加深了了解,同时也加强了java的学习,例如xml的解析,字符操作,线程等等,受益匪浅.想着当初

OpenCV2.4.9源码分析——Support Vector Machines

引言 本文共分为三个部分,第一个部分介绍SVM的原理,我们全面介绍了5中常用的SVM算法:C-SVC.ν-SVC.单类SVM.ε-SVR和ν-SVR,其中C-SVC和ν-SVC不仅介绍了处理两类分类问题的情况,还介绍处理多类问题的情况.在具体求解SVM过程中,我们介绍了SMO算法和广义SMO算法.第二个部分我们给出了OpenCV中SVM程序的详细注解.第三个部分我们给出了一个基于OpenCV的SVM算法的简单应用实例. 由于这篇文章太长,公式很多,把文章复制到这里,阅读体验会很差,因此,我把这篇

【转】近200篇机器学习&amp;深度学习资料分享(含各种文档,视频,源码等)

编者按:本文收集了百来篇关于机器学习和深度学习的资料,含各种文档,视频,源码等.而且原文也会不定期的更新,望看到文章的朋友能够学到更多. <Brief History of Machine Learning> 介绍:这是一篇介绍机器学习历史的文章,介绍很全面,从感知机.神经网络.决策树.SVM.Adaboost 到随机森林.Deep Learning. <Deep Learning in Neural Networks: An Overview> 介绍:这是瑞士人工智能实验室 Ju