神经网络搭建六步法扩展
神经网络搭建六步法扩展
1.自制数据集,应对特定应用
之前我们使用“MNIST数据集”,可以直接导入使用“tf.keras.datasets.mnist.load_data()”,并非所有的数据都有制作玩呗可以直接导入的数据集,因此我么你需要学会自己制作数据集。
以“MNIST”中的数据为例,我们给出:
60000张训练图片、训练图片的“图片名字 数字”形式的txt文件、10000张测试数据、测试图片的“图片名字 数字”形式的txt文件
代码实现
训练数据和测试数据的制作步骤是一样的,因此我们可以写一个函数来处理。
def generateds(path, txt):######return x, y_
对于制作流程,我们的思路是;
- 读取出txt文件中的内容(contents)
- 用x和y_两个list存放特征数据和标签
- 遍历每一行的内容
- 对每一行,根据图片名字和图片文件夹路径读入图片,转换成合适的数据结构,放入x中;将标签放入y_中
- 返回x, y_
def generateds(path, txt):f = open(txt, 'r')contents = f.readlines()f.close() # 记得关闭文件哦x, y_ = [], []for content in contents:value = content.split() # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表img_path = path + value[0] # 图片完整路径img = Image.open(img_path) # 读入图片img = np.array(img.convert('L')) # 图片变为8位宽灰度值的np.array格式img = img / 255. # 归一化x.append(img)y_.append(value[1])print('loading : ' + content)x = np.array(x) # 变为np.array格式y_ = np.array(y_) # 变为np.array格式y_ = y_.astype(np.int64) # 变为64位整型return x, y_ # 返回输入特征x,返回标签y_
我们可以把上面制作数据集的代码和“六步法”的代码何在一起,先检查是否存在数据集,如果存在直接load,如果不存在,先制作,然后把数据集存储到磁盘以备之后使用。(制作数据集非常耗时间)
import tensorflow as tf
from PIL import Image
import numpy as np
import ostrain_path = './mnist_image_label/mnist_train_jpg_60000/'
train_txt = './mnist_image_label/mnist_train_jpg_60000.txt'
x_train_savepath = './mnist_image_label/mnist_x_train.npy'
y_train_savepath = './mnist_image_label/mnist_y_train.npy'test_path = './mnist_image_label/mnist_test_jpg_10000/'
test_txt = './mnist_image_label/mnist_test_jpg_10000.txt'
x_test_savepath = './mnist_image_label/mnist_x_test.npy'
y_test_savepath = './mnist_image_label/mnist_y_test.npy'def generateds(path, txt):f = open(txt, 'r')contents = f.readlines()f.close()x, y_ = [], [for content in contents:value = content.split(img_path = path + value[0]img = Image.open(img_path)img = np.array(img.convert('L'))img = img / 255. x.append(img) y_.append(value[1]) print('loading : ' + content) x = np.array(x) y_ = np.array(y_) y_ = y_.astype(np.int64) return x, y_ # 返回输入特征x,返回标签y_# 如果数据集存在就直接使用,如果不存在就 先生成 再存储
if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(x_test_savepath) and os.path.exists(y_test_savepath):print('-------------Load Datasets-----------------')x_train_save = np.load(x_train_savepath)y_train = np.load(y_train_savepath)x_test_save = np.load(x_test_savepath)y_test = np.load(y_test_savepath)x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:print('-------------Generate Datasets-----------------')x_train, y_train = generateds(train_path, train_txt)x_test, y_test = generateds(test_path, test_txt)print('-------------Save Datasets-----------------')x_train_save = np.reshape(x_train, (len(x_train), -1))x_test_save = np.reshape(x_test, (len(x_test), -1))np.save(x_train_savepath, x_train_save)np.save(y_train_savepath, y_train)np.save(x_test_savepath, x_test_save)np.save(y_test_savepath, y_test)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)model.summary()
2.数据增强,增大数据量
有些时候,我们的数据量太小,需要使用数据增强来丰富我们的数据集,数据增强就是把原有的数据做一些变化(如图片的旋转,镜像等)来扩充数据。
步骤:
image_gen_train=tf.keras.preprocessing.image.ImageDataGenerator
( 增 强 方法)
image_gen_train.fit(x_train)
常用增强方法:
- 缩放系数:rescale=所有数据将乘以提供的值
- 随机旋转:rotation_range=随机旋转角度数范围
- 宽度偏移:width_shift_range=随机宽度偏移量
- 高度偏移:height_shift_range=随机高度偏移量
- 水平翻转:horizontal_flip=是否水平随机翻转
- 随机缩放:zoom_range=随机缩放的范围 [1-n,1+n]
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGeneratormnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 1. !!!!!!!!!!!!!!!!!!
# 给数据增加一个维度,从(60000, 28, 28)reshape为(60000, 28, 28, 1)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) image_gen_train=ImageDataGenerator(rescale=1./1.,rotation_range=45,width_shift_range=.15,height_shift_range=.15,horizontal_flip=False,zoom_range=0.5
)
image_gen_train.fit(x_train)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])# 2.!!!!!!!!!!!!!!!!!
model.fit(image_gen_train.flow(x_train,y_train,batch_size=32),epochs=5,validation_data=(x_test,y_test),validation_freq=1)model.summary()
注意:
- 数据增强函数的输入要求是4维,需要通过reshape调整
- model.fit输入训练数据和batch时要是用“image_gen_train.flow(x_train,y_train,batch_size=)”
3.断点续训,存取模型
3.1读取模型
读取模型使用model.load_weights()
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)
3.2保存模型
借助tensorflow给出的回调函数,直接保存参数和网络
具体用法如下:
#上接“六步法”中的前四部checkpoint_save_path = "./checkpoint/mnist.ckpt"
# 如果模型已经被保存过 导入模型继续训练
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)# 保存模型
cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history=model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])#下面是“六步法”的最后一步
4.参数提取,写至文本
4.1提取可训练参数
model.trainable_variables
返回模型中可训练的参数。
4.2设置 print 输出格式
如果数据量太大,print输出时默认会将中间的数据省略,用“省略号”表示。
np.set_printoptions
(
precision=小数点后按四舍五入保留几位,
threshold=数组元素数量少于或等于门槛值,打印全部元素;否则打印门槛值+1 个元素,中间用省略号补充
)
把threshold设置为np.inf即可输出全部数据元素。
代码如下,需要关注的地方我用!标注:
import tensorflow as tf
import os
import numpy as np# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
np.set_printoptions(threshold=np.inf)mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])model.summary()# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
print(model.trainable_variables)
file=open('./weights.txt','w')
for v in model.trainable_variables:file.write(str(v.name)+'\n')file.write(str(v.shape)+'\n')file.write(str(v.numpy())+'\n')
file.close()
weights.txt内容:
5.acc/loss 可视化,查看效果
使用变量history接收model.fit返回值,之后history.history获取acc/loss值。
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])acc=history.history['sparse_categorical_accuracy'] # 训练数据正确率
val_acc=history.history['val_sparse_categorical_accuracy'] # 测试数据正确率
loss=history.history['loss'] # 训练数据损失值
val_loss=history.history['val_loss'] # 测试数据损失值
# 画图
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
6、应用程序,给图识物
这里会给神经网络输入一条特征数据(手写数字图),让学习后的神经网络来预测这个数字是多少。
eg.
6.1前向传播执行应用
使用model.predict
(输入数据, batch_size=整数),返回前向传播计算结果。
因为我们这里是使用已经训练好的模型,所以需要:
- 构建一样的网络结构
- 加载保存的模型
- model.predict(x_predict)进行预测
from PIL import Image
import numpy as np
import tensorflow as tf# 模型保存地址
model_save_path = './checkpoint/mnist.ckpt'# 1. 构建一样的网络结构
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])# 2. 加载保存的模型
model.load_weights(model_save_path)preNum = int(input("input the number of test pictures:"))for i in range(preNum):image_path = input("the path of test picture:")img = Image.open(image_path)# 预处理 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!# 变为28*28img = img.resize((28, 28), Image.ANTIALIAS)# 变为灰度图img_arr = np.array(img.convert('L'))# 变为黑底白字# 方法一#img_arr = 255-img_arr# 方法二for i in range(28):for j in range(28):if img_arr[i][j] < 200:img_arr[i][j] = 255else:img_arr[i][j] = 0# 归一化img_arr = img_arr / 255.0# 增加一个维度,变为1*28*28 为了和训练时使用的数据的维度一致x_predict = img_arr[tf.newaxis, ...]# 3. 预测result = model.predict(x_predict)# pred是预测认为是0-9各个数字的概率pred = tf.argmax(result, axis=1)print('\n')# pred是张量,需要用tf.printtf.print(pred)
这里需要注意,因为我们训练神经网络时使用的是28*28的灰度图,黑底白字,所以需要把我们这里需要预测的图片(上图)转换为“28*28的灰度图,黑底白字”格式。
其中,变为黑底白字可以使用简单的方法一:img_arr = 255 - img_arr
也可以使用复杂的方法二,方法二将所有像素转换为“纯黑0”或“纯白255”,这样做可以降低图片中的噪声带来的影响。
神经网络搭建六步法扩展相关推荐
- 【Tensorflow学习三】神经网络搭建八股“六步法”编写手写数字识别训练模型
神经网络搭建八股"六步法"编写手写数字识别训练模型 Sequential用法 model.compile(optimizer=优化器,loss=损失函数,metrics=[&quo ...
- Excel数据分析学习笔记(一)数据分析六步法和重要分析模型
第一章 数据分析类型: 描述性统计分析,概括表述事物关系 探索性数据分析,发现数据的新特征 验证性数据分析,对假设进行证伪或证实 数据分析六步法 1.确定目标(这个很重要,在实践当中是比技术还要重要的 ...
- 以在线教育销售CRM为例,谈谈业务大盘拆解优化的六步法
前言 对于一名企业内的B端产品经理,核心工作之一,是利用技术赋能业务,帮助业务改进,实现企业的商业价值. 如何识别业务问题?如何发现改进机会点?如何分析业务大盘?如何从产品视角给出解决思路?如何基于数 ...
- dfmea文件_DFMEA和PFMEA的“六步法”其实很不一样
新版FMEA手册发布在即,"六步法"将成为其最核心的方法论,DFMEA和PFMEA的分析过程都需要遵照该方法."六步法"简述如下图: 但具体在应用"六 ...
- Kotlin学习笔记 第二章 类与对象 第五节 可见性 第六节 扩展
参考链接 Kotlin官方文档 https://kotlinlang.org/docs/home.html 中文网站 https://www.kotlincn.net/docs/reference/p ...
- 安全生产六步法是什么_安全生产六步法
卓越绩效管理模式.管理架构及实施六步法共121页_财务管理_经管营销_专业资料.... 1.8 9.3 73.4 气象指数 演 (横向) 练 0.3 修 0.2 正值 2.6 7.7 8.5 84.5 ...
- 优思学院|六西格玛管理的经典六步法
优思学院按:六西格玛管理的目标是消除变异,六西格管理的理论基础是标准差,用标准差来衡量变异的大小.所谓的六西格玛管理的经典六步法,就是六西格玛的六个基本的管理哲学. 一.以顾客为关注重心是六西格玛的灵 ...
- 基于神经网络的语音频带扩展
本博客转载自顾宇的<基于神经网络的语音频带扩展方法研究>,大家可从知网获取. 摘要 语音频带扩展旨在从频带受限的窄带语音信号中恢复宽带语音信号.由于受到语音采集设备以及信道条件的限制,传输 ...
- stm32捕获占空比_【电机控制】六步法驱动BLDC电机,使用硬件COM事件,STM32+CUBEMX(HAL库)配置...
现在我也只能说是电机入门,但是想要把电机作为终身事业,从有霍尔到无霍尔,从方波到正弦波,现在把其中的一些知识点分享出来,因为电机控制其实的资料比较难找的,前人栽树,后人乘凉,如果我有什么错误,在知乎上 ...
- 安全生产六步法是什么_海孜煤矿安全生产管理“六步法”实施办法.doc
海孜煤矿安全生产管理"六步法"实施办法.doc 海孜煤矿安全生产管理"六步法"实施办法为强化基层科(区)安全自主管理,有效防范与控制安全风险,实现矿井安全生产, ...
最新文章
- 人工智能在音频链中找到自己的声音
- Jmeter(一)http接口添加header和cookie --转载
- ConcurrentHashMap源码解析(1)
- 使用SHA256证书进行微软数字签名代码签名
- 初步了解win32界面库DuiLib
- MySQL时间戳(毫秒/秒)与日期格式的相互转换
- 为使节构建控制平面的指南第3部分-特定于域的配置API
- easyui学习笔记一:主要结构
- xForm应用开发手册
- HDU2019 数列有序!【入门】
- spring事务源码执行过程分析
- G-sensor 介绍
- 电脑硬盘双击打不开,提示格式化怎么办?
- OKR与KPI有什么区别
- Windows网络共享或共享打印机无法访问连接的简单终极解决方法
- ArcGIS必会的几个工具的应用
- 全国代收货款平台-快递鸟、菜鸟
- Java加密技术(一)—— HMACSHA1 加密算法
- 免费下载学术论文的网站
- Struts——开源MVC框架
热门文章
- C#利用NOPI导出到Excel
- DigitalFilmTools Rays 2.1.2汉化版|丁达尔光束耶稣光滤镜插件
- SQL-实现excel向下填充的功能
- B样条曲线的一些基本性质
- 主流数据库以及适用场景思维导图
- 一个500人天的BI项目实施记录
- HICE第四天笔记 12月8日
- vue如何集成阿里云视频服务组件(aliplayer)视频功能是使用el-dialog 弹出aliplayer播放
- switch怎么一个账号绑定各种服务器,NS怎么一个账号两台机器使用_Nintendo Switch 新旧机器同使用教程_尼萌手游网...
- NB-LOT 常用AT指令集简介