k-means算法又称k-均值算法,是机器学习聚类算法中的一种,是一种基于形心的划分方法,其中每个簇的中心都用簇中所有对象的均值来表示。其思想如下:

输入:

  • k:簇的数目;
  • D:包含n个对象的数据集。

输出:k个簇的集合。

方法:

  1. 从D中随机选择几个对象作为起始质心;
  2. 对每个质心,计算每个数据到各个质心的距离,并把这些点分配到离该质心最短的距离的簇;
  3. 对每个簇,计算簇中所有点的均值并将此均值作为新的质心;
  4. 将数据点按照新的中心重新聚类;
  5. 重复【步骤3】,直到质心不再发生变化(新的质心和原来的质心相等);
  6. 输出聚类结果。

算法实现:

木羊的k-means算法实现包括5各类。其中,DBConnection.java用于连接数据库,SelectData.java用于从数据库里读取数据,Point.java存放点对象模型,ManagePoint.java是对点的操作,Kmeans.java是算法的核心思想及主函数入口。以下分别给出各个类的详细代码:

DBConnection.java

数据集获取,在机器学习数据集获取官方网站UCI中点击打开链接,木羊已经把该数据集从txt文档中插入到数据库,并去除了最后一列(花类别)。读者若不熟悉数据库的读写,请百度。若木羊有时间,会在后面的博文中补充把txt文档内容读到数据库中的内容。

<span style="font-size:18px;">package db;import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;/*** * 数据库连接类* */
public class DBConnection {public static final String driver = "com.mysql.jdbc.Driver";public static final String url = "jdbc:mysql://localhost:3306/mydb";public static final String user = "root";public static final String pwd = "123";public static Connection dBConnection() {Connection con = null;try {// 加载mysql驱动器Class.forName(driver);// 建立数据库连接con = DriverManager.getConnection(url, user, pwd);} catch (ClassNotFoundException e) {// TODO Auto-generated catch blockSystem.out.println("加载驱动器失败");e.printStackTrace();} catch (SQLException e) {// TODO Auto-generated catch blockSystem.out.println("注册驱动器失败");e.printStackTrace();}return con;}
}</span>

数据库中的数据字段如下(共有150条数据):

SelectData.java

package dao;import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;import model.Point;
import db.DBConnection;/*** * 取出数据* * @return pointList* */
public class SelectData {public static final String SELECT = "select* from iris_Kmeans";public ArrayList<Point> getPoints() throws SQLException {ArrayList<Point> pointsList = new ArrayList<Point>();Connection con = DBConnection.dBConnection();ResultSet rs;// 创建一个PreparedStatement对象PreparedStatement pstmt = con.prepareStatement(SELECT);rs = pstmt.executeQuery();while (rs.next()) {Point point = new Point();point.setX(rs.getDouble(2));point.setY(rs.getDouble(3));point.setZ(rs.getDouble(4));point.setW(rs.getDouble(5));pointsList.add(point);}System.out.println("数据集: " + pointsList);pstmt.close();rs.close();con.close();return pointsList;}
}

Point.java

此处要注意重写equal和hashcode方法以便后面质心的比较。

package model;public class Point {private double x;private double y;private double z;private double w;public double getX() {return x;}public void setX(double x) {this.x = x;}public double getY() {return y;}public void setY(double y) {this.y = y;}public double getZ() {return z;}public void setZ(double z) {this.z = z;}public double getW() {return w;}public void setW(double w) {this.w = w;}public Point() {}public Point(double x, double y, double z, double w) {super();this.x = x;this.y = y;this.z = z;this.w = w;}@Overridepublic String toString() {return "Point [ x=" + x + ", y=" + y + ", z=" + z + ", w=" + w + "]";}@Overridepublic boolean equals(Object obj) {Point point = (Point) obj;if (this.getX() == point.getX() && this.getY() == point.getY()&& this.getZ() == point.getZ() && this.getW() == point.getW()) {return true;}return false;}@Overridepublic int hashCode() {return (int) (x + y + z + w);}
}

ManagePoint.java

该类包含了3个方法,分别用于计算两个点的欧氏距离,比较前后两个质心是否相同,更新质心。

package util;import java.util.ArrayList;
import java.util.Map;import model.Point;public class ManagePoint {/*** * 计算两点之间的距离* * @param p*            第一个点* @param q*            第二个点* @return distance* */public double getDistance(Point p, Point q) {double dx = p.getX() - q.getX();double dy = p.getY() - q.getY();double dz = p.getZ() - q.getZ();double dw = p.getW() - q.getW();double distance = Math.sqrt(dx * dx + dy * dy + dz * dz + dw * dw);return distance;}/*** 判断前后两个质心是否相同* * @param nowCenterCluster*            现在的质心* @param lastCenterCluster*            上一次的质心* @return boolean* */public boolean isEqual(Map<Point, ArrayList<Point>> lastCenterCluster,Map<Point, ArrayList<Point>> nowCenterCluster) {boolean contain = false;if (lastCenterCluster == null)return false;else {for (Point point : nowCenterCluster.keySet()) {contain = lastCenterCluster.containsKey(point);}if (contain)return true;}return false;}/*** * 计算新的质心* * @param value*            map中的值,存放簇中的所有点* @return point* */public Point getNewCenter(ArrayList<Point> value) {double sumX = 0, sumY = 0, sumZ = 0, sumW = 0;for (Point point : value) {sumX += point.getX();sumY += point.getY();sumZ += point.getZ();sumW += point.getW();}System.out.println("新的质心: (" + sumX / value.size() + "," + sumY/ value.size() + "," + sumZ / value.size() + "," + sumW/ value.size() + ")");Point point = new Point();point.setX(sumX / value.size());point.setY(sumY / value.size());point.setZ(sumZ / value.size());point.setW(sumW / value.size());return point;}
}

Kmeans.java

木羊把簇存在hashmap里,其中key存放该簇的质心,value存放该簇的所有点。特别注意的是,为了使最终聚类相对较理想,随机选择的三个初始质心应该在[0-50)、[50-100)、[100-150]三个区间内。

package util;import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;import model.Point;
import dao.SelectData;public class Kmeans {public Map<Point, ArrayList<Point>> executeKmeans(int k) {ArrayList<Point> dataList = new ArrayList<Point>();// 存放原始数据Map<Point, ArrayList<Point>> nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();// 当前质心及其簇内的点Map<Point, ArrayList<Point>> lastCenterClusterMap = null;// 上一个质心及其簇内的点try {dataList = new SelectData().getPoints();// 随机创建K个点作为起始质心Random rd = new Random();int[] initIndex = { 50, 50, 50 };int[] tempIndex = { 0, 50, 100 };System.out.println("起始质心下标: ");for (int i = 0; i < k; i++) {int index = rd.nextInt(initIndex[i]) + tempIndex[i];System.out.println("第" + (i + 1) + "个 : " + index);nowCenterClusterMap.put(dataList.get(index),new ArrayList<Point>());}// 输出起始质心System.out.println("起始质心: ");for (Point point : nowCenterClusterMap.keySet())System.out.println("key:  " + point);// 将数据点point加入配到离其最近的map的value中ManagePoint managePoint = new ManagePoint();while (true) {for (Point point : dataList) {double shortestDistance = Double.MAX_VALUE;// 初始化最短距离为Double的最大值Point key = null;for (Entry<Point, ArrayList<Point>> entry : nowCenterClusterMap.entrySet()) {// 计算质心与各点间的距离double distance = managePoint.getDistance(entry.getKey(), point);if (distance < shortestDistance) {shortestDistance = distance;key = entry.getKey();}}nowCenterClusterMap.get(key).add(point);}// 如果新的质心与上次的质心相等,则退出整个循环if (managePoint.isEqual(lastCenterClusterMap,nowCenterClusterMap)) {System.out.println("相等了。");break;}// 更新质心lastCenterClusterMap = nowCenterClusterMap;nowCenterClusterMap = new HashMap<Point, ArrayList<Point>>();System.out.println("------------------------------------------------------------------");for (Entry<Point, ArrayList<Point>> entry : lastCenterClusterMap.entrySet()) {nowCenterClusterMap.put(managePoint.getNewCenter(entry.getValue()),new ArrayList<Point>());}}} catch (SQLException e) {// TODO Auto-generated catch blockSystem.out.println("数据库操作失败");e.printStackTrace();}return nowCenterClusterMap;}public static void main(String[] args) {int K = 3;// 分为三个类Map<Point, ArrayList<Point>> result = new Kmeans().executeKmeans(K);// 输出分类System.out.println("===========聚类结果: ============");for (Entry<Point, ArrayList<Point>> entry : result.entrySet()) {System.out.println("\n" + "稳定的质心: " + entry.getKey());System.out.println("该簇的大小: " + entry.getValue().size());System.out.println("簇里的点:" + entry.getValue());}}
}

以上代码均从MyEclipse上复制粘贴而来,亲测可运行,结果如下:

经测试,无论初始质心被随机选择成哪3个,最终稳定的质心都不变。

(欢迎讨论。代码尚有不完善之处,请多多指教。转载请注明出处。)

java实现k-means算法(用的鸢尾花iris的数据集,从mysq数据库中读取数据)相关推荐

  1. java中unicode显示乱码_Java 已知Java系统编码是GBK,jtextarea从一编码为Unicode的文本中读取数据,出现乱码,怎么正常显示?...

    Java 已知Java系统编码是GBK,jtextarea从一编码为Unicode的文本中读取数据,出现乱码,怎么正常显示? 关注:159  答案:2  mip版 解决时间 2021-02-03 12 ...

  2. kmeans改进 matlab,基于距离函数的改进k―means 算法

    摘要:聚类算法在自然科学和和社会科学中都有很普遍的应用,而K-means算法是聚类算法中经典的划分方法之一.但如果数据集内相邻的簇之间离散度相差较大,或者是属性分布区间相差较大,则算法的聚类效果十分有 ...

  3. 算法2.2 已知线性表LA和LB中的数据元素按值非递减有序排列,现要求将LA和LB归并为一个新的线性表LC,且LC中的数据元素仍按值非递减有序排列。

    数据结构(C语言版)严蔚敏 吴伟民 算法2.2 已知线性表LA和LB中的数据元素按值非递减有序排列,现要求将LA和LB归并为一个新的线性表LC,且LC中的数据元素仍按值非递减有序排列.例如,设 LA= ...

  4. 用java向mysql数据库中插入数据为空

    利用java面向对像编程,向数据库中插入数据时.遇到插入的数据为空的情况.在此做一小结: 1.数据库连接正正常 2.sql语句没有问题 3.程序没有报异常 4.代码: import java.util ...

  5. Java实现Excel导入数据库,数据库中的数据导入到Excel

    前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家.点击跳转到教程. 实现的功能: Java实现Excel导入数据库,如果存在就更新 数据库中的数据导入到Excel 1. ...

  6. java将数据写入csv文件,从csv文件中读取数据

    全栈工程师开发手册 (作者:栾鹏) java教程全解 java将数据写入csv文件,从csv文件中读取数据 测试代码 public static void main(String[] arges){/ ...

  7. Java向数据库中插入数据出错时怎么避免插入错误数据

    Java向数据库中插入数据出错时怎么避免插入错误数据 对于初学者,向数据库写数据时,当程序输错,会有错误的数据写入了数据库,这是可以用捕获异常回滚的方法避免这种情况的发生 代码如下: /** 完成增删 ...

  8. java显示数据库_java查询数据库中的数据并显示

    java查询数据库中的数据并显示 关注:93  答案:2  mip版 解决时间 2021-01-17 16:29 提问者笑低了眉眼 2021-01-17 04:11 button.addSelecti ...

  9. mye连接mysql数据库_MySQL_如何在Java程序中访问mysql数据库中的数据并进行简单的操作,在上篇文章给大家介绍了Myeclip - phpStudy...

    如何在Java程序中访问mysql数据库中的数据并进行简单的操作 在上篇文章给大家介绍了Myeclipse连接mysql数据库的方法,通过本文给大家介绍如何在Java程序中访问mysql数据库中的数据 ...

最新文章

  1. 五分钟理解yield在python中的简单用法,让你不再迷惑
  2. 2021年春季学期-信号与系统-第一次作业参考答案-第五题
  3. python数字列表in_Python入门基础之数字字符串与列表
  4. 跟多导出数据库的方法
  5. MFC工作笔记0003---WindowsAPI与MFC的关系
  6. input子系统驱动学习之中的一个
  7. centos绑定多个ip CentOS一个网卡设置多个IP
  8. 视频教程-CCNA之TCP/IP协议栈精讲-思科认证
  9. 计算机控制技术毕业论文题目,计算机控制方面论文选题 计算机控制论文题目怎样定...
  10. 《黑客之道》- 全网最详细的kali系统安装教程
  11. 一文读懂 delete和delete[ ]
  12. Peoplesoft Pentest
  13. JavaScript学习第十九天
  14. CSAPP:第二章——信息的表示和处理
  15. 破解ESX主机ROOT帐户密码。
  16. 如何开启任务计划程序
  17. mac时间机器占用大量系统盘空间且在访达中无法找到
  18. 欧拉-伯努利梁横向振动2
  19. C语言常见问题——++i与i++详解
  20. 苹果icloud邮箱抓取

热门文章

  1. HDU4405 期望
  2. 如何在子网中访问上层网络的计算机文件夹
  3. SmartFoxServer学习总结(转载)
  4. 基于android平台的24点游戏设计与实现需求分析,基于Android平台的24点游戏设计与实现需求分析_毕业设计论文.doc...
  5. swagger内部类_API管理工具Swagger介绍及Springfox原理分析
  6. cx oracle 配置,cx_Oracle的配置啊。。终于搞出来了
  7. mysql触发器区分新增 修改_MySQL触发器 , 判断更新操作前后数据是否改变
  8. java 入参 是 枚举_java 枚举 参数传递
  9. oracle 快照用途,Oracle快照原理及实现总结
  10. error: storage size of ‘threads’ isn’t known