最近几天刚刚接触机器学习,学完K-Means聚类算法。正好又赶上一个课程项目是识别“手写数字”,因为KMeans能够实现聚类,因此自然而然地想要通过KMeans来实现。

前排提示:这是kmeans聚类的一个失败案例,没有成功聚类,仅供参考。

一,什么是KMeans聚类算法??

非常传统的聚类算法,目的是将一堆数据进行分类。

它的思想很朴素:假设这里有一群点,要将这些点分成两类。要是分成的类很合理的话,那不同类之间的中心点相聚是不是应该足够大,中心点附近的同一类的点是不是应该足够多?

举个例子:

a表示的是一堆原始点,没有处理。要将a聚类成两类,先随便找到两个点,计算所有点到这两个点的距离(欧式距离,曼哈顿距离,闵式距离等等都可以),根据距离最近的原则分配成两类。这时候是不是就能够得到两类的中心点,然后再次重复操作,直到最后聚出来的类不会发生变化。

so easy 是不是

二,使用的手写数字测试集??

我们在这里使用的是mnist测试集。这家伙的知名程度在机器学习中相当于是hello world了。不知道的小伙伴可以去查查。

但是一定有人会问到,mnist测试集应该怎么通过java使用呢?

不用担心,我用Python通过TensorFlow将mnist测试集打包成了txt文件,用java的文件操作直接调用就可以了。

具体效果像这样:

这是28 * 28的二维int数组,每个值介于0到255之间,熟悉图像处理的小伙伴一定知道这是灰度值,0表示最黑,255表示最亮,因此这是黑纸白字的测试集,大家要是自己写测试数据的使用要记着对图片进行预处理,要不然可能会出错。

我将txt命名为:数字名-标号的形式,方便之后训练和测试。

三,java手撕KMeans算法

先摆上一个算法流程图

1.首先定义:

           训练图片(50000 * 28 * 28 的三维数组)

           聚类中心(10 * 28 * 28的三维数组)

           每张图片到聚类中心的距离(50000 * 10 的二维数组)

           旧的类和新的类(ArrayList[] 数组,因为不知道一个类中到底会有多少个图片)

    static float[][][] num = new float[50000][28][28];static float[][][] center = new float[10][28][28];// 聚类中心static long[][] distance = new long[num.length][10];static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类static ArrayList<Integer>[] newKinds = new ArrayList[10];

2.定义方法:

        从Txt文件导入测试数据的方法

public static void getTXT(String path,int img,int x,int y) throws IOException {File file = new File(path);FileInputStream fis = new FileInputStream(file);InputStreamReader isr = new InputStreamReader(fis);BufferedReader br = new BufferedReader(isr);String line;while((line = br.readLine()) != null){boolean isNum = false;for(int i = 0;i < line.length();i ++){if(line.charAt(i) != ' ' && !isNum){// 如果遇到数字isNum = true;float tempNum = 0;// 取数字while(i < line.length() && line.charAt(i) != ' '){tempNum = tempNum * 10 + line.charAt(i) - '0';i++;}isNum = false;if(y < 28){}else{y = 0;x ++;}num[img][x][y] = tempNum;y++;}}}br.close();}

        获得图片到聚类中心距离的方法

    // 得到距离public static long getDistance(float[][] n,float[][] k){long ret = 0;for (int i = 0; i < 28;i ++){for (int j = 0; j < 28; j ++){ret += Math.pow((n[i][j] - k[i][j]),2);}}return ret;}

        得到图片距离最近聚类中心索引的方法

    // 获得数组元素最小值对应的下标public static int getMinIndex(long dis[]){int index = -1;long min = Integer.MAX_VALUE;for(int i = 0; i < 10;i ++){if(dis[i] < min){index = i;min = dis[i];}}return index;}

        比较旧的聚类和新的聚类是否相同的方法

    public static boolean isSame(){for(int i = 0; i < 10 ;i ++){for(int j = 0; j < newKinds[i].size();j ++){if(newKinds[i].size() != oldKinds[i].size()) return false;if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {return false;}}}return true;}

需要注意的是!!!

两个Integer的比较需要通过.intValue()的方法先转换成为int!!!再进行比较,否则会因为内存什么什么奇奇怪怪的原因导致出现130 != 130这种很天真的错误。

我在这里被坑了一次,希望看到这片文章的人能够避一下坑。

3.开始while(true)死循环,直到旧类和新类相等不发生改变

        int kindTime = 0;while(true){// 3.计算每个文件和当前类中心之间的距离for (int i = 0; i < num.length; i++){for (int j = 0; j < 10; j++){distance[i][j] = getDistance(num[i],center[j]);}}// 更新旧类for(int i = 0;i < 10;i ++){oldKinds[i].clear();for(int j = 0 ; j < newKinds[i].size();j ++){oldKinds[i].add(newKinds[i].get(j));}}// 更新新类for (int i = 0; i < 10 ; i ++){newKinds[i].clear();}for (int i = 0; i < num.length; i ++){// 获得距离最小值,将其放到对应的类中newKinds[getMinIndex(distance[i])].add(i);}// 4.更新聚类中心for(int i = 0; i < 10; i ++){for(int x = 0; x < 28; x++){for(int y = 0; y < 28;y ++){center[i][x][y] = getAverage(newKinds[i],x,y);}}}// 5.重复步骤,直到类不再发生改变if(isSame()){break;}System.out.println("第"+kindTime+"次聚类");kindTime++;}

4.保存类中心点

因为如果训练数据不变的话,聚类聚出的中心是不会变化的,所以为了避免之后聚类的重复操作,我们还是将得到的聚类中心点保存成为txt文件放到电脑上比较好。

    // 保存聚类中心点public static void saveKind(int index){FileWriter out = null;String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";File file = new File(path);try {out = new FileWriter(file);//二维数组按行存入到文件中for (int i = 0; i < center[index].length; i++) {for (int j = 0; j < center[index][i].length; j++) {//将每个元素转换为字符串String content = String.valueOf(center[index][i][j]) + " ";out.write(content + "\t");}out.write("\r\n");}out.close();} catch (IOException e) {e.printStackTrace();}}

到现在,所有kmeans要求的操作我们都已经实现了。我们看看效果怎么样吧

1.我从test测试集(刚刚是train训练集)中导入了8000张图片,0到9每个数字各800张。

导入的方式和上文中的相同,这里就不在赘述了。

然后通过刚刚聚出来的类中心对测试数据进行聚类。(因为kmeans是无监督聚类吗,所以我也不知道每个类中心代表的哪个数字)

这是最后聚出来的结果:

发现大问题!!!我将每个类聚到的数字分别列出来。比如第0类,聚到4个数字0,3个数字1……

最后得到的结果,很!不!理!想!

通过分析可以看到,数字1的聚类效果最好,800张图片中有787张被聚到第7类中了,但是第7类也混入了不少其他数字,还有129张2是什么鬼?!

其他的类就更不用说了,混杂了很多数字。

经过缜密思考之后,我认为是k的数值设置的问题,因为我们想要聚类出10个数字,所以很主观地将k设置成为了10,没有思考相同数字,因为书写原因而出现的数字内部聚类的问题。

就像数字0,分别被聚到了第1类和第4类中,这两类很少有其他数字。因此是将数字0进行了分类,把高的0矮的0胖的0瘦的0分开了!而不是将0之外的数字分开。

或许可以通过改变k的值进行改进呢!

这片文章才差不多就是这样了。最后贴上代码。

如果有朋友想要mnist手写数字数据集的txt文件,可以给我留言邮箱信息哦,我抽时间会发送的。

欢迎大佬们批评指正!

// 首先是kmeans聚类的代码
import java.io.*;
import java.util.ArrayList;public class KMeans {// KMeans算法实现手写数字聚类static float[][][] num = new float[50000][28][28];static float[][][] center = new float[10][28][28];// 聚类中心static long[][] distance = new long[num.length][10];static ArrayList<Integer>[] oldKinds = new ArrayList[10];// 旧的聚类static ArrayList<Integer>[] newKinds = new ArrayList[10];public static void main(String[] args) throws IOException {// 1.读取文件System.out.println("导入文件中……");for (int i = 0;i < num.length;i ++){getTXT("D:\\Python\\jupyter\\trains2\\" + Integer.toString(i/5000) + "-" + Integer.toString(i%5000 + 1) + ".txt",i,0,0);if(i % 1000 == 0) System.out.println("已导入文件:" + i);}System.out.println("导入文件成功!!!");// 随机选择聚类中心for(int i = 0; i < 10; i ++){oldKinds[i] = new ArrayList<>();}for(int i = 0 ; i < 10;i ++) {transTwoArray(num[i], center[i]);newKinds[i] = new ArrayList<>();newKinds[i].add(i);}int kindTime = 0;while(true){// 3.计算每个文件和当前类中心之间的距离for (int i = 0; i < num.length; i++){for (int j = 0; j < 10; j++){distance[i][j] = getDistance(num[i],center[j]);}}// 更新旧类for(int i = 0;i < 10;i ++){oldKinds[i].clear();for(int j = 0 ; j < newKinds[i].size();j ++){oldKinds[i].add(newKinds[i].get(j));}}// 更新新类for (int i = 0; i < 10 ; i ++){newKinds[i].clear();}for (int i = 0; i < num.length; i ++){// 获得距离最小值,将其放到对应的类中newKinds[getMinIndex(distance[i])].add(i);}// 4.更新聚类中心for(int i = 0; i < 10; i ++){for(int x = 0; x < 28; x++){for(int y = 0; y < 28;y ++){center[i][x][y] = getAverage(newKinds[i],x,y);}}}// 5.重复步骤,直到类不再发生改变if(isSame()){break;}System.out.println("第"+kindTime+"次聚类");kindTime++;}// 保存聚类中心System.out.println("聚类成功!!!");System.out.println("-------------------------");System.out.println("保存类中心点中……");for(int i = 0; i < 10;i ++){saveKind(i);}System.out.println("保存类中心点成功!!!");}// 读取文件public static void getTXT(String path,int img,int x,int y) throws IOException {File file = new File(path);FileInputStream fis = new FileInputStream(file);InputStreamReader isr = new InputStreamReader(fis);BufferedReader br = new BufferedReader(isr);String line;while((line = br.readLine()) != null){boolean isNum = false;for(int i = 0;i < line.length();i ++){if(line.charAt(i) != ' ' && !isNum){// 如果遇到数字isNum = true;float tempNum = 0;// 取数字while(i < line.length() && line.charAt(i) != ' '){tempNum = tempNum * 10 + line.charAt(i) - '0';i++;}isNum = false;if(y < 28){}else{y = 0;x ++;}num[img][x][y] = tempNum;y++;}}}br.close();}// 转移两个数组public static void transTwoArray(float[][] array1,float[][] array2){for(int i = 0; i < 28;i ++){for (int j = 0; j < 28;j ++){array2[i][j] = array1[i][j];}}}// 得到距离public static long getDistance(float[][] n,float[][] k){long ret = 0;for (int i = 0; i < 28;i ++){for (int j = 0; j < 28; j ++){ret += Math.pow((n[i][j] - k[i][j]),2);}}return ret;}// 获得数组元素最小值对应的下标public static int getMinIndex(long dis[]){int index = -1;long min = Integer.MAX_VALUE;for(int i = 0; i < 10;i ++){if(dis[i] < min){index = i;min = dis[i];}}return index;}// 计算均值public static float getAverage(ArrayList<Integer> arr,int x,int y){float ret = 0;for(int i = 0; i < arr.size(); i ++){ret += num[arr.get(i)][x][y];// 将同一类中所有相同位置元素相加}return ret / arr.size();}// 保存聚类中心点public static void saveKind(int index){FileWriter out = null;String path = "D:\\java\\workSpace\\KMeans\\" + index + "kinds.txt";File file = new File(path);try {out = new FileWriter(file);//二维数组按行存入到文件中for (int i = 0; i < center[index].length; i++) {for (int j = 0; j < center[index][i].length; j++) {//将每个元素转换为字符串String content = String.valueOf(center[index][i][j]) + " ";out.write(content + "\t");}out.write("\r\n");}out.close();} catch (IOException e) {e.printStackTrace();}}// 是否相等public static boolean isSame(){for(int i = 0; i < 10 ;i ++){for(int j = 0; j < newKinds[i].size();j ++){if(newKinds[i].size() != oldKinds[i].size()) return false;if (newKinds[i].get(j).intValue() != oldKinds[i].get(j).intValue() ) {return false;}}}return true;}
}

测试聚类中心的代码

import java.io.*;
import java.util.ArrayList;public class myKMeansTest {static float[][][] kMeans = new float[10][28][28];static float[][][] test = new float[8000][28][28];// 测试数据,每个数字有800张static long[][] distance = new long[8000][10];// 每张图片聚类类中心的距离static ArrayList<Integer>[] kinds = new ArrayList[10];// 每个类中包含的图片索引public static void main(String[] args) throws IOException {System.out.println("-----获取文件中-----");// 读取聚类中心文件for(int i = 0; i < 10;i ++){String img = "D:\\java\\workSpace\\KMeans\\" + i + "kinds.txt";getKMeansTxt(img,i);}// 读取测试文件for(int i = 0;i < 8000;i ++){String img = "D:\\Python\\jupyter\\test\\" + i/800 + "-" + (i%800 + 1) + ".txt";getTestTxt(img,i,0,0);if(i % 800 == 0) System.out.println("已导入数据:"+i);}System.out.println("获取文件成功!!");// 进行测试System.out.println("开始聚类……");for(int i = 0; i < 10;i ++){kinds[i] = new ArrayList<>();}for(int i = 0; i < 8000;i ++){for (int j = 0; j < 10;j ++){distance[i][j] = GoodKMeans.getDistance(kMeans[j],test[i]);// 获得每张图片对应聚类中心的距离}}for(int i= 0;i< 8000;i++){kinds[GoodKMeans.getMinIndex(distance[i])].add(i);// 将图片归为最小距离的类中}System.out.println("聚类成功!!");int[][] ans = new int[10][10];for(int i = 0; i < 10;i ++){for(int j = 0; j < kinds[i].size();j ++){if(kinds[i].get(j) < 800) ans[i][0]++;else if(kinds[i].get(j) >= 800 && kinds[i].get(j) < 1600) ans[i][1]++;else if(kinds[i].get(j) >= 1600 && kinds[i].get(j)< 2400) ans[i][2]++;else if(kinds[i].get(j) >= 2400 && kinds[i].get(j)< 3200) ans[i][3]++;else if(kinds[i].get(j) >= 3200 && kinds[i].get(j)< 4000) ans[i][4]++;else if(kinds[i].get(j) >= 4000 && kinds[i].get(j)< 4800) ans[i][5]++;else if(kinds[i].get(j) >= 4800 && kinds[i].get(j)< 5600) ans[i][6]++;else if(kinds[i].get(j) >= 5600 && kinds[i].get(j)< 6400) ans[i][7]++;else if(kinds[i].get(j) >= 6400 && kinds[i].get(j)< 7200) ans[i][8]++;else if(kinds[i].get(j) >= 7200 && kinds[i].get(j)< 8000) ans[i][9]++;}}for (int i = 0; i < 10;i ++){System.out.print("第"+i+"类中:");for (int j = 0; j < 10;j ++){System.out.print(j+":");System.out.printf("%3d",ans[i][j]);System.out.print("\t");}System.out.println();}}// 获得聚类中心文件public static void getKMeansTxt(String img,int index) throws IOException {File file = new File(img);FileInputStream fis = new FileInputStream(file);InputStreamReader isr = new InputStreamReader(fis);BufferedReader br = new BufferedReader(isr);int x = 0;int y = 0;String line;while((line = br.readLine()) != null){boolean isNum = false;for(int i = 0;i < line.length();i ++){if(line.charAt(i)-'0' <10 && line.charAt(i)-'0' >=0 && !isNum){// 如果遇到数字isNum = true;// 取数字int j = i + 1;while(j < line.length() && line.charAt(j) != ' '){j++;}isNum = false;if(y < 28){}else{y = 0;x ++;}kMeans[index][x][y] = Float.valueOf(line.substring(i,j)).floatValue();i = j;y++;}}}br.close();}// 获得测试文件public static void getTestTxt(String path,int img,int x,int y) throws IOException {File file = new File(path);FileInputStream fis = new FileInputStream(file);InputStreamReader isr = new InputStreamReader(fis);BufferedReader br = new BufferedReader(isr);String line;while((line = br.readLine()) != null){boolean isNum = false;for(int i = 0;i < line.length();i ++){if(line.charAt(i) != ' ' && !isNum){// 如果遇到数字isNum = true;float tempNum = 0;// 取数字while(i < line.length() && line.charAt(i) != ' '){tempNum = tempNum * 10 + line.charAt(i) - '0';i++;}isNum = false;if(y < 28){}else{y = 0;x ++;}test[img][x][y] = tempNum;y++;}}}br.close();}
}

java手撕KMeans算法实现手写数字聚类(失败案例)相关推荐

  1. [ 数据结构 -- 手撕排序算法第三篇 ] 希尔排序

    手撕排序算法系列之:希尔排序. 从本篇文章开始,我会介绍并分析常见的几种排序,大致包括插入排序,冒泡排序,希尔排序,选择排序,堆排序,快速排序,归并排序等. 大家可以点击此链接阅读其他排序算法:排序算 ...

  2. python手撕分水岭算法

    python手撕分水岭算法 1 分水岭算法实现 主要思路就是: 利用一个优先队列与有序队列(有序队列其实可以不用).优先队列是按像素的灰度值排列的,灰度值低的先被淹. 通过统计像素的附近的点的标记种类 ...

  3. [ 数据结构 -- 手撕排序算法第四篇 ] 选择排序

    手撕排序算法系列之第四篇:选择排序. 从本篇文章开始,我会介绍并分析常见的几种排序,大致包括直接插入排序,冒泡排序,希尔排序,选择排序,堆排序,快速排序,归并排序等. 大家可以点击此链接阅读其他排序算 ...

  4. [ 数据结构 -- 手撕排序算法第二篇 ] 冒泡排序

    手撕排序算法系列之:冒泡排序. 从本篇文章开始,我会介绍并分析常见的几种排序,大致包括插入排序,冒泡排序,希尔排序,选择排序,堆排序,快速排序,归并排序等. 大家可以点击此链接阅读其他排序算法:排序算 ...

  5. 手撕python_GitHub - caishiqing/manual: 手撕机器学习

    手撕机器学习 用腻了开源框架,尝试下手撕机器学习模型?写这个手撕机器学习系列,旨在不使用任何开源框架的条件下手推实现各种模型,同时保证高性能. Requirements 适用于python2.7与py ...

  6. 手撕包菜 mysql_手撕包菜搭建

    概述 最近做了两件事,一件事就是买了块1t硬盘,第二件事就是买了个百度云会员,无奈找不到资源下载,那就没办法了,搭建一个磁力链接搜索引擎来爬去链接,然后去找资源. 说道磁力链接搜索引擎,最好的当然是手 ...

  7. OpenCV4学习笔记(55)——基于KNN最近邻算法实现鼠标手写数字识别

    在上一篇博客<OpenCV4学习笔记(54)>中,整理了关于KNN最近邻算法的一些相关内容和一个手写体数字识别的例子.但是上次所实现的手写体数字识别,每次只能固定地输入测试图像进行预测,而 ...

  8. K-means 算法实现二维数据聚类

    所谓聚类分析,就是给定一个元素集合D,其中每个元素具有n个观测属性,对这些属性使用某种算法将D划分成K个子集,要求每个子集内部的元素之间相似度尽可能高,而不同子集的元素相似度尽可能低.聚类分析是一种无 ...

  9. matlab对手写数字聚类的方法_scikitlearn — 聚类

    可以使用模块sklearn.cluster对未标记的数据进行聚类.每个聚类算法都有两种变体:一个是类(class)实现的 fit方法来学习训练数据上的聚类:另一个是函数(function)实现,给定训 ...

最新文章

  1. 新配windows服务器及上边功能的试用体会
  2. linux常用命令-查看文本/cat,tac,more,less,head,tail
  3. hacker:Python通过对简单的WIFI弱口令实现自动实时破解
  4. java await signal_java Condtion await方法和signal方法解析
  5. windy数(BZOJ-1026)
  6. python中 [ 闭包 ] 小结
  7. 如何在excel 单元格中增加换行
  8. 机器学习-吴恩达-笔记-4-神经网络描述
  9. 《LeetCode刷题C/C++版答案》pdf出炉,白瞟党乐坏了
  10. 青岛理工大学QUT期末考试《电子商务概论》思维导图
  11. 程序员面试的注意事项(一):面试的流程
  12. 边界路由linux,路由表构成简介(Destination/Gateway/Genmask/Iface)
  13. erp系统云端服务器,erp系统软件云服务器
  14. FMI飞马网 |【线上直播】京东商城的通用代码质量提升方案
  15. learning java AWT Pannel
  16. namenode无法启动,There appears to be a gap in the edit log. We expected txid 10323, but got txid 10324.
  17. Spark随笔(三):straggler的产生原因
  18. Edge AI边缘智能:Communication-Efficient Edge AI: Algorithms and Systems(未完待续)
  19. python 视频加字幕_【小技巧】用Python给你的视频添加字幕
  20. 关于LeakCanary检测华为手机内存泄漏问题

热门文章

  1. 自由操控声音-相位声码器-变速篇(一)
  2. HX=JE,HX-JE芯片,无感4.9V升压ic电路图PDF应用技术
  3. main c语言中变量的定义,C语言中在main函数中定义的变量是全局变量么_后端开发...
  4. [2018-5-4]BNUZ你们还差得远呢
  5. 3399使用GPIO口模拟i2c升级NT68411
  6. 博通语法纠错技术方案入选ACL2022,论文详细解读
  7. 我们的征途是星辰大海 ( 蓝桥杯~算法提高 )
  8. 技术不是越来越简单,而是框架是你的羁绊
  9. 计算机主板会自动切断电源是怎么回事,电脑开机自动断电怎么办
  10. CSS: Animation CSS:动画 Lynda课程中文字幕