一:不使用梯度裁剪

#网络搭建和模型训练
import tensorflow as tf
from tensorflow.keras import layers,optimizers,datasets,Sequential
#学习率
learning_rate=0.01
#利用Sequential容器封装三个网络层,前网络层的输出默认作为下一层的输入
#三个非线性层的嵌套模型
model=Sequential([layers.Dense(256,activation='relu'),#隐藏层1layers.Dense(128,activation='relu'),#隐藏层2layers.Dense(10,activation='relu')])#输出层,输出节点数为10
#加载MNIST数据集
(x,y),(x_test,y_test)=datasets.mnist.load_data()
x=2*tf.convert_to_tensor(x,dtype=tf.float32)/255.-1
y=tf.convert_to_tensor(y,dtype=tf.int32)
print('x:',x.shape,'y:',y.shape,'x_test:',x_test.shape,'y_test:',y_test.shape)
#x: (60000, 28, 28) y: (60000,) x_test: (10000, 28, 28) y_test: (10000,)
#tf.keras.optimizers.SGD(learning_rate=0.01) 声明了一个梯度下降 优化器 (Optimizer),其学习率为 0.01。
#优化器可以帮助我们根据计算出的求导结果更新模型参数,从而最小化某个特定的损失函数,具体使用方式是调用其 apply_gradients() 方法。
optimizer = tf.keras.optimizers.SGD(learning_rate)
#构建梯度记录环境
with tf.GradientTape() as tape:#打平操作[b,28,28]->[b,784]x=tf.reshape(x,[-1,28*28])#step1,得到模型输出[b,784]=>[b,10]out=model(x)y_onehot=tf.one_hot(y,depth=10)#计算差的平方和,[b,10]loss=tf.square(out-y_onehot)#计算每个样本的平均误差loss=tf.reduce_sum(loss)/x.shape[0]print("loss: ",loss)
#自动求导函数tape.gradient(loss,model.trainable_variables),计算参数的梯度[w1,b1,w2,b2,w3,b3]
grades=tape.gradient(loss,model.trainable_variables)
#w'=w-lr*grades
optimizer.apply_gradients(zip(grades,model.trainable_variables))
print('w:{}\n b:{}'.format(grades[0],grades[1]))

二:使用梯度裁剪:
通过梯度剪裁可以较大程度上抑制梯度爆炸现象。

import tensorflow as tf
import numpy as np
import tensorflow as keras
from tensorflow.keras import layers,Sequential,losses,optimizers,datasets
batchsz=128
learn_rate=0.01(x,y),(x_test,y_test)=datasets.mnist.load_data()
x=2*tf.convert_to_tensor(x,dtype=tf.float32)/255.-1
y=tf.convert_to_tensor(y,dtype=tf.int32)
model=Sequential([layers.Dense(256,activation='relu'),layers.Dense(128,activation='relu'),layers.Dense(10,activation='relu')
])
optimizer=tf.keras.optimizers.SGD(learn_rate)
criteon=losses.CategoricalCrossentropy(from_logits=True)#创建损失函数的类,在实际计算时直接调用类实例即可
#创建梯度记录器
with tf.GradientTape() as tape:x=tf.reshape(x,[-1,28*28])logits=model(x)#前向传播y_onehot=tf.one_hot(y,depth=10)loss=criteon(logits,y_onehot)#print("loss: ",loss)#计算梯度值
grades=tape.gradient(loss,model.trainable_variables)
#使用全局梯度裁剪
grades,_=tf.clip_by_global_norm(grades,25)
optimizer.apply_gradients(zip(grades,model.trainable_variables))
print('w:{}\n b:{}'.format(grades[0],grades[1]))


从图片中可以看出,使用了梯度裁剪和不使用梯度裁剪区别还是挺大的!注意,观察两个图片的数据w和b的值之间的差异,使用了梯度裁剪之后的w和b值是比不使用梯度裁剪的值要小的。(当然这个每次的运行结果得到的loss结果可能是不同的)

不使用梯度裁剪和使用梯度裁剪的对比(tensorflow)相关推荐

  1. 解决 “梯度爆炸” 的方法 - 梯度裁剪

    梯度裁剪 一.什么是梯度爆炸 二.梯度裁剪 三.如何选择超参数c 四.框架中的实现 梯度裁剪是解决梯度爆炸的一种简单高效的方法,并且梯度裁剪可以应用于所有神经网络的训练中(任何可能发生梯度爆炸的训练过 ...

  2. 随机梯度下降、批量梯度下降、小批量梯度下降分类是什么?有什么区别?batch_size的选择如何实施、有什么影响?

    随机梯度下降.批量梯度下降.小批量梯度下降分类是什么?有什么区别?batch_size的选择如何实施.有什么影响? 目录

  3. 图像水平梯度和竖直梯度代码_20行代码发一篇NeurIPS:梯度共享已经不安全了

    整理 | 夕颜,Jane 出品 | AI科技大本营(ID:rgznai100) [导读]12 月 8 日-14 日,NeurIPS 2019 在加拿大温哥华举行,和往常一样,今年大会吸引了数万名专家参 ...

  4. 大白话5分钟带你走进人工智能-第十一节梯度下降之手动实现梯度下降和随机梯度下降的代码(6)...

                                第十一节梯度下降之手动实现梯度下降和随机梯度下降的代码(6) 我们回忆一下,之前咱们讲什么了?梯度下降,那么梯度下降是一种什么算法呢?函数最优化 ...

  5. 【数据挖掘】神经网络 后向传播算法 ( 梯度下降过程 | 梯度方向说明 | 梯度下降原理 | 损失函数 | 损失函数求导 | 批量梯度下降法 | 随机梯度下降法 | 小批量梯度下降法 )

    文章目录 I . 梯度下降 Gradient Descent 简介 ( 梯度下降过程 | 梯度下降方向 ) II . 梯度下降 示例说明 ( 单个参数 ) III . 梯度下降 示例说明 ( 多个参数 ...

  6. 梯度下降和随机梯度下降为什么能下降?

     梯度下降和随机梯度下降为什么能下降? 标签: 深度学习梯度下降SGD 2016-02-22 19:19 663人阅读 评论(1) 收藏 举报 本文章已收录于: 分类: Deep Learning ...

  7. 梯度、梯度下降,随机梯度下降

    一.梯度gradient http://zh.wikipedia.org/wiki/%E6%A2%AF%E5%BA%A6 在标量场f中的一点处存在一个矢量G,该矢量方向为f在该点处变化率最大的方向,其 ...

  8. 机器学习算法(优化)之一:梯度下降算法、随机梯度下降(应用于线性回归、Logistic回归等等)...

    本文介绍了机器学习中基本的优化算法-梯度下降算法和随机梯度下降算法,以及实际应用到线性回归.Logistic回归.矩阵分解推荐算法等ML中. 梯度下降算法基本公式 常见的符号说明和损失函数 X :所有 ...

  9. 梯度下降法和随机梯度下降法的区别

    这几天在看<统计学习方法>这本书,发现 梯度下降法 在 感知机 等机器学习算法中有很重要的应用,所以就特别查了些资料.  一.介绍 梯度下降法(gradient descent)是求解无约 ...

最新文章

  1. JWT(JSON Web Token) Java与.Net简单编码实现
  2. Eclipse搭建Android5.0应用开发环境 “ndk-build”:launchingfailed问题解决
  3. python如何复制一个变量_Python中变量、赋值、浅拷贝、深拷贝
  4. java 左边补0_java 数字左补齐0
  5. JEECG v2与v3两个版本的区别说明
  6. 内推|百度2020春实习-计算机视觉算法研发工程师-北京
  7. 50道pmp历年真题
  8. Linux安装微信、QQ
  9. cxf调用webservice
  10. 测绘专硕要学计算机吗,测绘工程专硕专业介绍_测绘工程非全日制研究生(专业硕士)_125在职研究生...
  11. Figma#1: 图形绘制
  12. 海龟绘图小案例(内含源码)
  13. 抖音:资本、梦想与躁动荷尔蒙裹挟的世界
  14. python斗地主出牌算法_python模拟斗地主发牌
  15. 最新微信小程序反编译破解过程记录
  16. seo优化教程-免费SEO优化详细教程
  17. Spring Security系列教程11--Spring Security核心API讲解
  18. 5分钟理解Iass Pass SasS三种云服务区别
  19. 网络安全工程师的职业前景如何?
  20. CodeCombat进军中国市场,中美少儿编程教育有何差距

热门文章

  1. python之⾯向对象-继承
  2. MySQL数据库分组和聚合函数组合使用
  3. ida 调试中call stack如何打开|修改数值
  4. 防火墙产品原理与应用:NAT支持的特殊协议
  5. 理解卷积神经网络中的输入与输出形状 | 视觉入门
  6. 计算机视觉系统中图像究竟经历了哪些“折磨”
  7. 利用OpenCV建立视差图像
  8. 基于光流的3D速度检测
  9. 数数正方形(ACM/ICPC World Finals)
  10. 雇佣和留住开发人员,打造优秀的团队