读书笔记:mini-batch学习 ← 斋藤康毅
如果以全部数据为对象求损失函数的和,则计算过程需要花费较长的时间。因此,我们从全部数据中选出一部分,作为全部数据的“近似”。这种学习方式称为mini-batch学习。
下面给出了从MNIST数据集的训练数据中随机抽取10笔数据的代码。
【mini-batch学习示例代码】
import sys,os
sys.path.append(os.pardir)
import numpy as np
from mnist import load_mnist(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True,one_hot_label=True)print(x_train.shape) #(60000, 784)
print(t_train.shape) #(60000, 10)train_size=x_train.shape[0]
batch_size=10
batch_mask=np.random.choice(train_size,batch_size)
x_batch=x_train[batch_mask]
t_batch=t_train[batch_mask]
print(x_batch)
print(t_batch)
运行此代码前,可将之前https://blog.csdn.net/hnjzsyjyj/article/details/128721706例子中下载的 MNIST 数据集的 'train-images-idx3-ubyte.gz'、'train-labels-idx1-ubyte.gz'、't10k-images-idx3-ubyte.gz'、't10k-labels-idx1-ubyte.gz' 等四个文件,以及生成的 mnist.pkl 文件直接复制过来。
其中,mnist.py的代码如下:
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}dataset_dir = os.path.dirname(os.path.abspath('__file__'))
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()
此mini-batch学习的代码运行后的结果如下:
(60000, 784)(60000, 10)[[0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.]...[0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.][0. 0. 0. ... 0. 0. 0.]][[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.][0. 0. 0. 0. 0. 0. 0. 0. 1. 0.][0. 0. 0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 0. 0. 1. 0. 0. 0. 0. 0.][0. 0. 0. 0. 0. 0. 0. 0. 0. 1.][0. 1. 0. 0. 0. 0. 0. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0. 0. 0.][0. 0. 0. 0. 0. 0. 0. 0. 1. 0.][1. 0. 0. 0. 0. 0. 0. 0. 0. 0.][0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]
读书笔记:mini-batch学习 ← 斋藤康毅相关推荐
- 斋藤康毅-深度学习入门 学习笔记二
ch02 感知机 Perceptron.py import numpy as np''' def AND(x1, x2):w1, w2, theta = 0.5, 0.5, 0.7tmp = w1 * ...
- 斋藤康毅-深度学习入门 学习笔记三
ch03 神经网络 1. pkl文件的创建与导入 python官方文档对pickle模块的定义:pickle The pickle module implements binary protocols ...
- 读书笔记:手写数字识别 ← 斋藤康毅
求解机器学习问题的步骤可以分为"学习"和"推理"两个阶段. 本例假设"学习"阶段已经完成,并将学习到的权重和偏置参数保存在pickle文件s ...
- 斋藤康毅-深度学习入门 学习笔记四
ch 神经网络的学习 损失函数 1.1 均方误差 import numpy as npdef mean_squared_error(y, t):return 0.5 * np.sum((y - t) ...
- 斋藤康毅-深度学习入门 学习笔记五
ch 误差反向传播法 乘法和加法层的反向传播 class AddLayer:def __init__(self):passdef forward(self, x, y):out = x + yretu ...
- 斋藤康毅-深度学习入门 学习笔记一
ch01 Python入门 basic.py ''' python --versionnote in python3 5/2 = 2.54**2 = 16type(3.4)x = 10 then x ...
- 深度学习入门_斋藤康毅_chapter23
系列文章目录 这是第一部分 文章目录 系列文章目录 前言 一.chapter 1 二.chapter感知机 1.numpy生成数组 三. 神经网络 总结 前言 本来是想通过李沐的网课入门深度学习的,但 ...
- 《深度学习入门--基于python的理论与实现》——斋藤康毅读书笔记
<深度学习入门--基于python的理论与实现>读书笔记(第二章) 写在前面 第二章:感知机 2.1感知机是什么 2.2简单的逻辑电路 2.2.1与门(and gate) 2.2.2与非门 ...
- 《深度学习入门——基于Python的理论与实现》斋藤康毅学习笔记(一)
第一章 (只将自己有疑惑并得到解决的学习内容作以下笔记) 1.python解释器 1.1数组 错误:a[ : -1] 不是获取所有元素 修改:a[ : -1]表示获取从第一个元素到最后一个元素之间的元 ...
最新文章
- 【莓闻】2009年黑莓增长显著 智能手机领域第一
- 3D中OBJ文件格式详解
- 图像多尺度对比增强算法
- springboot 使用webflux响应式开发教程(一)
- volta架构 微型计算机,性能大爆炸 NVIDIA新GPU架构曝光
- python - hadoop,mapreduce demo
- C语言指针学习(续)
- struts2+ajax+json使用实例
- mysql index sub part_mysql中的key和index 理解
- 伯克利、OpenAI等提出基于模型的元策略优化强化学习
- C#程序读取MAC地址的方法
- react打包后图片丢失_宜信技术实践|指尖前端重构(React)技术调研分析
- thymeleaf获取url地址跳转时所带参数
- 【性能】雅虎军规(14条常用)笔记
- android上的单片机编程软件下载,AVR单片机编程软件(AVR_fighter)
- 大数据杀熟 算法_大数据杀熟这事,究竟有多没谱?
- Matlab 黎卡提方程
- oracle循环数据字典,Oracle DUL的工作原理和技术实现
- ais解码_解决ais cassandra问题
- STM32定时器-6步PWM输出