线性规划其实实现很简单, 关键就是theta的训练。下面是我的JAVA代码实现:
我用的训练集为:
1.0 2.0 1.5
2.0 3.5 3.4
-1.2 2.0 3.5
4.7 3.2 4.5
2.3 -2.5 5.4
下面是类与函数的实现:
1 import java.io.BufferedReader; 2 import java.io.File; 3 import java.io.FileReader; 4 import java.io.IOException; 5 6 /** 7 * 线性回归 8 * @author CassieRyu 9 *批量梯度下降法 10 */ 11 public class Linear { 12 13 private double[][] trainData;//训练数据集 14 private int row; //训练集的行(样本数目) 15 private int column; //训练集的列(特征数目+2)第一行是人为添加的x0, 最后一列为y值 16 17 private double[] theta; //参数theta向量 18 private double alpha; //步长 19 private int iteration; //迭代次数 20 21 //构造函数 22 public Linear(String fileName, double alpha, int ite){ 23 int rowF= getRowFromFile(fileName); 24 int columnF = getColumnFromFile(fileName); 25 26 row = rowF; 27 column = columnF+1; //为了计算方便,加上x0那行 28 29 trainData = new double[row][column]; 30 loadTrainData(fileName); 31 32 this.alpha = alpha; 33 this.iteration = ite; 34 35 theta = new double[column-1];//减去y对应的那行 36 initializeTheta();//theta的初始化 37 38 trainedTheta();//训练后的theta值 39 } 40 41 //返回训练集的样本数目row 42 public int getRowFromFile(String fileName){ 43 44 int count=0; 45 File file = new File(fileName); 46 BufferedReader br = null; 47 try{ 48 br = new BufferedReader(new FileReader(file)); 49 String temp = null; 50 while((temp = br.readLine())!=null){ //循环读取下一行 51 count++; 52 } 53 }catch(IOException e){ 54 e.printStackTrace(); 55 }finally{ 56 if(br!=null) 57 try{ 58 br.close(); 59 }catch(IOException e1){ 60 61 } 62 } 63 64 return count; 65 } 66 67 //返回训练集列数,不包含x0 68 public int getColumnFromFile(String fileName){ 69 70 int count=0; 71 File file = new File(fileName); 72 BufferedReader br = null; 73 try{ 74 br = new BufferedReader(new FileReader(file)); 75 String temp = null; 76 if((temp = br.readLine())!=null){ 77 String [] tempStr = temp.split(" ");//用空格将列分开 78 count = tempStr.length;//数组长度为列的数目 79 } 80 }catch(IOException e){ 81 e.printStackTrace(); 82 }finally{ 83 if(br!=null) 84 try{ 85 br.close(); 86 }catch(IOException e1){ 87 88 } 89 } 90 91 return count; 92 } 93 94 //返回训练集 95 public void loadTrainData(String fileName){ 96 97 //初始化x0为1 98 for(int i=0;i<row;i++) 99 trainData[i][0]=1.0; 100 101 File file = new File(fileName); 102 BufferedReader br = null; 103 try{ 104 br = new BufferedReader(new FileReader(file)); 105 String temp = null; 106 int count=0; 107 while((temp = br.readLine())!=null){ //行循环 108 String [] tempStr = temp.split(" ");//用空格将列分开 109 110 for(int i=1;i<column;i++) //对每行的每列赋值,除第一列x0==1已赋值 111 trainData[count][i] = Double.parseDouble(tempStr[i-1]); 112 count++; //行号加1 113 } 114 }catch(IOException e){ 115 e.printStackTrace(); 116 }finally{ 117 if(br!=null) 118 try{ 119 br.close(); 120 }catch(IOException e1){ 121 122 } 123 } 124 } 125 126 //初始化theta的值 127 public void initializeTheta(){ 128 129 for(int i=0;i<column-1;i++) 130 theta[i]=1.0; 131 } 132 133 //训练theta的值 134 public void trainedTheta(){ 135 136 while((iteration--)>0){//迭代次数 137 138 //每迭代一次需带入新的theta值重新计算一次h(xi)-y(i) 139 double[] temp = new double[row]; 140 temp=getDerivation(); //h(xi)-y(i) 141 142 for(int j=0;j<column-1;j++){//循环一次的复杂度为O(m),m为样本数目 143 double []tep = new double[row]; 144 double result=0.0; 145 for(int i=0;i<row;i++){ 146 tep[i] = temp[i]*trainData[i][j]; //(h(xi)-y(i))*X(ij) 147 result+=tep[i]; 148 } 149 theta[j]-= alpha*result; 150 } 151 } 152 } 153 154 //得到(theta(k)*X(ik)-Y(i))即(h(xj)-yj) 155 public double[] getDerivation(){ 156 157 double [] deff = new double[row]; 158 159 for(int i=0;i<row;i++){ 160 double h = getHypothesisFunc(i); 161 deff[i]=h-trainData[i][column-1]; 162 } 163 return deff; 164 } 165 166 //得到theta(k)*X(ik) 167 public double getHypothesisFunc(int i){ //i为具体的某一行 168 169 double result=0; 170 for(int k=0;k<column-1;k++){ 171 result+=theta[k]*trainData[i][k]; 172 } 173 return result; 174 } 175 176 //打印训练集 177 public void printTrainData(){ 178 179 System.out.printf("\n训练集:\n"); 180 181 for(int i=0;i<row;i++){ 182 System.out.printf("第"+i+"行:"); 183 184 for(int j=0;j<column;j++){ 185 System.out.printf(trainData[i][j]+" "); 186 } 187 System.out.printf("\n"); 188 } 189 System.out.printf("\n"); 190 191 } 192 193 //打印theta值 194 public void printTheta(){ 195 196 System.out.printf("Theta集:\n"); 197 for(int j=0;j<column-1;j++){ 198 System.out.printf(theta[j]+" "); 199 } 200 System.out.printf("\n"); 201 } 202 203 //预测过程,即将theta带入h函数 204 public double predict(double[] newData){ 205 206 double h=0.0; 207 for(int i=0;i<column-1;i++){ 208 h+=newData[i]*theta[i]; 209 } 210 return h; 211 } 212 }
LinearRegression
根据模型进行测试:
命令行中的输入为:2.3 4.6
1 public class LinearMain { 2 3 public static void main(String[] args){ 4 5 String fileName = "C:\\Users\\CassieLiu\\Desktop\\train.txt"; 6 Linear lin = new Linear(fileName,0.005,100); 7 lin.printTrainData(); 8 lin.printTheta(); 9 10 //进行预测,数值在命令行参数里面 11 int len = args.length; 12 if(len!=(lin.getColumnFromFile(fileName)-1)){ //测试数据没有y值 13 System.out.printf("请输入对应该模型的样本!\n"); 14 return; 15 } 16 else{ 17 double[] arg = new double[len+1]; 18 arg[0]=1.0; //给x0赋值 19 for(int i=0;i<len;i++){ 20 arg[i+1] = Double.parseDouble(args[i]); 21 } 22 23 double result = lin.predict(arg); 24 System.out.printf("根据模型预测出的值为:"+result); 25 } 26 27 28 } 29 }
main
最后的输出结果为:
有不足之处请指出!
本文为博主原创博文,未经许可请勿转载!
时间: 2024-12-11 17:20:14