1.前言

Overfitting 也被称为过度学习,过度拟合。我们总是希望在机器学习训练时,机器学习模型能在新样本上很好的表现。过拟合时,通常是因为模型过于复杂,学习器把训练样本学得“太好了”,很可能把一些训练样本自身的特性当成了所有潜在样本的共性了,这样一来模型的泛化性能就下降了。我们形象的打个比方吧,你考试复习,复习题都搞懂了,但是一到考试就不会了,那是过拟合。

2.对比drop前后的loss

2.1.导入必要模块

import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer  #处理标签为二进制

2.2.加载数据

digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)   #转化标签为二进制形式
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)

2.3.定义添加层函数

def add_layer(inputs, in_size, out_size, layer_name, activation_function=None, ):Weights = tf.Variable(tf.random_normal([in_size, out_size]))   #系数biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, ) #偏置Wx_plus_b = tf.matmul(inputs, Weights) + biases# here to dropoutWx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)if activation_function is None:outputs = Wx_plus_belse:outputs = activation_function(Wx_plus_b, )tf.summary.histogram(layer_name + '/outputs', outputs)return outputs

2.4.损失函数与优化器

keep_prob = tf.placeholder(tf.float32)
xs = tf.placeholder(tf.float32, [None, 64])  # 8x8
ys = tf.placeholder(tf.float32, [None, 10])

这里的keep_prob是保留概率,即我们要保留的结果所占比例,它作为一个placeholder,在run时传入, 当keep_prob=1的时候,相当于100%保留,也就是dropout没有起作用。

添加隐含层和输出层:

l1 = add_layer(xs, 64, 50, 'l1', activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function=tf.nn.softmax)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))  # 交叉熵函数损失函数
tf.summary.scalar('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #优化函数sess = tf.Session()
merged = tf.summary.merge_all()
# summary writer goes in here
train_writer = tf.summary.FileWriter("logs/train", sess.graph)
test_writer = tf.summary.FileWriter("logs/test", sess.graph)

2.5.训练

if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:init = tf.initialize_all_variables()
else:init = tf.global_variables_initializer()
sess.run(init)
for i in range(500):# here to determine the keeping probabilitysess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 0.5})if i % 50 == 0:# record losstrain_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test, keep_prob: 1})train_writer.add_summary(train_result, i)test_writer.add_summary(test_result, i)

Tensorflow——Dropout(解决过拟合问题)相关推荐

  1. Dropout解决过拟合问题

    Dropout解决过拟合问题 晓雷 6 个月前 这篇也属于 <神经网络与深度学习总结系列>,最近看论文对Dropout这个知识点有点疑惑,就先总结以下.(没有一些基础可能看不懂,以后还会继 ...

  2. Dropout抑制过拟合

    dropout 可以看出,网络中的的一层中的某些神经元被丢弃,网络变得简单了一些. Dropout解决过拟合的原因 (1)取平均的作用 (2)减少神经元之间复杂的共适应关系: 因为dropout程序导 ...

  3. 深度学习-Tensorflow2.2-深度学习基础和tf.keras{1}-优化函数,学习速率,反向传播,网络优化与超参数选择,Dropout 抑制过拟合概述-07

    多层感知器: 优化使用梯度下降算法 学习速率 学习速率选取原则 反向传播 SGD RMSprop Adam learning_rate=0.01 # -*- coding: utf-8 -*- # - ...

  4. TensorFlow Dropout

    TensorFlow Dropout 图 1:来自论文 "Dropout: A Simple Way to Prevent Neural Networks from Overfitting& ...

  5. TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线

    TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线 目录 输出结果 设计代码 输出结果 设计代码 ...

  6. 过拟合解决方法python_《python深度学习》笔记---4.4、过拟合与欠拟合(解决过拟合常见方法)...

    <python深度学习>笔记---4.4.过拟合与欠拟合(解决过拟合常见方法) 一.总结 一句话总结: 减小网络大小 添加权重正则化 添加 dropout 正则化 1.机器学习的根本问题? ...

  7. 【机器学习】L1正则化与L2正则化详解及解决过拟合的方法

    在详细介绍L1与L2之前,先讲讲正则化的应用场景. 正则化方法:防止过拟合,提高泛化能力 所谓过拟合(over-fitting)其实就是所建的机器学习模型或者是深度学习模型在训练样本中表现得过于优越, ...

  8. 使用学习曲线(Learning curve),判断机器学习模型过拟合、欠拟合,与解决过拟合、欠拟合的问题

    文章目录 1.基本概念 过拟合与欠拟合 根据学习曲线判断过拟合.欠拟合 2.示例代码:绘制学习曲线 3.解决过拟合.欠拟合 解决过拟合 解决欠拟合 4. 过拟合.欠拟合的深层理解 1.基本概念 过拟合 ...

  9. Python 机器学习——解决过拟合的方法

    四种常用的解决过拟合(tackle overfitting)的方法,以思维导图的方式展示如下: 对神经网络而言,"choose a simpler model with fewer para ...

最新文章

  1. 在vue项目中引入高德地图及其UI组件的方法
  2. 将python编程为c_使用Cython为Python编写更快的C扩展
  3. tool 之gvim 64位安装流程
  4. 处理字符串_4_计算某个字符出现的次数
  5. 机器学习实战-神经网络-21
  6. 8数据提供什么掩膜产品_工业轨式1-8路RS485数据(MODBUS RTU协议)厂家产品说明...
  7. Redis Java调用
  8. 怎么在win7链接无线网络连接服务器,Win7系统网络连接一直显示正在获取网络地址但是连不上网解决方法...
  9. java形状函数_java基础:10.4 Java FX之形状
  10. 一步步学习SPD2010--附录A--SPD工作流条件和操作(4)--列表操作
  11. 国内优秀论坛之大汇集
  12. Python 量化投资实战教程(5) — A股回测KDJ 策略
  13. H264解码之PES流解析
  14. 考华为认证需要准备什么
  15. 01改变世界:没有计算器的日子怎么过——手动时期的计算工具
  16. python爬虫qq音乐歌词_10、 在QQ音乐中爬取某首歌曲的歌词
  17. WiFi模块种类二:单WiFi功能双频WiFi模块
  18. 关于虚拟机非正常关机的解决方案
  19. 学习java随堂练习-20220614
  20. 《C++ Concurrency in Action》笔记28 无锁并行数据结构

热门文章

  1. 无法将mysql服务器连接到_无法从java连接到mysql服务器
  2. 解决NSTextContainer分页时文本截断问题
  3. 常用cmd命令(持续更新)
  4. 2018福建省考c语言成绩查询,福建省公务员考试录用网成绩查询系统:2019福建省考分数查询入口...
  5. SFS2X 例子(java 扩展加as 客户端)
  6. Android程序对不同手机屏幕分辨率自适应的总结
  7. ConcurrentHashMap与HashTable的区别
  8. ORM框架之Spring Data JPA(二)spring data jpa方式的基础增删改查
  9. java case用法_Go语言 | goroutine不只有基础的用法,还有这些你不知道的操作
  10. Qt5 for linux离线安装工具下载地址