NumPy实现简单的神经网络分析Mnist手写数字库(三)之划分迷你批(mini-batch)
NumPy实现简单的神经网络分析Mnist手写数字库(三)之划分迷你批(mini-batch)
- 划分迷你批(mini-batch)
- 引言
- 迷你批(mini-batch)简介
- 经典梯度下降
- 随机梯度下降
- 迷你批梯度下降
- 划分迷你批
- 迷你批的使用要点
- 迷你批的划分
- (1)数据集随机打乱(shuffle)
- (2)分割数据集
- (3)在主函数中调用get_mini_batches()
- 小结
划分迷你批(mini-batch)
引言
在上一节-数据预处理中,
我们对读取到的Mnist手写数字数据进行了预处理。在本节中,我们将把预处理过的数据划分为迷你批(mini-batch)。
迷你批(mini-batch)简介
经典梯度下降
在机器学习中,最著名的优化算法莫过于梯度下降法(Gradient Descent)。早期人们会直接把整个数据集喂给算法。
(1)优点:准确,每一次迭代之后的代价函数(cost function)一定不会比迭代之前的高。
(2)缺点:就是计算成本大。导致训练的速度慢。
随机梯度下降
与之对应的就是随机梯度下降(Stochastic Gradient Descent)。每次只对一个样例(example)做梯度下降。而所谓随机,就是训练的方向随机。不再沿着梯度变化最大的方向。但是大体趋势是朝着代价函数的极小值前进。
(1)优点:一次梯度下降的运算量小
(2)不能收敛到极小值,只能在附近徘徊。训练的过程随机,走了很多弯路。
迷你批梯度下降
介于以上两者之间的就是迷你批梯度下降(mini-batch Gradient Descent)。进行一次梯度下降的单位是一个迷你批。迷你批的大小在1和整个数据的样本数之间。往往取2的幂次,诸如64,256,1024等。这个大小也是一个超参数。
(1)优点:收敛快,运算量小。
(2)缺点:需要提前划分迷你批;多了一个超参数,增加了模型复杂性。
不过总体来说利大于弊,是三者中最好的方法。
划分迷你批
迷你批的使用要点
1.迷你批是整个数据集互斥的子集
2.大小几乎都相同,除了数据集大小不能被迷你批大小整除的情况,会有一个迷你批稍短
3.周期(epoch)是指遍历整个数据集的过程。每经过一个周期,就需要重新随机划分迷你批。
迷你批的划分
先新建一个Python文件
"""
mini_batch.py打乱,分割数据集,返回迷你批"""
import numpy as np
(1)数据集随机打乱(shuffle)
def shuffle(X, Y):"""打乱数据集(X,Y)参数:X -- 图像数据,float32类型的矩阵Y -- 独热(one-hot)标签,uint8类型的矩阵返回:shuffles -- 字典,{"X_shuffle": X_shuffle, "Y_shuffle": Y_shuffle}"""#取数据集大小m = X.shape[1]#随机生成一个索引顺序permutation = list(np.random.permutation(m))#把X,Y打乱成相同顺序X_shuffle = X[:, permutation]Y_shuffle = Y[:, permutation]#打乱的数据集存在字典里shuffles = {"X_shuffle": X_shuffle, "Y_shuffle": Y_shuffle}return shuffles
(2)分割数据集
def get_mini_batches(X, Y, mini_batch_size):"""把数据集按照迷你批大小进行分割参数:X -- 图像数据,float32类型的矩阵Y -- 独热(one-hot)标签,uint8类型的矩阵mini_batch_size -- 迷你批大小返回:mini_batches -- 元素为(X,Y)元组的列表"""#调用刚才的函数shuffles = shuffle(X, Y)#取数据集大小num_examples = shuffles["X_shuffle"].shape[1]#计算完整迷你批的个数num_complete = num_examples // mini_batch_size#建立一个空列表,存储迷你批mini_batches = []#分配完整的迷你批for i in range(num_complete):mini_batches.append([shuffles["X_shuffle"]\[:, i*mini_batch_size:(i+1)*mini_batch_size], \shuffles["Y_shuffle"]\[:, i*mini_batch_size:(i+1)*mini_batch_size]])#如果需要的话,分配不完整的迷你批if 0 == num_examples % mini_batch_size:passelse:mini_batches.append([shuffles["X_shuffle"]\[:, num_complete*mini_batch_size:], \shuffles["Y_shuffle"]\[:, num_complete*mini_batch_size:]])return mini_batches
(3)在主函数中调用get_mini_batches()
注意在每个周期中都应该调用一次,得到新的划分。在之后的小节中我们会看到它的用法。这里只是一个测试。
mini_batches = mini_batches(X_train, Y_train_one_hot, 64)
得到的mini_batch的结构
每个元素的结构
小结
在本节中,我们写了划分迷你批的函数,在之后的训练中,我们会使用它。另外,迷你批仅用于训练,在测试和预测的时候不使用。
NumPy实现简单的神经网络分析Mnist手写数字库(三)之划分迷你批(mini-batch)相关推荐
- 基于MXNet实现MNIST手写数字体识别
MNIST手写数字集:包含训练集和测试集,训练集有60000个样本,测试集有10000个样本. MNIST手写数字训练代码分为:训练参数配置.数据读取.网络结构搭建.模型训练 import mxnet ...
- 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(双向LSTM,附完整代码和数据集)
写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 在https://blog.csdn.net/A ...
- 【原创】深度学习第7弹:小D识数字(MNIST手写数字集)
目录 一.前文回顾 二.MNIST手写数字数据集 1.什么是MNIST手写数字数据集 2.MNIST手写数字数据集下载 三.重构神经网络 1.为什么要重构神经网络 2.重构什么样的神经网络 四.识别数 ...
- 【图像分类】基于PyTorch搭建LSTM实现MNIST手写数字体识别(单向LSTM,附完整代码和数据集)
写在前面: 首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 提起LSTM大家第一反应是在NLP的数据集上比较 ...
- Mnist手写数字集的识别和可视化
当时看老师留的作业的时候发现要求读取历史数据,但网上有没找到,我自己找了手册啥的,查到了几个参数,就分享下,希望各位大佬不要看不上哈... 那就先把程序从零实现: import numpy as np ...
- pytorch实现手写数字识别_Paddle和Pytorch实现MNIST手写数字集识别对比
一.简介 1. Paddle PaddlePaddle是百度自主研发的集深度学习核心框架.工具组件和服务平台为一体的技术领先.功能完备的开源深度学习平台,有全面的官方支持的工业级应用模型,涵盖自然语言 ...
- MNIST手写数字体分类--KNN matlab实现
关于数据集神马的,请直接参考:http://blog.csdn.net/wangyuquanliuli/article/details/11606435 这里直接给出KNN matlab的实现 tra ...
- Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
点击"阅读原文"直接打开[北京站 | GPU CUDA 进阶课程]报名链接 作者,周乘,华中科技大学电子与信息工程系在读. 前言 tensorflow中文社区对官方文档进行了完整翻 ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 (zz)
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 我想写一系列深度学习的简单实战教程,用mxnet做实现平台的实例代码简单讲解深度学习常用的一些技术方向和实战样例.这 ...
最新文章
- 从2D到3D的目标检测综述
- Java数据结构和算法(二)——数组
- [MySQL 优化] 移除多余的checksum
- Java 8:长期支持的堡垒
- Hdu-6243 2017CCPC-Final A.Dogs and Cages 数学
- Chrome 开发工具 (Chrome Developer Tools):Network Panel说明
- MySQL—Mysql与MariaDB启停命令的区别
- 小米集团本周再回购1920万港元股票
- 英伟达:今年显卡将继续供不应求 尽量保证供应普通玩家
- excel 文件导入数据库(java)
- LAYUI 树形表格(tree table)
- cad填充密度怎么调整_CAD填充比例调好了,填充物数量怎么调,就是密度怎么调?...
- 语音播放与录音 (五分钟学会用 非常全面)
- 服务器虚拟cpu,服务器虚拟化 vcpu与内存配比
- 移动端2倍图和3倍图的处理方法
- 计算机网络简历技能填写,计算机网络专业个人简历个人技能范文
- svn 冲突 Error:Node remains in conflict
- 打开服务器网页要5秒,网页优化技巧 如何把网页加载时间控制在1.5秒以内
- Java 旋转、翻转图片工具类(附代码) | Java工具类
- HTTP 、HTTPS