在使用机器学习训练一个模型的时候,训练出来的模型很可能会出现两个常见的问题,过拟合和欠拟合从而导致模型对未知数据的预测能力会有所下降。欠拟合可能是由于模型的设计过于的简单,提取出来的特征不足以来刻画问题的趋势,比如说,使用一个模型来预测房价,影响房价的因素有房子所处位置、房子大小、卧室的数量、是否靠近医院、是否靠近学校等等。如果,我们在训练模型的时候仅仅只是用了房子的大小、房子所处的位置两个特征来训练一个预测房价的模型,这就会导致欠拟合的发生,导致模型预测的结果不准确。过拟合指的是,模型对于训练集的拟合程度非常的好,甚至可以达到对于训练数据的损失值为0,而对于未知数据的预测无法做出可靠的判断,原因是因为过拟合使得模型过度拟合训练数据中的噪音而忽视了问题的整体规律。

一、正则化

正则化:是一种常用的为了避免过度拟合而采用的一种算法。正则化的主要思想是通过在损失函数中加入刻画模型复杂程度的指标,假设模型的损失函数为J(θ),那么我们在使用优化算法来优化损失函数的时候,不是直接优化J(θ),而是优化J(θ)+ λ *R(w)。其中R(w)是指模型的复杂程度,λ表示模型复杂损失在总损失中的比例。需要注意的是这里的θ表示的是一个神经网络中的所有参数,它包括权重和偏置。一般来说,模型复杂度只由其权重(w)来决定。而常用的来刻画模型复杂度的函数R(w)有两种,一种是L1正则化,计算公式如下,以下公式中w都是从w1到wn:

还有一种L2正则化,计算公式如下:

无论是L1正则化还是L2正则化,两者的目的都是通过限制权重的大小,来使得模型不能任意的拟合训练数据中的随机噪音。两者正则化也有很大的区别。L1正则化会让参数变得稀疏,而L2正则化则不会产生这个问题。参数稀疏是指会有更多的参数变为0,其实就相当于达到了类似于特征选取的功能。L2正则化不会产生稀疏的原因在于,当参数很小的时候,参数的平方基本上就趋近于0,可以忽略了,而模型不会进一步将这个参数调整为0。L1正则化的计算公式不可导,L2正则化可导。而在优化模型的时候需要计算损失函数的偏导数,所以对于这种情况L2比较合适。在实践中,也可以将L1正则化和L2正则化同时使用:

二、TensorFlow中的正则化

在简单的神经网络中,加入正则化来计算损失函数还是比较容易的。当,神经网络变得非常复杂(层数很多)的时候,那么在损失函数中加入正则化的就会变得非常的复杂,使得损失函数的定义变得很长,从而还会导致程序的可读性变差。而且还有可能,当神经网络变得复杂的时候,定义网络结构的部分和计算损失函数的部分不在同一个函数中,这样就会使得计算损失函数不方便。TensorFlow提供了集合的方式,通过在计算图中保存一组实体,来解决这一类问题。

import tensorflow as tf
from numpy.random import RandomState#获取一层神经网络的权重,并将权重的L2正则化损失加入到集合中
def get_weight(shape,lamda):#定义变量var = tf.Variable(tf.random_normal(shape=shape),dtype=tf.float32)#将变量的L2正则化损失添加到集合中tf.add_to_collection("losses",tf.contrib.layers.l2_regularizer(lamda)(var))return varif __name__=="__main__":#定义输入节点x = tf.placeholder(tf.float32,shape=(None,2))#定义输出节点y_ = tf.placeholder(tf.float32,shape=(None,1))#定义每次迭代数据的大小batch_size = 8#定义五层神经网络,并设置每一层神经网络的节点数目layer_dimension = [2,10,10,10,1]#获取神经网络的层数n_layers = len(layer_dimension)#定义神经网络第一层的输入cur_layer = x#当前层的节点个数in_dimension = layer_dimension[0]#通过循环来生成5层全连接的神经网络结构for i in range(1,n_layers):#定义神经网络上一层的输出,下一层的输入out_dimension = layer_dimension[i]#定义当前层中权重的变量,并将变量的L2损失添加到计算图的集合中weight = get_weight([in_dimension,out_dimension],0.001)#定义偏置项bias = tf.Variable(tf.constant(0.1,shape=[out_dimension]))#使用RELU激活函数cur_layer = tf.nn.relu(tf.matmul(cur_layer,weight) + bias)#定义下一层神经网络的输入节点数in_dimension = layer_dimension[i]#定义均方差的损失函数mse_loss = tf.reduce_mean(tf.square(y_ - cur_layer))#将均方差孙函数添加到集合tf.add_to_collection("losses",mse_loss)#获取整个模型的损失函数,tf.get_collection("losses")返回集合中定义的损失#将整个集合中的损失相加得到整个模型的损失函数loss = tf.add_n(tf.get_collection("losses"))

TensorFlow优化模型之正则化相关推荐

  1. 因果模型五:用因果的思想优化风控模型——因果正则化评分卡模型

    因果模型五:用因果的思想优化风控模型--因果正则化评分卡模型 一.模型中的因果和相关 二.不可知样本选择偏差 三.因果推断 四.因果与评分卡的融合 五.模型效果评估 5.1 人工合成数据效果测试 5. ...

  2. TensorFlow pb模型修改和优化

    TensorFlow 模型训练完成后,通常会通过frozen过程保存一个最终的pb模型.保存的pb模型是以GraphDef数据结构保存的,可以序列化保存为二进制pb模型或者文本pbtxt模型.Grap ...

  3. Tensorflow【实战Google深度学习框架】—使用 TensorFlow 实现模型

    文章目录 1.建立模型(Model) 2.使用 TensorFlow 实现模型 3.使用 TensorFlow 训练模型 1.建立模型(Model) 如下为我们进行某项实验获得的一些实验数据: 我们将 ...

  4. tensorflow机器学习模型的跨平台上线

    在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法 ...

  5. 谷歌I/O走进TensorFlow开源模型世界:从图像识别到语义理解

    谷歌I/O走进TensorFlow开源模型世界:从图像识别到语义理解 2017-05-23 16:13:11    TensorFlow    2 0 0 一年一度的谷歌开发者大会 Google I/ ...

  6. 用浏览器训练Tensorflow.js模型的18个技巧(上)

    摘要: 送你18个训练Tensorflow.js模型的小技巧! 在移植现有模型(除tensorflow.js)进行物体检测.人脸检测.人脸识别后,我发现一些模型不能以最佳性能发挥.而tensorflo ...

  7. tensorflow保存模型和加载模型的方法(Python和Android)

    tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...

  8. tensorflow打印模型结构_社区分享 | 详解 TensorFlow 中 Placement 的最后一道防线 — Placer 算法...

    本文作者王思宇,阿里巴巴算法专家,从事深度学习算法平台建设,TensorFlow 分布式架构设计与大规模分布式性能优化工作,开源 TensorFlow 项目 contributor. 本文转自:互联网 ...

  9. [tensorflow] 线性回归模型实现

    在这一篇博客中大概讲一下用tensorflow如何实现一个简单的线性回归模型,其中就可能涉及到一些tensorflow的基本概念和操作,然后因为我只是入门了点tensorflow,所以我只能对部分代码 ...

  10. 【TVM帮助文档学习】使用TVMC编译和优化模型

    本文翻译自Compiling and Optimizing a Model with TVMC - tvm 0.9.dev0 documentation 在本节中,我们将使用TVM命令行驱动程序TVM ...

最新文章

  1. SAP PP 为工单确认时自动做收货的设置
  2. 某程序员面试支付宝P7,面试已通过,却因为背调没过!再进阿里失败!阿里背调,到底调啥?...
  3. 在eclipse中使用JDBC连接MySQL5.7.24
  4. 代码统计工具1.1版本技术文档
  5. 安卓 spinner下拉框 做模糊查询_如何用一张图来做全年/去年的部门离职率动态对比...
  6. rabbitmq简单收发服务搭建
  7. 带着灵魂去旅行的骑者-重新认识自我
  8. 猪行天下之Python基础——1.1 Python开发环境搭建
  9. GPU Gems1 - 24 高质量的过滤
  10. OpenCV:分离图像和视频的RGB通道
  11. java -D參数简化增加多个jar【简化设置classpath】
  12. 注意力机制可解释吗?这篇ACL 2019论文说……
  13. css position relative absolute fixed
  14. “一云多Region”究竟能为企业解决什么问题?
  15. 适用于低配机器,从USB摄像头拉H264流的Qt播放器
  16. 罗翔老师转谈记录,不同认知出发//心之所向,素履以往,生如逆旅,一苇以航。
  17. 【JavaScript算法】---希尔排序(转载自我的老师 Alley-巷子)
  18. 卓越电脑定时关机软件
  19. Ubuntu因为内存问题卡死解决方案
  20. 2345流氓软件让浏览器打开跳转到它的导航,并且自动下载安装2345浏览器

热门文章

  1. 显示日历信息的命令 cal 和 ncal
  2. 怎么计算crc16校验数据的校验码
  3. 专为存储设计的LRC编码
  4. 机械振动的傅里叶变化分析技术
  5. MATLAB机械动力分析,用MATLAB实现机械动力学
  6. 嵌入式系统应用开发—FPGA开发板—一位全加器仿真测试
  7. 星际译王,金山词霸,有道词典,词库下载 1
  8. oracle服务商前几名,oracle厂商服务有哪几种
  9. 计算机高程知识点,数字测图原理与方法知识点
  10. win10计算机百度云盘,windows10系统中怎样安装百度云盘?