打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
我的OpenCV学习笔记(六):使用支持向量机(SVM)

SVM是2000年左右提出的一种新的分类方法,着重解决了小样本分类问题。具体原理可以参看模式识别的书籍。OpenCV中的SVM的实现也是基于大名鼎鼎的SVM 库:http://www.csie.ntu.edu.tw/~cjlin。OpenCV教程中有两个例子,一个是线性可分的,一个是线性不可分的,我对他们做了详尽的注释:

先看线性可分时:

  1. #include <opencv2/core/core.hpp>  
  2. #include <opencv2/highgui/highgui.hpp>  
  3. #include <opencv2/ml/ml.hpp>  
  4.   
  5. using namespace cv;  
  6.   
  7. int main()  
  8. {  
  9.     // Data for visual representation  
  10.     int width = 512, height = 512;  
  11.     Mat image = Mat::zeros(height, width, CV_8UC3);  
  12.   
  13.     // Set up training data  
  14.     float labels[5] = {1.0, -1.0, -1.0, -1.0,1.0};  
  15.     Mat labelsMat(5, 1, CV_32FC1, labels);  
  16.   
  17.   
  18.     float trainingData[5][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501},{501,128} };  
  19.     Mat trainingDataMat(5, 2, CV_32FC1, trainingData);  
  20.   
  21.     //设置支持向量机的参数  
  22.     CvSVMParams params;  
  23.     params.svm_type    = CvSVM::C_SVC;//SVM类型:使用C支持向量机  
  24.     params.kernel_type = CvSVM::LINEAR;//核函数类型:线性  
  25.     params.term_crit   = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);//终止准则函数:当迭代次数达到最大值时终止  
  26.   
  27.     //训练SVM  
  28.     //建立一个SVM类的实例  
  29.     CvSVM SVM;  
  30.     //训练模型,参数为:输入数据、响应、XX、XX、参数(前面设置过)  
  31.     SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);  
  32.       
  33.     Vec3b green(0,255,0), blue (255,0,0);  
  34.     //显示判决域  
  35.     for (int i = 0; i < image.rows; ++i)  
  36.         for (int j = 0; j < image.cols; ++j)  
  37.         {  
  38.                         Mat sampleMat = (Mat_<float>(1,2) << i,j);  
  39.             //predict是用来预测的,参数为:样本、返回值类型(如果值为ture而且是一个2类问题则返回判决函数值,否则返回类标签)、  
  40.             float response = SVM.predict(sampleMat);  
  41.   
  42.             if (response == 1)  
  43.                 image.at<Vec3b>(j, i)  = green;  
  44.             else if (response == -1)   
  45.                  image.at<Vec3b>(j, i)  = blue;  
  46.         }  
  47.   
  48.     //画出训练数据  
  49.     int thickness = -1;  
  50.     int lineType = 8;  
  51.     circle( image, Point(501,  10), 5, Scalar(  0,   0,   0), thickness, lineType);//画圆  
  52.     circle( image, Point(255,  10), 5, Scalar(255, 255, 255), thickness, lineType);  
  53.     circle( image, Point(501, 255), 5, Scalar(255, 255, 255), thickness, lineType);  
  54.     circle( image, Point( 10, 501), 5, Scalar(255, 255, 255), thickness, lineType);  
  55.     circle(image, Point( 501, 128), 5, Scalar(0, 0, 0), thickness, lineType);  
  56.   
  57.     //显示支持向量  
  58.     thickness = 2;  
  59.     lineType  = 8;  
  60.     //获取支持向量的个数  
  61.     int c     = SVM.get_support_vector_count();  
  62.   
  63.     for (int i = 0; i < c; ++i)  
  64.     {  
  65.         //获取第i个支持向量  
  66.         const float* v = SVM.get_support_vector(i);  
  67.         //支持向量用到的样本点,用灰色进行标注  
  68.         circle( image,  Point( (int) v[0], (int) v[1]),   6,  Scalar(128, 128, 128), thickness, lineType);  
  69.     }  
  70.   
  71.     imwrite("result.png", image);        // save the image   
  72.   
  73.     imshow("SVM Simple Example", image); // show it to the user  
  74.     waitKey(0);  
  75.   
  76. }  


 

线性不可分时由于样本较多,训练的时间比较长:

  1. #include <iostream>  
  2. #include <opencv2/core/core.hpp>  
  3. #include <opencv2/highgui/highgui.hpp>  
  4. #include <opencv2/ml/ml.hpp>  
  5. #include "time.h"  
  6.   
  7. using namespace cv;  
  8. using namespace std;  
  9. //程序说明:  
  10. //一共两个样本集每个样本集有100个样本,其中90个是线性可分的,10个线型不可分  
  11. //这200个样本数据储存在trainData内:trainData是一个200行2列的矩阵,其中第一列储存样本的X值,第二列储存的是样本的Y值  
  12. //每一列的前90个元素是第一类的线性可分部分,后90个元素是第二类的线性可分部分,中间的20个元素是线性不可分部分  
  13. //第一类样本的X值分布在整幅图像的[0,0.4]范围内,第二类样本的X值分布在整幅图像的[0.6,1]范围内,中间的[0.4,0.6]是线性不可分的部分;这三部分的Y值都在整幅图像的高度内自由分布  
  14.   
  15.   
  16. //每个样本集的数量  
  17. #define NTRAINING_SAMPLES 100  
  18.   
  19. //其中的线性部分  
  20. #define FRAC_LINEAR_SEP   0.9f  
  21.   
  22. int main()  
  23. {  
  24.     //定义显示结果的图像  
  25.     //图像的宽度、高度  
  26.     const int WIDTH = 512,HEIGHT = 512;  
  27.     Mat image = Mat::zeros(HEIGHT,WIDTH,CV_8UC3);  
  28.   
  29.   
  30.   
  31.     //************第一步:设定训练数据***********  
  32.     //************1.设定数据结构****************  
  33.     //承载训练数据的结构  
  34.     Mat trainData(2*NTRAINING_SAMPLES,2,CV_32FC1);  
  35.     //承载这些数据分类的结构  
  36.     Mat labels(2*NTRAINING_SAMPLES,1,CV_32FC1);  
  37.     //设定随机数种子  
  38.     RNG rng(100);  
  39.     //设定线性可分部分的数据量  
  40.     int nLinearSamples = (int) (NTRAINING_SAMPLES*FRAC_LINEAR_SEP);  
  41.   
  42.     //**************2.设定第一类中的数据*********  
  43.     //从整个数据集中取出前[0,89]行  
  44.     //注:*Range的范围是[a,b)  
  45.     Mat trainClass = trainData.rowRange(0,nLinearSamples);  
  46.     //取出第一列  
  47.     Mat c = trainClass.colRange(0,1);  
  48.     //随机生成X的值:[0,0.4*WIDTH]  
  49.     rng.fill(c,RNG::UNIFORM,Scalar(1),Scalar(0.4*WIDTH));  
  50.     //取出第二列  
  51.     c = trainClass.colRange(1,2);  
  52.     //随机生成Y的值  
  53.     rng.fill(c,RNG::UNIFORM,Scalar(1),Scalar(HEIGHT));  
  54.   
  55.     //**************2.设定第二类的数据*************  
  56.     //从整个数据中取出[110,199]行  
  57.     trainClass = trainData.rowRange(2*NTRAINING_SAMPLES-nLinearSamples,2*NTRAINING_SAMPLES);  
  58.     //取出第一列  
  59.     c = trainClass.colRange(0,1);  
  60.     //随机生成X的值[0.6*WIDTH,WIDTH]  
  61.     rng.fill(c,RNG::UNIFORM,Scalar(0.6*WIDTH),Scalar(WIDTH));  
  62.     //取出第二列  
  63.     c = trainClass.colRange(1,2);  
  64.     //随机生成Y的值  
  65.     rng.fill(c,RNG::UNIFORM,Scalar(1),Scalar(HEIGHT));  
  66.   
  67.     //***************3.设定线性不可分的数据***********  
  68.     //取出[90,109]行  
  69.     trainClass = trainData.rowRange(nLinearSamples,2*NTRAINING_SAMPLES-nLinearSamples);  
  70.     //取出第一列  
  71.     c = trainClass.colRange(0,1);  
  72.     //随机生成X的值[0.4*WIDTH,0.6*WIDTH]  
  73.     rng.fill(c,RNG::UNIFORM,Scalar(0.4*WIDTH),Scalar(0.6*WIDTH));  
  74.     //取出第二列  
  75.     c = trainClass.colRange(1,2);  
  76.     //随机生成Y的值  
  77.     rng.fill(c,RNG::UNIFORM,Scalar(1),Scalar(HEIGHT));  
  78.   
  79.   
  80.     //***************4.为所有数据设置标签**********  
  81.     //前100个数据设为第一类  
  82.     labels.rowRange(0,NTRAINING_SAMPLES).setTo(1);  
  83.     //后100个数据设为第二类  
  84.     labels.rowRange(NTRAINING_SAMPLES,2*NTRAINING_SAMPLES).setTo(2);  
  85.   
  86.   
  87.     //**************第二步:设置SVM参数***********  
  88.     CvSVMParams params;  
  89.     //SVM类型: C-Support Vector Classification  
  90.     params.svm_type     = SVM::C_SVC;  
  91.   
  92.     params.C            = 0.1;  
  93.     //和函数类型:Linear kernel  
  94.     params.kernel_type  = SVM::LINEAR;  
  95.     //终止准则:当迭代次数到达最大值后终止  
  96.     params.term_crit    = TermCriteria(CV_TERMCRIT_ITER,(int) 1e7,1e-6);  
  97.   
  98.   
  99.   
  100.     //**************第三步:训练SVM***********  
  101.     cout<<"开始训练过程"<<endl;  
  102.     //开始计时  
  103.     clock_t start,finish;  
  104.     double duration;  
  105.     start = clock();  
  106.     //*************1.建立一个SVM实例**********  
  107.     CvSVM svm;  
  108.     //*************2.调用训练函数*************  
  109.     svm.train(trainData,labels,Mat(),Mat(),params);  
  110.     //结束计时  
  111.     finish = clock();  
  112.     duration = (double)(finish-start) / CLOCKS_PER_SEC;  
  113.     cout<<"训练过程结束,共耗时:"<<duration<<"秒"<<endl;  
  114.   
  115.   
  116.   
  117.   
  118.     //************第四步:显示判决域************  
  119.     //第一类用绿色;第二类用蓝色  
  120.     Vec3b green(0,100,0),blue(100,0,0);  
  121.     for(int i = 0; i < image.rows; ++i)  
  122.     {  
  123.         for(int j = 0; j < image.cols; ++j)  
  124.         {  
  125.             Mat sampleMat = (Mat_<float>(1,2)<<i,j);  
  126.             float response = svm.predict(sampleMat);  
  127.             if (response == 1)  
  128.             {  
  129.                 image.at<Vec3b>(j,i) = green;  
  130.             }  
  131.             else if (response == 2)  
  132.             {  
  133.                 image.at<Vec3b>(j,i) = blue;  
  134.             }  
  135.         }  
  136.     }  
  137.   
  138.   
  139.   
  140.     //************第五步:显示训练数据************  
  141.     //红色  
  142.     //负数会导致画出的图型是实心的  
  143.     int thick = -1;  
  144.     int lineType = 8;  
  145.     float px,py;  
  146.     //************1.第一类*************  
  147.     for(int i = 0; i < NTRAINING_SAMPLES; ++i)  
  148.     {  
  149.         px = trainData.at<float>(i,0);  
  150.         py = trainData.at<float>(i,1);  
  151.         circle(image,Point((int)px,(int)py),3,Scalar(0,255,0));  
  152.     }  
  153.     //***********2.第二类****************  
  154.     for(int i = NTRAINING_SAMPLES; i < 2*NTRAINING_SAMPLES; ++i)  
  155.     {  
  156.         px = trainData.at<float>(i,0);  
  157.         py = trainData.at<float>(i,1);  
  158.         circle(image,Point((int)px,(int)py),3,Scalar(255,0,0));       
  159.     }  
  160.   
  161.   
  162.   
  163.     //***********第六步:显示支持向量*************  
  164.     thick = 2;  
  165.     lineType = 8;  
  166.     //获取支持向量的个数  
  167.     int x = svm.get_support_vector_count();  
  168.     for(int i = 0; i < x; ++i)  
  169.     {  
  170.         const float* v = svm.get_support_vector(i);  
  171.         circle(image,Point((int)v[0],(int)v[1]),6,Scalar(128,128,128),thick,lineType);  
  172.     }  
  173.     imshow("分类结果",image);  
  174.     waitKey(0);  
  175.     return 0;  
  176. }  


其实我对SVM的理解也只是照猫画虎,当训练数据是高维情况时,也完全不知所措,以后要是需要在这方面有深入研究的话,在仔细考虑吧!

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
【转】 利用Open CV和SVM实现问题识别
ML之SVM:基于Js代码利用SVM算法的实现根据Kaggle数据集预测泰坦尼克号生存人员
MATLAB:多个旅行商问题MTSP算法 (1)
OpenCV学习笔记(四十)——再谈OpenCV数据结构Mat详解
SVM:从理论到OpenCV实践
LibSVM for Python 使用
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服