想到那天头条面试时,让我手撸kmeans,奈何好久不用c++,好多都忘了==淡淡的忧伤
这次刚好赶上机会,可以再试试了,我写成项目了,有多个文件
首先:base.h
#ifndef BASE_H #define BASE_H #include<iostream> #include<opencv2/opencv.hpp> #include<cassert> #include<stdlib.h> class Baseofgeo{ public: float computedis(const std::vector<float> p1,const std::vector<float> p2); void Gekmeasns(std::vector<std::vector<float> >& listA,int K); private: //to do... }; #endif
再上:base.cpp
#include"base.h" void Baseofgeo::Gekmeasns(std::vector<std::vector<float> >& listA,int K){ //初始化 srand((unsigned)time(NULL)); std::vector<std::vector<float> > centerid(K,std::vector<float> (3,0.0)); int len=listA.size(); for (int i=0;i<K;){ int randomindex=rand()%len; if(listA[randomindex][3]!=-1.0){continue;} listA[randomindex][3]=i; centerid[i][0]=listA[randomindex][0]; centerid[i][1]=listA[randomindex][1]; centerid[i][2]=listA[randomindex][2]; i++; } //计算距离 int count=0;//迭代次数 float J1=0.0;//记录上一次迭代后的类内距离 float reserror=100.0;//记录连续两次类内距离的变化 while( count<=20 && reserror>10.0){ std::cout<<"第"<<count<<"次迭代"<<std::endl; //对每个点遍历,计算与其最接近的中心,并赋上类别 for(std::vector<float>& p1:listA){ std::vector<float> pp1(3,0.0); pp1[0]=p1[0];pp1[1]=p1[1];pp1[2]=p1[2]; float mindist=99999999.0; for(int j=0;j<centerid.size();j++){ std::vector<float> p2=centerid[j]; float distt=computedis(pp1,p2); if(distt<mindist){ mindist=distt; p1[3]=(float)j; } } } //重新计算中心 for(int i=0;i<K;i++){ std::vector<float> sum(3,0.0); float numb=0.0; for(std::vector<float> p:listA){ if((int)p[3]==i){ sum[0]+=p[0];sum[1]+=p[1];sum[2]+=p[2]; numb++; } } assert(numb!=0); sum[0]/=numb;sum[1]/=numb;sum[2]/=numb; centerid[i]=sum; } //计算终止条件1 count++; //计算终止条件2 float J=0.0; for(int i=0;i<K;i++){ for(std::vector<float> p:listA ){ if ((int)p[3]==i){ std::vector<float> ptem(3,0.0); ptem[0]=p[0];ptem[1]=p[1];ptem[2]=p[2]; J+=computedis(ptem,centerid[i]); } } } // if (count==1){ //记录上次的类内距离之和 J1=J; }else{ reserror=std::abs(J-J1); std::cout<<"reserror:"<<reserror<<std::endl; //记录上次的类内距离之和 J1=J; } } } float Baseofgeo::computedis(const std::vector<float> p1,const std::vector<float> p2){ assert(p1.size()==3&&p2.size()==3); return std::sqrt(pow(p1[0]-p2[0],2)+pow(p1[1]-p2[1],2)+pow(p1[2]-p2[2],2)); }
主函数嘛:main.cpp
#include<iostream> #include<vector> #include<boost/concept_check.hpp> #include<opencv2/opencv.hpp> #include<time.h> #include<stdlib.h> #include<cassert> #include<memory> #include"base.h" int main(int argc,char** argv){ cv::Mat I=cv::imread("../data/0001.jpg"); cv::imshow("im a pic",I); std::vector<std::vector<float> > listA(I.cols*I.rows,std::vector<float>(4,-1.0)); int nl=d_Ihsv.rows; int nc=d_Ihsv.cols; int ii=0; for(int i=0;i<nl;i++){ for(int j=0;j<nc;j++){ listA[ii][0]=I.at<cv::Vec3f>(i,j)[0]/10.0; listA[ii][1]=I.at<cv::Vec3f>(i,j)[1]/10.0; listA[ii][2]=I.at<cv::Vec3f>(i,j)[2]/10.0; ii++; } } std::shared_ptr<Baseofgeo> basemethod(new Baseofgeo()); basemethod->Gekmeasns(listA,72);//就当我是可视化==,可视化第二期再更 for(auto p:listA){ std::cout<<p[3]<<std::endl; } return 0; }
还有CMakeLists.txt
cmake_minimum_required( VERSION 2.8 ) project ( image ) set(OpenCV_DIR "/home/geo/opencv-2.4.13/build") add_compile_options(-std=c++11) find_package(OpenCV REQUIRED) add_library(base base.cpp) target_link_libraries( base ${OpenCV_LIBS}) add_executable(main main.cpp) target_link_libraries(main base ${OpenCV_LIBS})
时间: 2024-10-10 03:13:38