3、Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3
第三章Convolution Neural Network (卷积神经网络)
3实例
3.1 测试数据
按照上例数据,或者新建图片识别数据。
3.2 CNN实例
//2 测试数据
Logger.getRootLogger.setLevel(Level.WARN)
valdata_path="/user/tmp/deeplearn/train_d.txt"
valexamples=sc.textFile(data_path).cache()
valtrain_d1=examples.map{ line =>
valf1 = line.split("\t")
valf =f1.map(f =>
f.toDouble)
valy =f.slice(0,10)
valx =f.slice(10,f.length)
(newBDM(1,y.length,
y), (new BDM(1,x.length,
x)).reshape(28,28) /
255.0)
}
valtrain_d=train_d1.map(f=> (f._1, f._2))
//3 设置训练参数,建立模型
// opts:迭代步长,迭代次数,交叉验证比例
valopts= Array(100.0,1.0,0.0)
train_d.cache
valnumExamples=train_d.count()
println(s"numExamples = $numExamples.")
valCNNmodel=newCNN().
setMapsize(new BDM(1,2, Array(28.0,28.0))).
setTypes(Array("i",
"c","s","c","s")).
setLayer(5).
setOnum(10).
setOutputmaps(Array(0.0,
6.0,0.0,12.0,0.0)).
setKernelsize(Array(0.0,
5.0,0.0,5.0,0.0)).
setScale(Array(0.0,
0.0,2.0,0.0,2.0)).
setAlpha(1.0).
setBatchsize(50.0).
setNumepochs(1.0).
CNNtrain(train_d,opts)
//4 模型测试
valCNNforecast=CNNmodel.predict(train_d)
valCNNerror=CNNmodel.Loss(CNNforecast)
println(s"NNerror = $CNNerror.")
valprintf1=CNNforecast.map(f=> (f.label.data(0),
f.predict_label.data(0))).take(200)
println("预测结果——实际值:预测值:误差")
for(i <-0 until
printf1.length)
println(printf1(i)._1 +"\t"
+printf1(i)._2 +"\t" + (printf1(i)._2
-printf1(i)._1)) val
numExamples = train_d.count()
println(s"numExamples = $numExamples.")
println(mynn._2)
for(i <-0 to
mynn._1.length -1) {
print(mynn._1(i) +"\t")
}
println()
println("mynn_W1")
valtmpw1=mynn._3(0)
for(i <-0 to
tmpw1.rows -1) {
for(j <-0 to
tmpw1.cols -1) {
print(tmpw1(i,j) +
"\t")
}
println()
}
valNNmodel=newNeuralNet().
setSize(mynn._1).
setLayer(mynn._2).
setActivation_function("sigm").
setOutput_function("sigm").
setInitW(mynn._3).
NNtrain(train_d,nnopts)
//5 NN模型测试
valNNforecast=NNmodel.predict(train_d)
valNNerror=NNmodel.Loss(NNforecast)
println(s"NNerror = $NNerror.")
valprintf1=NNforecast.map(f=> (f.label.data(0),
f.predict_label.data(0))).take(200)
println("预测结果——实际值:预测值:误差")
for(i <-0 until
printf1.length)
println(printf1(i)._1 +"\t"
+printf1(i)._2 +"\t" + (printf1(i)._2
-printf1(i)._1))
转载请注明出处:
版权声明:本文为博主原创文章,未经博主允许不得转载。