如果以全部数据为对象求损失函数的和,则计算过程需要花费较长的时间。因此,我们从全部数据中选出一部分,作为全部数据的“近似”。这种学习方式称为mini-batch学习。

下面给出了从MNIST数据集的训练数据中随机抽取10笔数据的代码。

【mini-batch学习示例代码】

import sys,os
sys.path.append(os.pardir)
import numpy as np
from mnist import load_mnist(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True,one_hot_label=True)print(x_train.shape)  #(60000, 784)
print(t_train.shape)  #(60000, 10)train_size=x_train.shape[0]
batch_size=10
batch_mask=np.random.choice(train_size,batch_size)
x_batch=x_train[batch_mask]
t_batch=t_train[batch_mask]
print(x_batch)
print(t_batch)

运行此代码前,可将之前https://blog.csdn.net/hnjzsyjyj/article/details/128721706例子中下载的 MNIST 数据集的 'train-images-idx3-ubyte.gz'、'train-labels-idx1-ubyte.gz'、't10k-images-idx3-ubyte.gz'、't10k-labels-idx1-ubyte.gz' 等四个文件,以及生成的 mnist.pkl 文件直接复制过来。

其中,mnist.py的代码如下:

try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}dataset_dir = os.path.dirname(os.path.abspath('__file__'))
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")    with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] =  _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])    dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()

此mini-batch学习的代码运行后的结果如下:

(60000, 784)(60000, 10)[[0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.]...[0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.]][[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.][0. 0. 0. 0. 0. 0. 0. 0. 1. 0.][0. 0. 0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 0. 0. 1. 0. 0. 0. 0. 0.][0. 0. 0. 0. 0. 0. 0. 0. 0. 1.][0. 1. 0. 0. 0. 0. 0. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0. 0. 0.][0. 0. 0. 0. 0. 0. 0. 0. 1. 0.][1. 0. 0. 0. 0. 0. 0. 0. 0. 0.][0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]

读书笔记:mini-batch学习 ← 斋藤康毅相关推荐

  1. 斋藤康毅-深度学习入门 学习笔记二

    ch02 感知机 Perceptron.py import numpy as np''' def AND(x1, x2):w1, w2, theta = 0.5, 0.5, 0.7tmp = w1 * ...

  2. 斋藤康毅-深度学习入门 学习笔记三

    ch03 神经网络 1. pkl文件的创建与导入 python官方文档对pickle模块的定义:pickle The pickle module implements binary protocols ...

  3. 读书笔记:手写数字识别 ← 斋藤康毅

    求解机器学习问题的步骤可以分为"学习"和"推理"两个阶段. 本例假设"学习"阶段已经完成,并将学习到的权重和偏置参数保存在pickle文件s ...

  4. 斋藤康毅-深度学习入门 学习笔记四

    ch 神经网络的学习 损失函数 1.1 均方误差 import numpy as npdef mean_squared_error(y, t):return 0.5 * np.sum((y - t) ...

  5. 斋藤康毅-深度学习入门 学习笔记五

    ch 误差反向传播法 乘法和加法层的反向传播 class AddLayer:def __init__(self):passdef forward(self, x, y):out = x + yretu ...

  6. 斋藤康毅-深度学习入门 学习笔记一

    ch01 Python入门 basic.py ''' python --versionnote in python3 5/2 = 2.54**2 = 16type(3.4)x = 10 then x ...

  7. 深度学习入门_斋藤康毅_chapter23

    系列文章目录 这是第一部分 文章目录 系列文章目录 前言 一.chapter 1 二.chapter感知机 1.numpy生成数组 三. 神经网络 总结 前言 本来是想通过李沐的网课入门深度学习的,但 ...

  8. 《深度学习入门--基于python的理论与实现》——斋藤康毅读书笔记

    <深度学习入门--基于python的理论与实现>读书笔记(第二章) 写在前面 第二章:感知机 2.1感知机是什么 2.2简单的逻辑电路 2.2.1与门(and gate) 2.2.2与非门 ...

  9. 《深度学习入门——基于Python的理论与实现》斋藤康毅学习笔记(一)

    第一章 (只将自己有疑惑并得到解决的学习内容作以下笔记) 1.python解释器 1.1数组 错误:a[ : -1] 不是获取所有元素 修改:a[ : -1]表示获取从第一个元素到最后一个元素之间的元 ...

最新文章

  1. 【莓闻】2009年黑莓增长显著 智能手机领域第一
  2. 3D中OBJ文件格式详解
  3. 图像多尺度对比增强算法
  4. springboot 使用webflux响应式开发教程(一)
  5. volta架构 微型计算机,性能大爆炸 NVIDIA新GPU架构曝光
  6. python - hadoop,mapreduce demo
  7. C语言指针学习(续)
  8. struts2+ajax+json使用实例
  9. mysql index sub part_mysql中的key和index 理解
  10. 伯克利、OpenAI等提出基于模型的元策略优化强化学习
  11. C#程序读取MAC地址的方法
  12. react打包后图片丢失_宜信技术实践|指尖前端重构(React)技术调研分析
  13. thymeleaf获取url地址跳转时所带参数
  14. 【性能】雅虎军规(14条常用)笔记
  15. android上的单片机编程软件下载,AVR单片机编程软件(AVR_fighter)
  16. 大数据杀熟 算法_大数据杀熟这事,究竟有多没谱?
  17. Matlab 黎卡提方程
  18. oracle循环数据字典,Oracle DUL的工作原理和技术实现
  19. ais解码_解决ais cassandra问题
  20. STM32定时器-6步PWM输出

热门文章

  1. JavaScript 判断 Chrome 内核的 360 浏览器
  2. springboot的多模块开发
  3. 如何使用CSS实现硬件加速?
  4. python日出日落时间实现和详解
  5. 未来世界的精彩,你难以想象
  6. 正态分布(高斯分布)的由来(公式推导)
  7. Path Finder for Mac(系统文件管理器)
  8. 服务器装win10性能,服务器可以装win10吗
  9. 19年英语及计算机统考时间,2019年英语四六级考试时间及考试科目【已公布】
  10. B19 - 999、大数据组件学习⑯ - ElasticSearch