深度学习3-tensorflow2.0模型训练-自定义模型训练
文章目录
- 1.神经网络的正向传播与反向传播
- 2.自动求导机制
- 3.案例1:模型自动求导
- 4.案例2:使用GradientTape自定义训练模型
- 5.案例3:使用GradientTape自定义训练模型(加入评估函数)
掌握神经网络正向传播与反向传播,tf.GradientTape求导机制以及自定义模型训练操作
1.神经网络的正向传播与反向传播
需要具体了解
2.自动求导机制
梯度求解利器:tf.GradientTape
GradientTape是eager模式下计算梯度用的GradientTape是eager模式下计算梯度用的watch(tensor)作用:确保某个tensor被tape追踪参数:tensor: 一个Tensor或者一个Tensor列表gradient(target, sources)作用:根据tape上面的上下文来计算某个或者某些tensor的梯度参数target: 被微分的Tensor或者Tensor列表,你可以理解为经过某个函数之后的值sources: Tensors 或者Variables列表(当然可以只有一个值). 你可以理解为函数的某个变量返回:一个列表表示各个变量的梯度值,和source中的变量列表一一对应,表明这个变量的梯度。上面的例子中的梯度计算部分可以更直观的理解这个函数的用法。
举例:计算y= x2在x=3时的导数**
import tensorflow as tfx = tf.constant(3.0)
with tf.GradientTape() as g:# watch作用:确保每个tensor被tape追踪g.watch(x)y = x*x
# 求导
# gradient作用是:根据tape来计算某个或某些tensor的梯度,即y导 = 2*x = 2*3 =6
dy_dx = g.gradient(y,x)
print(dy_dx)
# tf.Tensor(6.0, shape=(), dtype=float32)
loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()with tf.GradientTape() as tape:predictions = model(data) # 正向传播loss = loss_object(labels,predictions)gradients = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradients,model.trainable_variables))# 一般在网络中使用时,不需要显示调用watch函数,使用默认设置,GradientTape会监控可训练变量# apply_gradients(grads_and_vars,name=None)
# 作用:把计算出来的梯度更新到变量上面
# 参数含义:
# grads_and_vars:(gradient,variable)对的列表
# name:操作名
3.案例1:模型自动求导
4.案例2:使用GradientTape自定义训练模型
运行结果:
5.案例3:使用GradientTape自定义训练模型(加入评估函数)
让我们将metric添加到组合中。下面可以在从头开始编写的训练循环中随时使用内置指标
(或编写的自定义指标)。流程如下:- 在循环开始时初始化metrics- metric.update_state():每batch之后更新- metric.result():需要显示metrics的当前值时调用- metric.reset_states():需要清除metrics状态时重置(通常在每个epoch的结尾)
进行几个epoch运行训练循环
model = MyModel(num_classes=10)
epochs = 3
for epoch in range(epochs):print('Start of epoch %d'%(epoch,))# 遍历数据集的batch_sizefor step,(x_batch_train,y_batch_train) in enumerate(train_dataset):# 一个batchwith tf.GradientTape() as tape:logits = model(x_batch_train)loss_value = loss_fn(y_batch_train,logits)grads = tape.gradient(loss_value,model.trainable_weights)optimizer.apply_gradients(zip(grads,model.trainable_weights))# 更新训练集的metricstrain_acc_metric(y_batch_train,logits)# 每200 batches打印一次loss值if step % 200 == 0:print('Training loss (for onr batch) at step %s:%s'%(step,float(loss_value)))print('Seen so far :%s sample'%((step+1) * 64))# 在每一个epoch结束时显示metricstrain_acc = train_acc_metric.result()print('Training acc over epoch: %s'%(float(train_acc),))# 在每个epoch结束时重置训练指标train_acc_metric.reset_states()# 在每个epoch结束时运行一个验证集for x_batch_val,y_batch_val in val_dataset:val_logits = model(x_batch_val)# 更新验证集metricsval_acc_metric(y_batch_val,val_logits)val_acc = val_acc_metric.result()print('Validation acc : %s'%(float(val_acc),))# 重置val_acc_metric.reset_states()# 在每个epoch结束时运行一个测试集for x_batch_test,y_batch_test in test_dataset:test_logits = model(x_batch_test)# 更新测试集metricstest_acc_metric(y_batch_test,test_logits)test_acc = test_acc_metric.result()print('Test acc : %s'%(float(test_acc)))# 重置test_acc_metric.reset_states()
深度学习3-tensorflow2.0模型训练-自定义模型训练相关推荐
- 吴恩达深度学习之tensorflow2.0 课程
课链接 吴恩达深度学习之tensorflow2.0入门到实战 2019年最新课程 最佳配合吴恩达实战的教程 代码资料 自己取 链接:https://pan.baidu.com/s/1QrTV3KvKv ...
- 第3章(3.11~3.16节)模型细节/Kaggle实战【深度学习基础】--动手学深度学习【Tensorflow2.0版本】
项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...
- 第0章【序】--动手学深度学习【Tensorflow2.0版本】
项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 这个项目将<动手学深度学习> 原书中MXNet代码实现改为Tenso ...
- 第1章【深度学习简介】--动手学深度学习【Tensorflow2.0版本】
项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...
- 《自然语言处理实战入门》 深度学习组件TensorFlow2.0 ---- 文本数据建模流程
文章大纲 一,准备数据 二,定义模型 三,训练模型 四,评估模型 五,使用模型 六,保存模型 参考文献 文本处理的建模流程,使用清华发布的新闻分类数据集: 中文文本分类数据集THUCNews THUC ...
- 【深度学习与tensorflow2.0实战】(网易云课堂)13-GAN
本文目录 GAN原理 纳什均衡-D.G EM距离 GAN实战 **gan.py** dataset.py GAN原理 Having Fun ▪ https://reiinakano.github.io ...
- 【深度学习】深入浅出数字图像处理基础(模型训练的先修课)
[深度学习]深入浅出数字图像处理基础(模型训练的先修课) 文章目录 1 图像的表示 2 图像像素运算 3 采样与量化3.1 采样3.2 量化3.3 图像上采样与下采样 4 插值算法分类 5 什么是池化 ...
- 【深度学习】Keras加载权重更新模型训练的教程(MobileNet)
[深度学习]Keras加载权重更新模型训练的教程(MobileNet) 文章目录 1 重新训练 2 keras常用模块的简单介绍 3 使用预训练模型提取特征(口罩检测) 4 总结 1 重新训练 重新建 ...
- Opencv4.x深度学习之Tensorflow2.3框架训练模型
Opencv4.x深度学习之Tensorflow2.3框架训练模型 第一部分:开发环境 1.Win10 x64 2.Opencv-Python 3.Tensorflow 2.3.0 CPU 第二部分: ...
- win10安装yolox,训练自定义模型,使用tensorrt部署全流程
系统环境:win10.cuda10.2.cudnn8.2 一.采集数据 有2段视频,先使用ffmpeg对视频进行抽帧,由于视频比较长,所以每隔5秒抽取1张图片. ffmpeg -i light000. ...
最新文章
- 科学:揭示自由意志的生物学本质
- js经典试题之ES6
- day11 - 15(装饰器、生成器、迭代器、内置函数、推导式)
- 图片→矩阵→空间→坍缩-→质点--用神经网络将空间坍缩成粒子的实验数据汇总
- ChartDirector资料小结
- spring编程式事务控制
- iOS中安全结束 子线程 的方法
- ASP.NET MVC框架(第一部分)
- 七秘诀工作效率与薪水翻番
- 页面内部DIV让点击外部DIV 事件不发生(阻止冒泡事件)
- BUU BRUTE 1
- IAP固件升级原理及实现详解
- GitHub 的 Pull Request 是指什么意思?
- PayPay migrated the core payment database from Aurora to TiDB
- 外贸软件如何提升进出口公司业绩 实现降本增效
- 【智能驾驶】最全、最强的无人驾驶技术学习路线
- SpringBoot+Vue项目线上教学平台
- 既生Java,何生Groovy?
- 华为DHCP Snooping原理及其实验配置
- matlab二重指针,VC++中函数返回数组指针或者带指针的结构体的编译方式是否可取? - 程序语言 - 小木虫 - 学术 科研 互动社区...