实现代码:import structfrom numpy import *import numpy as npimport timedef read_image(file_name):    #先用二进制方式把文件都读进来    file_handle=open(file_name,"rb")  #以二进制打开文档    file_content=file_handle.read()   #读取到缓冲区中    offset=0    head = struct.unpack_from('>IIII', file_content, offset)  # 取前4个整数,返回一个元组    offset += struct.calcsize('>IIII')    imgNum = head[1]  #图片数    rows = head[2]   #宽度    cols = head[3]  #高度

    images=np.empty((imgNum , 784))#empty,是它所常见的数组内的所有元素均为空,没有实际意义,它是创建数组最快的方法    image_size=rows*cols#单个图片的大小    fmt='>' + str(image_size) + 'B'#单个图片的format

    for i in range(imgNum):        images[i] = np.array(struct.unpack_from(fmt, file_content, offset))        # images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))        offset += struct.calcsize(fmt)    return images

#读取标签def read_label(file_name):    file_handle = open(file_name, "rb")  # 以二进制打开文档    file_content = file_handle.read()  # 读取到缓冲区中

    head = struct.unpack_from('>II', file_content, 0)  # 取前2个整数,返回一个元组    offset = struct.calcsize('>II')

    labelNum = head[1]  # label数    # print(labelNum)    bitsString = '>' + str(labelNum) + 'B'  # fmt格式:'>47040000B'    label = struct.unpack_from(bitsString, file_content, offset)  # 取data数据,返回一个元组    return np.array(label)

def loadDataSet():    #mnist    train_x_filename="train-images-idx3-ubyte"    train_y_filename="train-labels-idx1-ubyte"    test_x_filename="t10k-images-idx3-ubyte"    test_y_filename="t10k-labels-idx1-ubyte"

    # #fashion mnist    # train_x_filename="fashion-train-images-idx3-ubyte"    # train_y_filename="fashion-train-labels-idx1-ubyte"    # test_x_filename="fashion-t10k-images-idx3-ubyte"    # test_y_filename="fashion-t10k-labels-idx1-ubyte"

    train_x=read_image(train_x_filename)#60000*784 的矩阵    train_y=read_label(train_y_filename)#60000*1的矩阵    test_x=read_image(test_x_filename)#10000*784    test_y=read_label(test_y_filename)#10000*1

    train_x=normalize(train_x)    test_x=normalize(test_x)    # #调试的时候让速度快点,就先减少数据集大小    # train_x=train_x[0:1000,:]    # train_y=train_y[0:1000]    # test_x=test_x[0:500,:]    # test_y=test_y[0:500]

    return train_x, test_x, train_y, test_y

def  normalize(data):#图片像素二值化,变成0-1分布    m=data.shape[0]    n=np.array(data).shape[1]    for i in range(m):        for j in range(n):            if data[i,j]!=0:                data[i,j]=1            else:                data[i,j]=0    return data

#(1)计算先验概率及条件概率def train_model(train_x,train_y,classNum):#classNum是指有10个类别,这里的train_x是已经二值化,    m=train_x.shape[0]    n=train_x.shape[1]    # prior_probability=np.zeros(n)#先验概率    prior_probability=np.zeros(classNum)#先验概率    conditional_probability=np.zeros((classNum,n,2))#条件概率    #计算先验概率和条件概率    for i in range(m):#m是图片数量,共60000张        img=train_x[i]#img是第i个图片,是1*n的行向量        label=train_y[i]#label是第i个图片对应的label        prior_probability[label]+=1#统计label类的label数量(p(Y=ck),下标用来存放label,prior_probability[label]除以n就是某个类的先验概率        for j in range(n):#n是特征数,共784个            temp=img[j].astype(int)#img[j]是0.0,放到下标去会显示错误,只能用整数

            conditional_probability[label][j][temp] += 1

            # conditional_probability[label][j][img[j]]+=1#统计的是类为label的,在每个列中为1或者0的行数为多少,img[j]的值要么就是0要么就是1,计算条件概率

    #将概率归到[1.10001]    for i in range(classNum):        for j in range(n):            #经过二值化的图像只有0,1两种取值            pix_0=conditional_probability[i][j][0]            pix_1=conditional_probability[i][j][1]

            #计算0,1像素点对应的条件概率            probability_0=(float(pix_0)/float(pix_0+pix_1))*10000+1            probability_1 = (float(pix_1)/float(pix_0 + pix_1)) * 10000 + 1

            conditional_probability[i][j][0]=probability_0            conditional_probability[i][j][1]=probability_1    return prior_probability,conditional_probability

#(2)对给定的x,计算先验概率和条件概率的乘积def cal_probability(img,label,prior_probability,conditional_probability):    probability=int(prior_probability[label])#先验概率    n=img.shape[0]    # print(n)    for i in range(n):#应该是特征数        probability*=int(conditional_probability[label][i][img[i].astype(int)])

    return probability

#确定实例x的类,相当于argmaxdef predict(test_x,test_y,prior_probability,conditional_probability):#传进来的test_x或者是train_x都是二值化后的    predict_y=[]    m=test_x.shape[0]    n=test_x.shape[1]    for i in range(m):        img=np.array(test_x[i])#img已经是二值化以后的列向量        label=test_y[i]        max_label=0        max_probability= cal_probability(img,0,prior_probability,conditional_probability)        for j in range(1,10):#从下标为1开始,因为初始值是下标为0            probability=cal_probability(img,j,prior_probability,conditional_probability)            if max_probability<probability:                max_probability=probability                max_label=j        predict_y.append(max_label)#用来记录每行最大概率的label    return np.array(predict_y)

def cal_accuracy(test_y,predict_y):    m=test_y.shape[0]    errorCount=0.0    for i in range(m):        if test_y[i]!=predict_y[i]:            errorCount+=1    accuracy=1.0-float(errorCount)/m    return accuracy

if __name__=='__main__':    classNum=10    print("Start reading data...")    time1=time.time()    train_x, test_x, train_y, test_y=loadDataSet()    train_x=normalize(train_x)    test_x=normalize(test_x)

    time2=time.time()    print("read data cost",time2-time1,"second")

    print("start training data...")    prior_probability, conditional_probability=train_model(train_x,train_y,classNum)    for i in range(classNum):        print(prior_probability[i])#输出一下每个标签的总共数量    time3=time.time()    print("train data cost",time3-time2,"second")

    print("start predicting data...")    predict_y=predict(test_x,test_y,prior_probability,conditional_probability)    time4=time.time()    print("predict data cost",time4-time3,"second")

    print("start calculate accuracy...")    acc=cal_accuracy(test_y,predict_y)    time5=time.time()    print("accuarcy",acc)    print("calculate accuarcy cost",time5-time4,"second")结果截图:输出的5923.0.。。这些是我输出一下每个类别的图片有几张。

调用自己写的朴素贝叶斯函数正确率是84.12%,调用sklearn中的BernoulliNB函数,正确率是84.27%

调用sklearn中的BernoulliNB函数的代码如下:

结果截屏:

优化:加入主成分分析方法,进行降维操作,代码如下:

结果截屏:

待修改中!

参考链接;https://blog.csdn.net/wds2006sdo/article/details/51967839

转载于:https://www.cnblogs.com/BlueBlue-Sky/p/9382577.html

python朴素贝叶斯分类MNIST数据集相关推荐

  1. python朴素贝叶斯分类器实现_用scikit-learn实现朴素贝叶斯分类器

    朴素贝叶斯(Naive Bayes Classifier)是一种「天真」的算法(假定所有特征发生概率是独立的),同时也是一种简单有效的常用分类算法.关于它的原理,参见朴素贝叶斯分类器的应用. scik ...

  2. python朴素贝叶斯分类示例_Python实现的朴素贝叶斯分类器示例

    本文实例讲述了Python实现的朴素贝叶斯分类器.分享给大家供大家参考,具体如下: 因工作中需要,自己写了一个朴素贝叶斯分类器. 对于未出现的属性,采取了拉普拉斯平滑,避免未出现的属性的概率为零导致整 ...

  3. [转载] 朴素贝叶斯python实现预测_Python实现朴素贝叶斯分类器的方法详解

    参考链接: Python朴素贝叶斯分类器 本文实例讲述了Python实现朴素贝叶斯分类器的方法.分享给大家供大家参考,具体如下: 贝叶斯定理 贝叶斯定理是通过对观测值概率分布的主观判断(即先验概率)进 ...

  4. 朴素贝叶斯分类器之分类实操

    python 朴素贝叶斯分类器之分类实操 基本概念 鲁棒性 Huber从稳健统计的角度系统地给出了鲁棒性3个层面的概念: 1.是模型具有较高的精度或有效性,这也是对于机器学习中所有学习模型的基本要求: ...

  5. 多项式朴素贝叶斯分类器_多项式朴素贝叶斯分类器的主题预测

    多项式朴素贝叶斯分类器 In Analytics Vidhya, Hackathon, there was a problem statement for text prediction of top ...

  6. 基于朴素贝叶斯分类器的西瓜数据集 2.0 预测分类_第十章:利用Python实现朴素贝叶斯模型

    免责声明:本文是通过网络收集并结合自身学习等途径合法获取,仅作为学习交流使用,其版权归出版社或者原创作者所有,并不对涉及的版权问题负责.若原创作者或者出版社认为侵权,请联系及时联系,我将立即删除文章, ...

  7. 基于jupyter notebook的python编程-----MNIST数据集的的定义及相关处理学习

    基于jupyter notebook的python编程-----MNIST数据集的相关处理 一.MNIST定义 1.什么是MNIST数据集 2.python如何导入MNIST数据集并操作 3.接下来, ...

  8. 朴素贝叶斯分类器原理解析与python实现

    贝叶斯分类器是以贝叶斯原理为基础的分类器的总称,是一种生成式模型,朴素贝叶斯分类器是其中最简单的一种.要高明白贝叶斯分类器的原理,首先得明白一些基本概念. 预备知识 基本概念 先验概率:根据统计/经验 ...

  9. 朴素贝叶斯python实现预测_Python实现朴素贝叶斯分类器的方法详解

    本文实例讲述了Python实现朴素贝叶斯分类器的方法.分享给大家供大家参考,具体如下: 贝叶斯定理 贝叶斯定理是通过对观测值概率分布的主观判断(即先验概率)进行修正的定理,在概率论中具有重要地位. 先 ...

最新文章

  1. k8s入门系列之介绍篇
  2. 还找不到想要的文章吗?公众号搜索方法大全
  3. spoj Simple Average
  4. php单词出现频率,PHP编程计算文件或数组中单词出现频率的方法
  5. 为了OFFER系列 | 牛客网美团点评数据分析刷题
  6. X509Certificate2 本地正常,放到线上内部错误
  7. JS+CSS实现Dock menu(MacOS菜单导航效果)
  8. Android GsmCellLocation.getCellLocation返回NULL
  9. 阿里巴巴连接池mysql_阿里巴巴连接池(Druid)
  10. 影响英语单词拼写的6大因素
  11. 用一段CSS代码找回属于童年的哆啦A梦欢度六一附源码在线展示
  12. 面试官说:Spring这几个问题你回答下,月薪3万,下周来上班!
  13. 文件锁(二)——文件锁的读锁和写锁
  14. Word插入Latex公式的几种方式~(TeXsWord、EqualX、Aurora、向Office插入LaTeX公式的工具)...
  15. 定西稳定高速的服务器,中国移动宽带甘肃定西的dns服务器地址
  16. 菲尔普斯:200自决赛会很艰苦 满意预赛成绩
  17. html5 bdi 不起作用,html bdi标签的使用详解
  18. 第三人称和第一人称互相切换【Low版】
  19. 第二节 数据CRUD操作与连接查询和子查询(包含练习)
  20. passing 'unsigned char [150]' to parameter of type 'char *' converts between pointers to integer typ

热门文章

  1. 没人说得清深度学习的原理 只是把它当作一个黑箱来使
  2. FreeTextBox 3.1.6 的实践总结和几个问题
  3. 深度学习《Photo Editing》
  4. 调参方法论:如何提高机器学习模型的性能?
  5. Syslink Control使用技巧
  6. IDEA远程部署调试Java应用程序
  7. Java7并发编程指南——第三章:线程同步辅助类
  8. 725 - Division
  9. QT Basic 014 Model/View programming (模型、视图编程)
  10. python进程池一个进程卡住_python进程池,每个进程都有超时