神经网络搭建六步法扩展

1.自制数据集,应对特定应用

之前我们使用“MNIST数据集”,可以直接导入使用“tf.keras.datasets.mnist.load_data()”,并非所有的数据都有制作玩呗可以直接导入的数据集,因此我么你需要学会自己制作数据集。

以“MNIST”中的数据为例,我们给出:

60000张训练图片、训练图片的“图片名字 数字”形式的txt文件、10000张测试数据、测试图片的“图片名字 数字”形式的txt文件

代码实现

训练数据和测试数据的制作步骤是一样的,因此我们可以写一个函数来处理。

def generateds(path, txt):######return x, y_

对于制作流程,我们的思路是;

  1. 读取出txt文件中的内容(contents)
  2. 用x和y_两个list存放特征数据和标签
  3. 遍历每一行的内容
  4. 对每一行,根据图片名字和图片文件夹路径读入图片,转换成合适的数据结构,放入x中;将标签放入y_中
  5. 返回x, y_
def generateds(path, txt):f = open(txt, 'r')contents = f.readlines()f.close() # 记得关闭文件哦x, y_ = [], []for content in contents:value = content.split() # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表img_path = path + value[0] # 图片完整路径img = Image.open(img_path) # 读入图片img = np.array(img.convert('L'))  # 图片变为8位宽灰度值的np.array格式img = img / 255.  # 归一化x.append(img)y_.append(value[1])print('loading : ' + content)x = np.array(x)  # 变为np.array格式y_ = np.array(y_)  # 变为np.array格式y_ = y_.astype(np.int64)  # 变为64位整型return x, y_  # 返回输入特征x,返回标签y_

我们可以把上面制作数据集的代码和“六步法”的代码何在一起,先检查是否存在数据集,如果存在直接load,如果不存在,先制作,然后把数据集存储到磁盘以备之后使用。(制作数据集非常耗时间)

import tensorflow as tf
from PIL import Image
import numpy as np
import ostrain_path = './mnist_image_label/mnist_train_jpg_60000/'
train_txt = './mnist_image_label/mnist_train_jpg_60000.txt'
x_train_savepath = './mnist_image_label/mnist_x_train.npy'
y_train_savepath = './mnist_image_label/mnist_y_train.npy'test_path = './mnist_image_label/mnist_test_jpg_10000/'
test_txt = './mnist_image_label/mnist_test_jpg_10000.txt'
x_test_savepath = './mnist_image_label/mnist_x_test.npy'
y_test_savepath = './mnist_image_label/mnist_y_test.npy'def generateds(path, txt):f = open(txt, 'r')contents = f.readlines()f.close()x, y_ = [], [for content in contents:value = content.split(img_path = path + value[0]img = Image.open(img_path)img = np.array(img.convert('L'))img = img / 255.  x.append(img) y_.append(value[1]) print('loading : ' + content) x = np.array(x) y_ = np.array(y_) y_ = y_.astype(np.int64) return x, y_  # 返回输入特征x,返回标签y_# 如果数据集存在就直接使用,如果不存在就 先生成 再存储
if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(x_test_savepath) and os.path.exists(y_test_savepath):print('-------------Load Datasets-----------------')x_train_save = np.load(x_train_savepath)y_train = np.load(y_train_savepath)x_test_save = np.load(x_test_savepath)y_test = np.load(y_test_savepath)x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))
else:print('-------------Generate Datasets-----------------')x_train, y_train = generateds(train_path, train_txt)x_test, y_test = generateds(test_path, test_txt)print('-------------Save Datasets-----------------')x_train_save = np.reshape(x_train, (len(x_train), -1))x_test_save = np.reshape(x_test, (len(x_test), -1))np.save(x_train_savepath, x_train_save)np.save(y_train_savepath, y_train)np.save(x_test_savepath, x_test_save)np.save(y_test_savepath, y_test)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)model.summary()

2.数据增强,增大数据量

有些时候,我们的数据量太小,需要使用数据增强来丰富我们的数据集,数据增强就是把原有的数据做一些变化(如图片的旋转,镜像等)来扩充数据。

步骤

image_gen_train=tf.keras.preprocessing.image.ImageDataGenerator( 增 强 方法)

image_gen_train.fit(x_train)

常用增强方法

  • 缩放系数:rescale=所有数据将乘以提供的值
  • 随机旋转:rotation_range=随机旋转角度数范围
  • 宽度偏移:width_shift_range=随机宽度偏移量
  • 高度偏移:height_shift_range=随机高度偏移量
  • 水平翻转:horizontal_flip=是否水平随机翻转
  • 随机缩放:zoom_range=随机缩放的范围 [1-n,1+n]
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGeneratormnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 1. !!!!!!!!!!!!!!!!!!
# 给数据增加一个维度,从(60000, 28, 28)reshape为(60000, 28, 28, 1)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  image_gen_train=ImageDataGenerator(rescale=1./1.,rotation_range=45,width_shift_range=.15,height_shift_range=.15,horizontal_flip=False,zoom_range=0.5
)
image_gen_train.fit(x_train)model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])# 2.!!!!!!!!!!!!!!!!!
model.fit(image_gen_train.flow(x_train,y_train,batch_size=32),epochs=5,validation_data=(x_test,y_test),validation_freq=1)model.summary()

注意:

  1. 数据增强函数的输入要求是4维,需要通过reshape调整
  2. model.fit输入训练数据和batch时要是用“image_gen_train.flow(x_train,y_train,batch_size=)”

3.断点续训,存取模型

3.1读取模型

读取模型使用model.load_weights()

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)

3.2保存模型

借助tensorflow给出的回调函数,直接保存参数和网络

具体用法如下:

#上接“六步法”中的前四部checkpoint_save_path = "./checkpoint/mnist.ckpt"
# 如果模型已经被保存过 导入模型继续训练
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)# 保存模型
cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history=model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),                                 validation_freq=1,callbacks=[cp_callback])#下面是“六步法”的最后一步

4.参数提取,写至文本

4.1提取可训练参数

model.trainable_variables返回模型中可训练的参数。

4.2设置 print 输出格式

如果数据量太大,print输出时默认会将中间的数据省略,用“省略号”表示。

np.set_printoptions(

precision=小数点后按四舍五入保留几位,

threshold=数组元素数量少于或等于门槛值,打印全部元素;否则打印门槛值+1 个元素,中间用省略号补充

)

把threshold设置为np.inf即可输出全部数据元素。

代码如下,需要关注的地方我用!标注:

import tensorflow as tf
import os
import numpy as np# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
np.set_printoptions(threshold=np.inf)mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])model.summary()# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
print(model.trainable_variables)
file=open('./weights.txt','w')
for v in model.trainable_variables:file.write(str(v.name)+'\n')file.write(str(v.shape)+'\n')file.write(str(v.numpy())+'\n')
file.close()

weights.txt内容:

5.acc/loss 可视化,查看效果

使用变量history接收model.fit返回值,之后history.history获取acc/loss值。

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])acc=history.history['sparse_categorical_accuracy'] # 训练数据正确率
val_acc=history.history['val_sparse_categorical_accuracy']  # 测试数据正确率
loss=history.history['loss']  # 训练数据损失值
val_loss=history.history['val_loss']  # 测试数据损失值
# 画图
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

6、应用程序,给图识物

这里会给神经网络输入一条特征数据(手写数字图),让学习后的神经网络来预测这个数字是多少。

eg.

6.1前向传播执行应用

使用model.predict(输入数据, batch_size=整数),返回前向传播计算结果。

因为我们这里是使用已经训练好的模型,所以需要:

  1. 构建一样的网络结构
  2. 加载保存的模型
  3. model.predict(x_predict)进行预测
from PIL import Image
import numpy as np
import tensorflow as tf# 模型保存地址
model_save_path = './checkpoint/mnist.ckpt'# 1. 构建一样的网络结构
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])# 2.  加载保存的模型
model.load_weights(model_save_path)preNum = int(input("input the number of test pictures:"))for i in range(preNum):image_path = input("the path of test picture:")img = Image.open(image_path)# 预处理 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!# 变为28*28img = img.resize((28, 28), Image.ANTIALIAS)# 变为灰度图img_arr = np.array(img.convert('L'))# 变为黑底白字# 方法一#img_arr = 255-img_arr# 方法二for i in range(28):for j in range(28):if img_arr[i][j] < 200:img_arr[i][j] = 255else:img_arr[i][j] = 0# 归一化img_arr = img_arr / 255.0# 增加一个维度,变为1*28*28 为了和训练时使用的数据的维度一致x_predict = img_arr[tf.newaxis, ...]# 3.  预测result = model.predict(x_predict)# pred是预测认为是0-9各个数字的概率pred = tf.argmax(result, axis=1)print('\n')# pred是张量,需要用tf.printtf.print(pred)

这里需要注意,因为我们训练神经网络时使用的是28*28的灰度图,黑底白字,所以需要把我们这里需要预测的图片(上图)转换为“28*28的灰度图,黑底白字”格式。

其中,变为黑底白字可以使用简单的方法一:img_arr = 255 - img_arr

也可以使用复杂的方法二,方法二将所有像素转换为“纯黑0”或“纯白255”,这样做可以降低图片中的噪声带来的影响。

神经网络搭建六步法扩展相关推荐

  1. 【Tensorflow学习三】神经网络搭建八股“六步法”编写手写数字识别训练模型

    神经网络搭建八股"六步法"编写手写数字识别训练模型 Sequential用法 model.compile(optimizer=优化器,loss=损失函数,metrics=[&quo ...

  2. Excel数据分析学习笔记(一)数据分析六步法和重要分析模型

    第一章 数据分析类型: 描述性统计分析,概括表述事物关系 探索性数据分析,发现数据的新特征 验证性数据分析,对假设进行证伪或证实 数据分析六步法 1.确定目标(这个很重要,在实践当中是比技术还要重要的 ...

  3. 以在线教育销售CRM为例,谈谈业务大盘拆解优化的六步法

    前言 对于一名企业内的B端产品经理,核心工作之一,是利用技术赋能业务,帮助业务改进,实现企业的商业价值. 如何识别业务问题?如何发现改进机会点?如何分析业务大盘?如何从产品视角给出解决思路?如何基于数 ...

  4. dfmea文件_DFMEA和PFMEA的“六步法”其实很不一样

    新版FMEA手册发布在即,"六步法"将成为其最核心的方法论,DFMEA和PFMEA的分析过程都需要遵照该方法."六步法"简述如下图: 但具体在应用"六 ...

  5. Kotlin学习笔记 第二章 类与对象 第五节 可见性 第六节 扩展

    参考链接 Kotlin官方文档 https://kotlinlang.org/docs/home.html 中文网站 https://www.kotlincn.net/docs/reference/p ...

  6. 安全生产六步法是什么_安全生产六步法

    卓越绩效管理模式.管理架构及实施六步法共121页_财务管理_经管营销_专业资料.... 1.8 9.3 73.4 气象指数 演 (横向) 练 0.3 修 0.2 正值 2.6 7.7 8.5 84.5 ...

  7. 优思学院|六西格玛管理的经典六步法

    优思学院按:六西格玛管理的目标是消除变异,六西格管理的理论基础是标准差,用标准差来衡量变异的大小.所谓的六西格玛管理的经典六步法,就是六西格玛的六个基本的管理哲学. 一.以顾客为关注重心是六西格玛的灵 ...

  8. 基于神经网络的语音频带扩展

    本博客转载自顾宇的<基于神经网络的语音频带扩展方法研究>,大家可从知网获取. 摘要 语音频带扩展旨在从频带受限的窄带语音信号中恢复宽带语音信号.由于受到语音采集设备以及信道条件的限制,传输 ...

  9. stm32捕获占空比_【电机控制】六步法驱动BLDC电机,使用硬件COM事件,STM32+CUBEMX(HAL库)配置...

    现在我也只能说是电机入门,但是想要把电机作为终身事业,从有霍尔到无霍尔,从方波到正弦波,现在把其中的一些知识点分享出来,因为电机控制其实的资料比较难找的,前人栽树,后人乘凉,如果我有什么错误,在知乎上 ...

  10. 安全生产六步法是什么_海孜煤矿安全生产管理“六步法”实施办法.doc

    海孜煤矿安全生产管理"六步法"实施办法.doc 海孜煤矿安全生产管理"六步法"实施办法为强化基层科(区)安全自主管理,有效防范与控制安全风险,实现矿井安全生产, ...

最新文章

  1. 人工智能在音频链中找到自己的声音
  2. Jmeter(一)http接口添加header和cookie --转载
  3. ConcurrentHashMap源码解析(1)
  4. 使用SHA256证书进行微软数字签名代码签名
  5. 初步了解win32界面库DuiLib
  6. MySQL时间戳(毫秒/秒)与日期格式的相互转换
  7. 为使节构建控制平面的指南第3部分-特定于域的配置API
  8. easyui学习笔记一:主要结构
  9. xForm应用开发手册
  10. HDU2019 数列有序!【入门】
  11. spring事务源码执行过程分析
  12. G-sensor 介绍
  13. 电脑硬盘双击打不开,提示格式化怎么办?
  14. OKR与KPI有什么区别
  15. Windows网络共享或共享打印机无法访问连接的简单终极解决方法
  16. ArcGIS必会的几个工具的应用
  17. 全国代收货款平台-快递鸟、菜鸟
  18. Java加密技术(一)—— HMACSHA1 加密算法
  19. 免费下载学术论文的网站
  20. Struts——开源MVC框架

热门文章

  1. C#利用NOPI导出到Excel
  2. DigitalFilmTools Rays 2.1.2汉化版|丁达尔光束耶稣光滤镜插件
  3. SQL-实现excel向下填充的功能
  4. B样条曲线的一些基本性质
  5. 主流数据库以及适用场景思维导图
  6. 一个500人天的BI项目实施记录
  7. HICE第四天笔记 12月8日
  8. vue如何集成阿里云视频服务组件(aliplayer)视频功能是使用el-dialog 弹出aliplayer播放
  9. switch怎么一个账号绑定各种服务器,NS怎么一个账号两台机器使用_Nintendo Switch 新旧机器同使用教程_尼萌手游网...
  10. NB-LOT 常用AT指令集简介