文章目录

  • 1.神经网络的正向传播与反向传播
  • 2.自动求导机制
  • 3.案例1:模型自动求导
  • 4.案例2:使用GradientTape自定义训练模型
  • 5.案例3:使用GradientTape自定义训练模型(加入评估函数)

掌握神经网络正向传播与反向传播,tf.GradientTape求导机制以及自定义模型训练操作

1.神经网络的正向传播与反向传播

需要具体了解

2.自动求导机制

梯度求解利器:tf.GradientTape
GradientTape是eager模式下计算梯度用的GradientTape是eager模式下计算梯度用的watch(tensor)作用:确保某个tensor被tape追踪参数:tensor: 一个Tensor或者一个Tensor列表gradient(target, sources)作用:根据tape上面的上下文来计算某个或者某些tensor的梯度参数target: 被微分的Tensor或者Tensor列表,你可以理解为经过某个函数之后的值sources: Tensors 或者Variables列表(当然可以只有一个值). 你可以理解为函数的某个变量返回:一个列表表示各个变量的梯度值,和source中的变量列表一一对应,表明这个变量的梯度。上面的例子中的梯度计算部分可以更直观的理解这个函数的用法。

举例:计算y= x2在x=3时的导数**

import tensorflow as tfx = tf.constant(3.0)
with tf.GradientTape() as g:# watch作用:确保每个tensor被tape追踪g.watch(x)y = x*x
# 求导
# gradient作用是:根据tape来计算某个或某些tensor的梯度,即y导 = 2*x = 2*3 =6
dy_dx = g.gradient(y,x)
print(dy_dx)
#  tf.Tensor(6.0, shape=(), dtype=float32)
loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()with tf.GradientTape() as tape:predictions = model(data)      # 正向传播loss = loss_object(labels,predictions)gradients = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradients,model.trainable_variables))# 一般在网络中使用时,不需要显示调用watch函数,使用默认设置,GradientTape会监控可训练变量# apply_gradients(grads_and_vars,name=None)
# 作用:把计算出来的梯度更新到变量上面
# 参数含义:
# grads_and_vars:(gradient,variable)对的列表
# name:操作名

3.案例1:模型自动求导



4.案例2:使用GradientTape自定义训练模型



运行结果:

5.案例3:使用GradientTape自定义训练模型(加入评估函数)

让我们将metric添加到组合中。下面可以在从头开始编写的训练循环中随时使用内置指标
(或编写的自定义指标)。流程如下:- 在循环开始时初始化metrics- metric.update_state():每batch之后更新- metric.result():需要显示metrics的当前值时调用- metric.reset_states():需要清除metrics状态时重置(通常在每个epoch的结尾)



进行几个epoch运行训练循环

model = MyModel(num_classes=10)
epochs = 3
for epoch in range(epochs):print('Start of epoch %d'%(epoch,))# 遍历数据集的batch_sizefor step,(x_batch_train,y_batch_train) in enumerate(train_dataset):# 一个batchwith tf.GradientTape() as tape:logits = model(x_batch_train)loss_value = loss_fn(y_batch_train,logits)grads = tape.gradient(loss_value,model.trainable_weights)optimizer.apply_gradients(zip(grads,model.trainable_weights))# 更新训练集的metricstrain_acc_metric(y_batch_train,logits)# 每200 batches打印一次loss值if step % 200 == 0:print('Training loss (for onr batch) at step %s:%s'%(step,float(loss_value)))print('Seen so far :%s sample'%((step+1) * 64))# 在每一个epoch结束时显示metricstrain_acc = train_acc_metric.result()print('Training acc over epoch: %s'%(float(train_acc),))# 在每个epoch结束时重置训练指标train_acc_metric.reset_states()# 在每个epoch结束时运行一个验证集for x_batch_val,y_batch_val in val_dataset:val_logits = model(x_batch_val)# 更新验证集metricsval_acc_metric(y_batch_val,val_logits)val_acc = val_acc_metric.result()print('Validation acc : %s'%(float(val_acc),))# 重置val_acc_metric.reset_states()# 在每个epoch结束时运行一个测试集for x_batch_test,y_batch_test in test_dataset:test_logits = model(x_batch_test)# 更新测试集metricstest_acc_metric(y_batch_test,test_logits)test_acc = test_acc_metric.result()print('Test acc : %s'%(float(test_acc)))# 重置test_acc_metric.reset_states()


深度学习3-tensorflow2.0模型训练-自定义模型训练相关推荐

  1. 吴恩达深度学习之tensorflow2.0 课程

    课链接 吴恩达深度学习之tensorflow2.0入门到实战 2019年最新课程 最佳配合吴恩达实战的教程 代码资料 自己取 链接:https://pan.baidu.com/s/1QrTV3KvKv ...

  2. 第3章(3.11~3.16节)模型细节/Kaggle实战【深度学习基础】--动手学深度学习【Tensorflow2.0版本】

    项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...

  3. 第0章【序】--动手学深度学习【Tensorflow2.0版本】

    项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 这个项目将<动手学深度学习> 原书中MXNet代码实现改为Tenso ...

  4. 第1章【深度学习简介】--动手学深度学习【Tensorflow2.0版本】

    项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...

  5. 《自然语言处理实战入门》 深度学习组件TensorFlow2.0 ---- 文本数据建模流程

    文章大纲 一,准备数据 二,定义模型 三,训练模型 四,评估模型 五,使用模型 六,保存模型 参考文献 文本处理的建模流程,使用清华发布的新闻分类数据集: 中文文本分类数据集THUCNews THUC ...

  6. 【深度学习与tensorflow2.0实战】(网易云课堂)13-GAN

    本文目录 GAN原理 纳什均衡-D.G EM距离 GAN实战 **gan.py** dataset.py GAN原理 Having Fun ▪ https://reiinakano.github.io ...

  7. 【深度学习】深入浅出数字图像处理基础(模型训练的先修课)

    [深度学习]深入浅出数字图像处理基础(模型训练的先修课) 文章目录 1 图像的表示 2 图像像素运算 3 采样与量化3.1 采样3.2 量化3.3 图像上采样与下采样 4 插值算法分类 5 什么是池化 ...

  8. 【深度学习】Keras加载权重更新模型训练的教程(MobileNet)

    [深度学习]Keras加载权重更新模型训练的教程(MobileNet) 文章目录 1 重新训练 2 keras常用模块的简单介绍 3 使用预训练模型提取特征(口罩检测) 4 总结 1 重新训练 重新建 ...

  9. Opencv4.x深度学习之Tensorflow2.3框架训练模型

    Opencv4.x深度学习之Tensorflow2.3框架训练模型 第一部分:开发环境 1.Win10 x64 2.Opencv-Python 3.Tensorflow 2.3.0 CPU 第二部分: ...

  10. win10安装yolox,训练自定义模型,使用tensorrt部署全流程

    系统环境:win10.cuda10.2.cudnn8.2 一.采集数据 有2段视频,先使用ffmpeg对视频进行抽帧,由于视频比较长,所以每隔5秒抽取1张图片. ffmpeg -i light000. ...

最新文章

  1. 科学:揭示自由意志的生物学本质
  2. js经典试题之ES6
  3. day11 - 15(装饰器、生成器、迭代器、内置函数、推导式)
  4. 图片→矩阵→空间→坍缩-→质点--用神经网络将空间坍缩成粒子的实验数据汇总
  5. ChartDirector资料小结
  6. spring编程式事务控制
  7. iOS中安全结束 子线程 的方法
  8. ASP.NET MVC框架(第一部分)
  9. 七秘诀工作效率与薪水翻番
  10. 页面内部DIV让点击外部DIV 事件不发生(阻止冒泡事件)
  11. BUU BRUTE 1
  12. IAP固件升级原理及实现详解
  13. GitHub 的 Pull Request 是指什么意思?
  14. PayPay migrated the core payment database from Aurora to TiDB
  15. 外贸软件如何提升进出口公司业绩 实现降本增效
  16. 【智能驾驶】最全、最强的无人驾驶技术学习路线
  17. SpringBoot+Vue项目线上教学平台
  18. 既生Java,何生Groovy?
  19. 华为DHCP Snooping原理及其实验配置
  20. matlab二重指针,VC++中函数返回数组指针或者带指针的结构体的编译方式是否可取? - 程序语言 - 小木虫 - 学术 科研 互动社区...

热门文章

  1. [转]JAVA与.NET DES加密解密
  2. PPC丢失后,手机信息如何保护?(C#)
  3. 库表操作 - 存储引擎
  4. Java 数组类型转字符串类型
  5. Java基础---数组
  6. python-time、datetimme模块
  7. 最常用的css垂直居中方法
  8. 如何养成周回顾习惯的回复
  9. HDU-1069 Monkey and Banana 动态规划
  10. 机器学习概念笔记(1)——混淆矩阵、Precision、Recall、F-score