各位同学好,今天和大家分享一下TensorFlow2.0深度学习中的交叉验证法正则化方法,最后展示一下自定义网络的小案例


1. 交叉验证

交叉验证主要防止模型过于复杂而引起的过拟合找到使模型泛化能力最优的参数。我们将数据划分为训练集、验证集、测试集。训练集用于输入网络模型作为样本进行学习。验证集是在迭代过程中对模型进行评估,寻找最优解。测试集是在整个网络训练完成后进行评估。

K折交叉验证,就是将训练集数据等比例划分成K份,以其中的1份作为验证数据其他的K-1份数据作为训练数据。每次迭代从都是从K个部分选取一份不同的数据部分作为测试数据,剩下的K-1个当作训练数据,最后把得到的K个实验结果进行平分。


划分方法 

(1)构造数据集时划分

首先导入训练集(x,y)和测试集(x_test, y_test),K折交叉验证是对测试集的划分,指定迭代500次,每次迭代都从训练集中选出一部分作为验证数据ds_val,剩下的作为训练数据ds_train。使用tf.random.shuffle() 随机打乱索引顺序,不影响x和y之间的对应关系。tf.gather()根据索引来选取值。

# 以手写数字为例,获取训练集和测试集
(x,y),(x_test,y_test) = datasets.mnist.load_data()# 预处理函数
def processing(x,y): # 从[0,255]=>[-1,1]x = 2 * tf.cast(x, dtype=tf.float32) / 255.0 - 1y = tf.cast(y, dtype=tf.int32)return(x,y)# 交叉验证K=500
for epoch in range(500):idx = tf.range(60000) # 假设training数据一共有60k张图象,生成索引idx = tf.random.shuffle(idx) # 随机打乱索引# 利用随机打散的索引来收集数据,不改变xy之间的关联x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])x_val, y_val = tf.ga,ther(x, idx[-10000:]), tf.gather(y, idx[-10000:])# 构建训练集ds_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))  # 自动将输入的xy转变成tenosr类型ds_train = ds_train.map(processing).shuffle(10000).batch(128) # 对数据集中的所有数据使用预处理函数# 构建验证集ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))  ds_val = ds_test.map(processing).batch(128) # 每次迭代取128组数据,验证不需要打乱数据

(2)使用训练函数fit()中的参数划分

如果嫌使用上面的方法构造数据集太麻烦的话,可以在模型训练函数fit()中指定划分方式validation_split=0.1,每次迭代取0.1倍的训练数据作为验证集,剩下的作为训练集。ds_train_val 要求是没有被划分过的训练集数据。这样的话就不需要再指定validation_data验证集数据了,在划分时自动生成。

# ds_train_val指没有划分过的train和val数据集,validation_split=0.1动态切割,0.1比例的数据分给val
network.fit(ds_train_val, epochs=6, validation_split=0.1, validation_freq=2)
# 不需要再指定validation_data,已经在被包含在validation_split中了

在模型迭代过程中使用验证集来查看什么时候模型效果最优,找到最优的就跳出循环。验证集在挑选模型参数的时候,先保存误差极小值对应的权重,如果后面检测到的误差都大于它,就使用当前这个权重。


2. 正则化

当采用比较复杂的模型,去拟合数据时,很容易出现过拟合现象,这会导致模型的泛化能力下降,对模型添加正则化项可以限制模型的复杂度,使得模型在复杂度和性能达到平衡。

原理: 【通俗易懂】机器学习中 L1 和 L2 正则化的直观解释

L1正则化是在原来的损失函数基础上加上权重参数的绝对值。L1可以产生0解,L1获得稀疏解。

L2正则化是在原来的损失函数基础上加上权重参数的平方和。L2可以产生趋近0的解,L2获得非零稠密解。


在构建网络层时指定正则化参数kernel_regularizer,使用二范数的方法keras.regularizers.l2,惩罚系数0.01。

# 使用二范数正则化,loss = loss + 0.001*regularizer,指定正则化的权重
model = keras.Sequential([keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), activation=tf.nn.relu),keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001), activation=tf.nn.relu),keras.layers.Dense(1, activation=tf.nn.sigmoid)])

3. 自定义网络

3.1 数据获取

首先导入我们需要的库文件,从系统中导入图片数据,划分测试集和训练集。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 输出框只输出有意义的信息#(1)数据获取
(x,y),(x_test,y_test) = datasets.cifar10.load_data() #获取图像分类数据
# 查看数据信息
print(f'x.shape: {x.shape}, y.shape: {y.shape}')  #查看训练集的维度信息
print(f'x_test.shape: {x_test.shape}, y_test.shape: {y_test.shape}')  #测试集未读信息
print(f'y[:5]: {y[:5]}')  #查看训练集目标的前5项
# 绘图展示
import matplotlib.pyplot as plt
for i in range(10): # 展示前10张图片plt.subplot(2,5,i+1)  # 2行5列第i+1个位置plt.imshow(x[i])plt.xticks([]) # 不显示x和y轴坐标刻度plt.yticks([])# 输入的图像形状
# x.shape: (50000, 32, 32, 3), y.shape: (50000, 1)
# x_test.shape: (10000, 32, 32, 3), y_test.shape: (10000, 1)

需要训练的图片如下,图片本身不清晰,这里只说一下基本的自定义网络的构造,最多只有80%准确率,模型优化到卷积神经网络章节再谈。


3.2 数据预处理

由于导入的目标值y的shape时二维[50k,1],需要将axis=1的轴压缩掉,变成一个一维的向量[50k],使用tf.squeeze()压缩指定轴,对目标值one-hot编码对应索引的值变为1,其他索引对应的值变为0,shape变为[b,10]。把特征值x的范围映射到[-1,1]之间。

#(2)数据预处理
# 定义预处理函数
def processing(x,y): # 由于目标数据是而二维的,把shape=1的轴删除,从向量变成标量y = tf.squeeze(y)  # 默认压缩所有维度为1的轴,shape为[50k]y = tf.one_hot(y, depth=10) # one-hot编码,分成10个类别,shape为[50k,10],对应下标所在的值为1# 每个像素值的范围在[-1,1]之间,从[0,255]=>[-1,1]x = 2 * tf.cast(x, dtype=tf.float32) / 255.0 - 1y = tf.cast(y, dtype=tf.int32)return(x,y)# 构建训练集数据集
ds_train = tf.data.Dataset.from_tensor_slices((x, y))  # 自动将输入的xy转变成tenosr类型
ds_train = ds_train.map(processing).batch(128).shuffle(10000)  # 对数据集中的所有数据使用预处理函数# 构建测试集数据集
ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
ds_test = ds_test.map(processing).batch(128) # 每次迭代取128组数据,测试不需要打乱数据# 构造迭代器,查看数据集是否正确
sample = next(iter(ds_train))  # 每次运行从训练数据集中取出一组xy
print('x_batch.shape', sample[0].shape, 'y_batch.shape', sample[1].shape)
# x_batch.shape (128, 32, 32, 3)   y_batch.shape (128, 10)

3.3 自定义网络

#(3)构造网络
class MyDense(layers.Layer): #必须继承layers.Layer层,放到sequential容器中# 代替layers.Dense层def __init__(self, input_dim, output_dim):super(MyDense, self).__init__()   # 调用母类初始化,必须# 自己发挥'w''b'指定名字没什么用,创建shape为[input_dim, output_dim的权重# 使用add_variable创建变量        self.kernel = self.add_variable('w',[input_dim, output_dim])self.bias = self.add_variable('b', [output_dim])# call方法,training来指示现在是训练还是测试         def call(self, inputs, training=None):x = inputs @ self.kernel + self.biasreturn x# 自定义网络层
class MyNetwork(keras.Model):  # 必须继承keras.Model大类,才能使用complie、fit等功能def __init__(self):super(MyNetwork, self).__init__()  # 调用父类Mymodel# 新建五个层次self.fc1 = MyDense(32*32*3, 256)  #input_dim=784,output_dim=256self.fc2 = MyDense(256, 128)self.fc3 = MyDense(128, 64)self.fc4 = MyDense(64, 32)        self.fc5 = MyDense(32, 10)def call(self, inputs, training=None):# 前向传播,可以接收四维的tensorx = tf.reshape(inputs, [-1,32*32*3]) # 改变输入特征的形状x = self.fc1(x) #第一层[b,32*32*3]==>[b,256]x = tf.nn.relu(x) #激活函数x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)  #logits层return x

3.4 网络配置

#(4)网络配置
network = MyNetwork()
network.compile(optimizer = optimizers.Adam(lr=0.001),  # 指定优化器loss = tf.losses.CategoricalCrossentropy(from_logits=True), #交叉熵损失metrics = ['accuracy'])  # 测试指标     #(5)网络训练,输入训练数据,循环5次,验证集为ds_test,每一次大循环做一次测试
network.fit(ds_train, epochs=5, validation_data=ds_test, validation_freq=1)# 循环5次后的结果为
Epoch 5/5
391/391 [==============================] - 3s 8ms/step - loss: 1.2197 - accuracy: 0.5707 - val_loss: 1.3929 - val_accuracy: 0.5182

优化方法到卷积神经网络再展示

【深度学习】(7) 交叉验证、正则化,自定义网络案例:图片分类,附python完整代码相关推荐

  1. 动手深度学习13——计算机视觉:数据增广、图片分类

    文章目录 一.数据增广 1.1 为何进行数据增广? 1.2 常见图片增广方式 1.2.1 翻转 1.2.2 切割(裁剪) 1.2.3 改变颜色 1.2.4 综合使用 1.3 使用图像增广进行训练 1. ...

  2. 【深度学习】(2) 数据加载,前向传播2,附python完整代码

    生成数据集: tf.data.Dataset.from_tensor_slices(tensor变量) 创建一个数据集,其元素是给定张量的切片 生成迭代器: next(iter()) next() 返 ...

  3. 【动手教你学故障诊断:Python实现Tensorflow+CNN深度学习的轴承故障诊断(西储大学数据集)(含完整代码)】

    项目名称 动手教你学故障诊断:Python实现基于Tensorflow+CNN深度学习的轴承故障诊断(西储大学数据集)(含完整代码) 项目介绍 该项目使用tensorflow和keras搭建深度学习C ...

  4. 【机器学习入门】(8) 线性回归算法:正则化、岭回归、实例应用(房价预测)附python完整代码和数据集

    各位同学好,今天我和大家分享一下python机器学习中线性回归算法的实例应用,并介绍正则化.岭回归方法.在上一篇文章中我介绍了线性回归算法的原理及推导过程:[机器学习](7) 线性回归算法:原理.公式 ...

  5. 深度学习:交叉验证(Cross Validation)

    首先,交叉验证的目的是为了让被评估的模型达到最优的泛化性能,找到使得模型泛化性能最优的超参值.在全部训练集上重新训练模型,并使用独立测试集对模型性能做出最终评价. 目前在一些论文里倒是没有特别强调这样 ...

  6. 【深度学习】(5) 简单网络,案例:服装图片分类,附python完整代码

    1. 数据获取 使用系统内部的服装数据集构建神经网络.首先导入需要的库文件,x和y中保存训练集的图像和目标.x_test和y_test中保存测试集需要的图像和目标.(x, y)及(x_test, y_ ...

  7. 【深度学习】(1) 前向传播,附python完整代码

    各位同学大家好,今天和大家分享一下TensorFlow2.0深度学习中前向传播的推导过程,使用系统自带的mnist数据集. 1. 数据获取 首先,我们导入需要用到的库文件和数据集.导入的x和y数据是数 ...

  8. 深度学习(二十)基于Overfeat的图片分类、定位、检测

    基于Overfeat的图片分类.定位.检测 原文地址:http://blog.csdn.net/hjimce/article/details/50187881 作者:hjimce 一.相关理论 本篇博 ...

  9. 【神经网络】(1) 简单网络,实例:气温预测,附python完整代码和数据集

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中的一个小案例.案例内容:现有348个气温样本数据,每个样本有8项特征值和1项目标值,进行回归预测,构建神经网络模型. 数据集免费:神经 ...

最新文章

  1. 涨点技巧!汇集13个Kaggle图像分类项目的性能提升指南
  2. memcached原理详述及配置
  3. .netcore多语言解决方案
  4. vbs删除非空文件夹
  5. Docker之Dockerfile详解
  6. mysql gtid基础_MySQL 基础知识梳理学习(四)----GTID
  7. 爬虫实战学习笔记_1 爬虫基础+HTTP原理
  8. prototype.js ajax.request,javascript – Prototype和Ajax.Request范围
  9. 使用程序修改域帐户直接领导时遇到的错误
  10. python的装饰器迭代器与生成器_详解python中的生成器、迭代器、闭包、装饰器
  11. python闯红灯检测斑马线检测红绿灯检测车速检测车流量统计车牌识别智慧交通系统
  12. 关于华为S27000交换机在局域网中的一些简单配置
  13. Laravel 下使用 FFmpeg 处理多媒体文件
  14. 如何免费使用office软件?
  15. 暴雪不管的国服 链游要插手
  16. 传奇开区发布广告和选择广告投放网站的那些事
  17. DNS域名解析服务正向解析和反向解析
  18. Kruskal算法:将森林合并成树
  19. [视频]K8飞刀 一键免杀 IE神洞网马教程
  20. An Industrial-Strength Audio Search Algorithm

热门文章

  1. Android 相对布局别自己快遗忘的属性layout_alignRight,layout_alignBottom,layout_alignTop,layout_alignLeft
  2. Error: module pages/utils/util is not defined
  3. 使用docker Hub
  4. UVA10212 【The Last Non-zero Digit.】
  5. time datetime
  6. [Luogu] 选学霸
  7. Maven的setting.xml配置文件详解(中文)
  8. RDIFramework.NET ━ .NET快速信息化系统开发框架 V3.2-新增模块管理界面导出功能(可按条件导出)...
  9. try-catch-finally对返回值的影响
  10. eclipse设置保护色非原创