【模式识别】K-近邻分类算法KNN

一时失言乱红尘 2022-08-26 14:22 351阅读 0赞

K-近邻(K-Nearest Neighbors, KNN)是一种很好理解的分类算法,简单说来就是从训练样本中找出K个与其最相近的样本,然后看这K个样本中哪个类别的样本多,则待判定的值(或说抽样)就属于这个类别。

KNN算法的步骤

  • 计算已知类别数据集中每个点与当前点的距离;
  • 选取与当前点距离最小的K个点;
  • 统计前K个点中每个类别的样本出现的频率;
  • 返回前K个点出现频率最高的类别作为当前点的预测分类。

OpenCV中使用CvKNearest

OpenCV中实现CvKNearest类可以实现简单的KNN训练和预测。

  1. int main()
  2. {
  3. float labels[10] = {0,0,0,0,0,1,1,1,1,1};
  4. Mat labelsMat(10, 1, CV_32FC1, labels);
  5. cout<<labelsMat<<endl;
  6. float trainingData[10][2];
  7. srand(time(0));
  8. for(int i=0;i<5;i++){
  9. trainingData[i][0] = rand()%255+1;
  10. trainingData[i][1] = rand()%255+1;
  11. trainingData[i+5][0] = rand()%255+255;
  12. trainingData[i+5][1] = rand()%255+255;
  13. }
  14. Mat trainingDataMat(10, 2, CV_32FC1, trainingData);
  15. cout<<trainingDataMat<<endl;
  16. CvKNearest knn;
  17. knn.train(trainingDataMat,labelsMat,Mat(), false, 2 );
  18. // Data for visual representation
  19. int width = 512, height = 512;
  20. Mat image = Mat::zeros(height, width, CV_8UC3);
  21. Vec3b green(0,255,0), blue (255,0,0);
  22. for (int i = 0; i < image.rows; ++i){
  23. for (int j = 0; j < image.cols; ++j){
  24. const Mat sampleMat = (Mat_<float>(1,2) << i,j);
  25. Mat response;
  26. float result = knn.find_nearest(sampleMat,1);
  27. if (result !=0){
  28. image.at<Vec3b>(j, i) = green;
  29. }
  30. else
  31. image.at<Vec3b>(j, i) = blue;
  32. }
  33. }
  34. // Show the training data
  35. for(int i=0;i<5;i++){
  36. circle( image, Point(trainingData[i][0], trainingData[i][1]),
  37. 5, Scalar( 0, 0, 0), -1, 8);
  38. circle( image, Point(trainingData[i+5][0], trainingData[i+5][1]),
  39. 5, Scalar(255, 255, 255), -1, 8);
  40. }
  41. imshow("KNN Simple Example", image); // show it to the user
  42. waitKey(10000);
  43. }

使用的是之前BP神经网络中的例子,分类结果如下:

20140415201435109

预测函数find_nearest()除了输入sample参数外还有些其他的参数:

  1. float CvKNearest::find_nearest(const Mat& samples, int k, Mat* results=0,
  2. const float** neighbors=0, Mat* neighborResponses=0, Mat* dist=0 )

20140415201848234

即,samples为样本数*特征数的浮点矩阵;K为寻找最近点的个数;results与预测结果;neibhbors为k*样本数的指针数组(输入为const,实在不知为何如此设计);neighborResponse为样本数*k的每个样本K个近邻的输出值;dist为样本数*k的每个样本K个近邻的距离。

另一个例子

OpenCV refman也提供了一个类似的示例,使用CvMat格式的输入参数:

  1. int main( int argc, char** argv )
  2. {
  3. const int K = 10;
  4. int i, j, k, accuracy;
  5. float response;
  6. int train_sample_count = 100;
  7. CvRNG rng_state = cvRNG(-1);
  8. CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 );
  9. CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 );
  10. IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
  11. float _sample[2];
  12. CvMat sample = cvMat( 1, 2, CV_32FC1, _sample );
  13. cvZero( img );
  14. CvMat trainData1, trainData2, trainClasses1, trainClasses2;
  15. // form the training samples
  16. cvGetRows( trainData, &trainData1, 0, train_sample_count/2 );
  17. cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) );
  18. cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count );
  19. cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) );
  20. cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 );
  21. cvSet( &trainClasses1, cvScalar(1) );
  22. cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count );
  23. cvSet( &trainClasses2, cvScalar(2) );
  24. // learn classifier
  25. CvKNearest knn( trainData, trainClasses, 0, false, K );
  26. CvMat* nearests = cvCreateMat( 1, K, CV_32FC1);
  27. for( i = 0; i < img->height; i++ )
  28. {
  29. for( j = 0; j < img->width; j++ )
  30. {
  31. sample.data.fl[0] = (float)j;
  32. sample.data.fl[1] = (float)i;
  33. // estimate the response and get the neighbors’ labels
  34. response = knn.find_nearest(&sample,K,0,0,nearests,0);
  35. // compute the number of neighbors representing the majority
  36. for( k = 0, accuracy = 0; k < K; k++ )
  37. {
  38. if( nearests->data.fl[k] == response)
  39. accuracy++;
  40. }
  41. // highlight the pixel depending on the accuracy (or confidence)
  42. cvSet2D( img, i, j, response == 1 ?
  43. (accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) :
  44. (accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) );
  45. }
  46. }
  47. // display the original training samples
  48. for( i = 0; i < train_sample_count/2; i++ )
  49. {
  50. CvPoint pt;
  51. pt.x = cvRound(trainData1.data.fl[i*2]);
  52. pt.y = cvRound(trainData1.data.fl[i*2+1]);
  53. cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED );
  54. pt.x = cvRound(trainData2.data.fl[i*2]);
  55. pt.y = cvRound(trainData2.data.fl[i*2+1]);
  56. cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED );
  57. }
  58. cvNamedWindow( "classifier result", 1 );
  59. cvShowImage( "classifier result", img );
  60. cvWaitKey(0);
  61. cvReleaseMat( &trainClasses );
  62. cvReleaseMat( &trainData );
  63. return 0;
  64. }

分类结果:

20140415201457109

KNN的思想很好理解,也非常容易实现,同时分类结果较高,对异常值不敏感。但计算复杂度较高,不适于大数据的分类问题。

(转载请注明作者和出处:http://blog.csdn.net/xiaowei_cqu 未经允许请勿用于商业用途)

发表评论

表情:
评论列表 (有 0 条评论,351人围观)

还没有评论,来说两句吧...

相关阅读

    相关 K-近邻(KNN)算法

    因为自己的好奇心,所以做了这一篇关于KNN 算法的笔记。 文章目录 一、简介 二、KNN算法实现 2.1实现步骤 2.2代码

    相关 K-近邻算法(KNN)

         拜读大神的系列教程,大神好像姓崔(猜测),大神根据《机器学习实战》来讲解,讲的很清楚,读了大神的博客后,我也把我自己吸收的写下来,可能有很多错误之处,希望拍砖(拍轻点

    相关 knnk近邻算法

    一、什么是knn算法 knn算法实际上是利用训练数据集对特征向量空间进行划分,并作为其分类的模型。其输入是实例的特征向量,输出为实例的类别。寻找最近的k个数据,推测新数据...