逻辑回归算法识别Minst手写集

  • 逻辑回归算法(LR)及其基本概念
    • 线性回归
    • 逻辑回归
    • “神奇的函数”
      • 预测函数
      • 损失函数
  • Minst手写集
  • 代码实现
    • 具体代码:
    • 运行结果
    • TIPS:
  • 总结

逻辑回归算法(LR)及其基本概念

线性回归

线性回归 是利用连续性变量来估计实际数值(例如房价,呼叫次数和总销售额等)。我们通过线性回归算法找出自变量和因变量间的最佳线性关系,图形上可以确定一条最佳直线。这条最佳直线就是回归线。这个回归关系可以用Y=aX+b 表示。
说人话就是通过已知的数据来拟合一条直线,使数据尽可能落在这条直线上,并且这条直线能对接下来的未知数据进行预测。
举个例子就是像高中统计学的最小二乘法求回归方程

逻辑回归

逻辑回归 其实是一个分类算法而不是回归算法。通常是利用已知的自变量来预测一个离散型因变量的值(像二进制值0/1,是/否,真/假)。简单来说,它就是通过拟合一个逻辑函数(logit fuction)来预测一个事件发生的概率。所以它预测的是一个概率值,自然,它的输出值应该在0到1之间。

说人话就是分类(大多数是分两类)。通过一个神奇的函数来预测接下来未知信息发生的概率。重点:逻辑回归的结果是一个概率

“神奇的函数”

预测函数

找一个预测函数,该函数就是我们需要找的分类函数,它用来预测输入数据的判断结果。这个过程时非常关键的,需要对数据有一定的了解或分析,知道或者猜测预测函数的“大概”形式,比如是线性函数还是非线性函数。

看不懂吧,看不懂就对了。我也看不懂。
它大概意思是说我们要去先确定一个可以分类的函数,拿Minst举例就是说找到一个函数可以把他分成10类。这篇文章我们用的是sigmoid函数

下面是函数表达式和函数图像

LR模型的主要任务是给定一些历史的{X,Y},其中X是样本n个特征值(自变量),Y的取值是{0,1}代表正例与负例,通过对这些历史样本的学习,从而得到一个数学模型,给定一个新的X,能够预测出Y。

从函数图像我们可以看出来,当特征值为0时,对应的概率是0.5.因此LR可以得到一个事件发生的可能性,超过50%则认为事件发生,低于50%则认为事件不发生。

举个例子假如说我们要是识别一个手写数字是不是数字5,结果只有两个, 还是 不是 如果sigmoid的概率大于0.5 那他就是5;反之就不是

损失函数

Minst手写集损失函数loss function)是用来估量模型的预测值f(x)与真实值Y的不一致程度,它是一个非负实值函数,通常使用L(Y, f(x))来表示,损失函数越小。

看不懂吧,看不懂就对了,我也看不懂

说人话就是:
损失函数定义是:衡量模型模型预测的好坏

可能这么说有点小小的抽象 ,那么再解释下,损失函数就是用来表现预测与实际数据的差距程度

比如你高中学的回归方程,真实值和你用回归方程求出来的预测值肯定会有误差,那么我们找到一个函数表达这个误差就是损失函数

因为Minst识别是多分类问题
这里我们用log对数损失函数

图片来自https://blog.csdn.net/u013069552/article/details/113804323?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164931446116782094892055%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=164931446116782094892055&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_click~default-2-113804323.142v6pc_search_result_control_group,157v4control&utm_term=%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0&spm=1018.2226.3001.4187

Minst手写集

MNIST 是一个入门级计算机视觉数据集,包含了很多手写(0~9)数字图片。其中Minst手写集训练样本:共60000个,其中55000个用于训练,另外5000个用于验证,测试样本:共10000个

这里是官网链接:
如下图所示,为四张MNIST图片.

这里是Minst的官网:戳我去Minst官网下载


进去之后我们可以发现有四个可以下载的压缩包(都要下,然后解压使用)
分别是训练集图片、训练集标签、测试集图片、测试集标签
其中标签就是图片所对应的数字,用来检测和训练的

关于Minst的相关知识在这里不多做介绍
可以去这个博客看一下Minst数据集介绍
我们需要知道的是一下几点:
1.每张手写图片是由28*28(786)像素组成的图片
2. 文件(那个.gz文件)头信息,依次为魔数、图片数量、每张图片高、每张图片宽

代码实现

下面就直接把代码给大家参考一下,有一些问题我写在注释里了,希望会对大家有帮助

具体代码:

# -*- coding = utf-8 -*-
"""
#@Time:2022年4月1日
#@Author:YuDai
#@File:逻辑回归算法minst.py
#@Software : pycharm
"""from numpy import *
import numpy as np
import time
from scipy.special import expit     # logistic sigmoid函数 是一个逻辑回归算法里需要用到的函数(贼复杂,看都看不懂)
import struct  # 这是一个如何定义格式字符串的库
import math# 读取图片
def read_image(file_name):# 先用二进制方式把文件都读进来file_handle=open(file_name,"rb")  # 以二进制打开文档file_content=file_handle.read()   # 读取到缓冲区中offset=0# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽 取前4个整数,返回一个元组head = struct.unpack_from('>IIII', file_content, offset)offset += struct.calcsize('>IIII')imgNum = head[1]  # 图片数rows = head[2]   # 宽度cols = head[3]  # 高度images=np.empty((imgNum, 784))     # empty,是它所常见的数组内的所有元素均为空,没有实际意义,它是创建数组最快的方法image_size=rows*cols        # 单个图片的大小# 图像数据像素值的类型为unsigned char型,对应的format格式为B。# 这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)fmt='>' + str(image_size) + 'B'     # 单个图片的formatfor i in range(imgNum):images[i] = np.array(struct.unpack_from(fmt, file_content, offset))offset += struct.calcsize(fmt)return images# 读取标签
def read_label(file_name):file_handle = open(file_name, "rb")  # 以二进制打开文档file_content = file_handle.read()  # 读取到缓冲区中# 和上面一样解析数据head = struct.unpack_from('>II', file_content, 0)  # 取前2个整数,返回一个元组offset = struct.calcsize('>II')labelNum = head[1]  # label数# print(labelNum)bitsString = '>' + str(labelNum) + 'B'  # fmt格式:'>47040000B'label = struct.unpack_from(bitsString, file_content, offset)  # 取data数据,返回一个元组return np.array(label)def loadDataSet():train_x_filename=r"这里是训练集图片的路径"train_y_filename=r"这里是训练集标签的路径"test_x_filename=r"这里是测试集图片的路径"test_y_filename=r"这里是测试集标签的路径"train_x=read_image(train_x_filename)train_y=read_label(train_y_filename)test_x=read_image(test_x_filename)test_y=read_label(test_y_filename)# # # #调试的时候让速度快点,就先减少数据集大小# train_x=train_x[0:1000,:]# train_y=train_y[0:1000]# test_x=test_x[0:500,:]# test_y=test_y[0:500]return train_x, test_x, train_y, test_y# 从这开始就是有关sigmoid函数的运算
# https://blog.csdn.net/qq_39783601/article/details/105557388?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164924372916780366571823%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=164924372916780366571823&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-2-105557388.142^v5^pc_search_result_cache,157^v4^control&utm_term=sigmoid&spm=1018.2226.3001.4187
# Sigmoid函数的图像一般来说并不直观,我理解的是对数值越大,函数越逼近1,数值越小,函数越逼近0,将数值结果转化为了0到1之间的概率
# sigmoid函数
def sigmoid(inX):return 1.0/(1+exp(-inX))
# 预测函数
def classifyVector(inX,weights):    # 这里的inX相当于test_data,以回归系数和特征向量作为输入来计算对应的sigmoidprob=sigmoid(sum(inX*weights))if prob>0.5:return 1.0else: return 0.0
# 训练模型
def train_model(train_x, train_y, theta, learning_rate, iterationNum, numClass):    # theta是n+1行的列向量m=train_x.shape[0]n=train_x.shape[1]train_x=np.insert(train_x,0,values=1,axis=1)J_theta = np.zeros((iterationNum,numClass))for k in range(numClass):# print(k)real_y=np.zeros((m,1))index=train_y==k    # index中存放的是train_y中等于0的索引real_y[index]=1     # 在real_y中修改相应的index对应的值为1,先分类0和非0for j in range(iterationNum):# print(j)temp_theta = theta[:,k].reshape((785,1))# h_theta=expit(np.dot(train_x,theta[:,k]))#是m*1的矩阵(列向量),这是概率h_theta = expit(np.dot(train_x, temp_theta)).reshape((60000,1))# 这里的一个问题,将train_y变成0或者1J_theta[j,k] = (np.dot(np.log(h_theta).T,real_y)+np.dot((1-real_y).T,np.log(1-h_theta))) / (-m)temp_theta = temp_theta + learning_rate*np.dot(train_x.T,(real_y-h_theta))# theta[:,k] =learning_rate*np.dot(train_x.T,(real_y-h_theta))theta[:, k] = temp_theta.reshape((785,))return theta    # 返回的theta是n*numClass矩阵def predict(test_x,test_y,theta,numClass):# 这里的theta是学习得来的最好的theta,是n*numClass的矩阵errorCount=0test_x = np.insert(test_x, 0, values=1, axis=1)m = test_x.shape[0]h_theta=expit(np.dot(test_x,theta))#h_theta是m*numClass的矩阵,因为test_x是m*n,theta是n*numClassh_theta_max = h_theta.max(axis=1)  # 获得每行的最大值,h_theta_max是m*1的矩阵,列向量h_theta_max_postion=h_theta.argmax(axis=1) # 获得每行的最大值的labelfor i in range(m):if test_y[i]!=h_theta_max_postion[i]:errorCount+=1error_rate = float(errorCount) / mprint("error_rate", error_rate)return error_ratedef mulitPredict(test_x,test_y,theta,iteration):numPredict=10errorSum=0for k in range(numPredict):errorSum+=predict(test_x,test_y,theta,iteration)print("after %d iterations the average error rate is:%f" % (numPredict, errorSum / float(numPredict)))if __name__=='__main__':print("开始读入训练数据。。。")time1=time.time()train_x, test_x, train_y, test_y = loadDataSet()time2=time.time()print("读入时间cost:",time2-time1,"second")numClass=10 # 控制列向量iteration = 1000 #这里可以改迭代的次数,不同的迭代次数错误率也不一样,迭代次数越多,error 就越小learning_rate = 0.001 # 学习率n=test_x.shape[1]+1theta=np.zeros((n,numClass))# theta=np.random.rand(n,1)#随机构造n*numClass的矩阵,因为有numClass个分类器,所以应该返回的是numClass个列向量(n*1)print("开始训练数据。。。")theta_new = train_model(train_x, train_y, theta, learning_rate, iteration,numClass)time3 = time.time()print("训练时间cost:", time3 - time2, "second")print("开始预测数据。。。。")predict(test_x, test_y, theta_new,iteration)time4=time.time()print("预测时间cost",time4-time3,"second")

运行结果

迭代1次


迭代10次

迭代100次

迭代1000次

迭代1000次的时候跑了快五分钟,就没法演示了

TIPS:

iteration = 100     # 迭代次数
learning_rate = 0.001   # 学习率

关于这两个参数是可以自己调整的
iteration是迭代次数,说白了就是让程序学几遍,学的越多次,准确率自然高

learning_rate 是学习效率,这个数也可以自己调整,但是并不是字面意思越高越好。举个例子:

假设这个小人的目标要走到谷底,那么learning_rate就好比步长,最理想的状态当然是用最少的步数直接到达目的地。但是如果步子迈的太大就会错过谷底。所以这里建议learning_rate不要设太大0.01或者0.001即可

总结

我是纯纯小白,写这个也是为了交作业(bushi)
所以有好多东西没有理解透,比如梯度下降,onehot。。。balabala

逻辑回归算法识别Minst手写集相关推荐

  1. tkinter+socket&MySQL+keras识别minst手写数字

    tkinter + socket + keras + MySQL识别Minst手写数字 环境配置 代码 服务端 客户端 主函数main.py 类Window.py 实验报告部分 一.总体功能说明 1. ...

  2. BP算法实现--minst手写数字数据集识别

    实验步骤 初始化网络架构 网络层数,每层神经元数,连接神经元的突触权重,每个神经元的偏置 构造bp算法函数 对于一个输入数据,前向计算每层的输出值,保存未激活的输出和激活过的输出值,这里用的激活函数是 ...

  3. Tensorflow卷积神经网络识别MINST手写数字

    开发环境: Ubuntu16.04+Tensorflow1.5.0-GPU+CUDN9.0+CUDNN7.0 如果是Debian系列的系统,请参考这篇博客进行安装. 所有完整代码的github地址为: ...

  4. TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)

    TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...

  5. 机器学习算法 08 —— 支持向量机SVM算法(核函数、手写数字识别案例)

    文章目录 系列文章 支持向量机SVM算法 1 SVM算法简介 1.1 引入 1.2 算法定义 2 SVM算法原理 2.1 线性可分支持向量机 2.2 SVM计算过程与算法步骤(有点难,我也没理解透,建 ...

  6. 基于一个线性层的softmax回归模型和MNIST数据集识别自己手写数字

    原博文是用cnn识别,因为我是在自己电脑上跑代码,用不了处理器,所以参考Mnist官网上的一个线性层的softmax回归模型的代码,把两篇文章结合起来识别. 最后效果 源代码识别mnist数据集的准确 ...

  7. TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类

    TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类 目录 设计思路 实现代码 设计思路 更新-- 实现代码 # -*- coding:utf-8 -*- import ten ...

  8. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  9. DL之LiRDNNCNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测

    DL之LiR&DNN&CNN:利用LiR.DNN.CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测 目录 输出结果 设计思路 核心代码 输出结果 数据集:Da ...

  10. DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练、预测(95%)

    DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练.预测(95%) 目录 数据集展示 输出结果 设计代码 数据集展示 先查看sklearn自带di ...

最新文章

  1. mac mysql的安装
  2. DDD分层架构最佳实践
  3. c++将小数化为二进制_C++版进制转换(十进制,二进制,十六进制整数和小数)
  4. 将表中的值变成字段显示
  5. JMS - QueueBrowser
  6. leetcode-寻找两个正序数组的中位数
  7. linux java amr转mp3_本工具用于将微信语音 amr 格式转换为 mp3 格式以便在 html5 的 audio 标签中进行播放...
  8. 使用cronolog-1.6.2按日期截取Tomcat日志
  9. IDA远程调试Android
  10. linux重启配置文件,rEFInd启动管理器配置文件详解
  11. 基于汇编的 C/C++ 协程 - 背景知识
  12. csdn账号不能合并
  13. Jetson-TX2双声卡TLV320AIC32x4 alsa实现同时录音与播放
  14. Cesium原理篇:5最长的一帧之影像
  15. 软件工程作业一:从产品经理人角度分析微信求职招聘小程序
  16. node抓取58同城信息_如何使用标准库和Node.js轻松抓取网站以获取信息
  17. 基于双层优化的微电网系统规划设计方法matlab程序(yalmip+cplex)
  18. unionpay 云闪付小程序开发包
  19. office2016官方下载 免费完整版
  20. GAN Step By Step -- Step4 CGAN

热门文章

  1. [转]coolfire黑客入门教程系列之(五)
  2. windows下 gcc 下载及使用指南
  3. eNsp——Vlan
  4. Java高级工程师面试总结
  5. WPS简历模板的图标怎么修改_新媒体运营-简历模板范文,【工作经历+项目经验+自我评价】怎么写?...
  6. 统信系统UOS桌面版V20 用户手册
  7. 陕西省地形图与陕西地形高程数据DEM下载
  8. maya mentray_mental ray2016中文版下载|
  9. Java 泛型的实例化总结
  10. C语言学习笔记-1(资料:郝斌老师C语言视频)