文章目录

  • 1. 自定义模型
  • 2. 学习流程

学习于:简单粗暴 TensorFlow 2

1. 自定义模型

  • 重载 call() 方法,pytorch 是重载 forward() 方法
import tensorflow as tf
X = tf.constant([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
y = tf.constant([[10.0],[20.0]])class Linear(tf.keras.Model):def __init__(self):super().__init__()self.dense = tf.keras.layers.Dense(units=1,activation=None,kernel_initializer=tf.zeros_initializer(),bias_initializer=tf.zeros_initializer())def call(self, input): # 重载 call 方法output = self.dense(input)return outputmodel = Linear()# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)for i in range(100):with tf.GradientTape() as tape: # 梯度记录器y_pred = model(X)loss = tf.reduce_mean(tf.square(y_pred-y)) # 损失grads = tape.gradient(loss, model.variables) # 求导# 更新参数optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))

2. 学习流程

  • 加载手写数字数据集
class MNistLoader():def __init__(self):data = tf.keras.datasets.mnist# 加载数据(self.train_data, self.train_label),(self.test_data, self.test_label) = data.load_data()# 扩展维度,灰度图1通道 [batch_size, 28, 28, chanels=1]self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)self.train_label = self.train_label.astype(np.int32)self.test_label = self.test_label.astype(np.int32)# 样本个数self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从训练集里随机取出 batch_size 个样本idx = np.random.randint(0, self.num_train_data, batch_size)return self.train_data[idx, :], self.train_label[idx]
  • 定义模型
# 自定义多层感知机模型
class MLPmodel(tf.keras.Model):def __init__(self):super().__init__()# 除第一维以外的维度展平self.flatten = tf.keras.layers.Flatten()self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')self.dense2 = tf.keras.layers.Dense(units=10)def call(self, input):x = self.flatten(input)x = self.dense1(x)x = self.dense2(x)output = tf.nn.softmax(x)return output
  • 训练
# 参数
num_epochs = 5
batch_size = 50
learning_rate = 1e-4# 模型实例
mymodel = MLPmodel()
# 数据加载
data_loader = MNistLoader()
# adam 优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)num_batches = int(data_loader.num_train_data//batch_size * num_epochs)
# 训练
for idx in range(num_batches):# 取出数据X,y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape: # 梯度记录y_pred = mymodel(X) # 预测# 计算交叉熵损失loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch {}, loss {}".format(idx, loss.numpy()))# 计算梯度grads = tape.gradient(loss, mymodel.variables)# 更新参数optimizer.apply_gradients(grads_and_vars=zip(grads, mymodel.variables))
  • 预测
# 评估标准
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
# 预测
for idx in range(num_batches):# 数据区间start, end = idx*batch_size, (idx+1)*batch_size# 放入模型,预测y_pred = mymodel.predict(data_loader.test_data[start : end])# 统计更新 预测信息    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start:end],y_pred=y_pred)
print("test 准确率:{}".format(sparse_categorical_accuracy.result()))
# test 准确率:0.9455000162124634

TensorFlow 2.0 - 自定义模型、训练过程相关推荐

  1. 再战FGM!Tensorflow2.0 自定义模型训练实现NLP中的FGM对抗训练 代码实现

    TF版本2.2及以上 def creat_FGM(epsilon=1.0):@tf.function def train_step(self, data):'''计算在embedding上的gradi ...

  2. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  3. TF学习——TF之TFOD:基于TFOD AP训练ssd_mobilenet预模型+faster_rcnn_inception_resnet_v2_模型训练过程(TensorBoard监控)全记录

    TF学习--TF之TFOD:基于TFOD AP训练ssd_mobilenet预模型+faster_rcnn_inception_resnet_v2_模型训练过程(TensorBoard监控)全记录 目 ...

  4. 深度学习模型训练过程

    深度学习模型训练过程 一.数据准备 基本原则: 1)数据标注前的标签体系设定要合理 2)用于标注的数据集需要无偏.全面.尽可能均衡 3)标注过程要审核 整理数据集 1)将各个标签的数据放于不同的文件夹 ...

  5. 【深度学习】Tensorboard可视化模型训练过程和Colab使用

    [深度学习]Tensorboard可视化模型训练过程和Colab使用 文章目录 1 概述 2 手撸代码实现 3 Colab使用3.1 详细步骤3.2 Demo 4 总结 1 概述 在利用TensorF ...

  6. 【深度学习】模型训练过程可视化思路(可视化工具TensorBoard)

    [深度学习]模型训练过程可视化思路(可视化工具TensorBoard) 文章目录 1 TensorBoard的工作原理 2 TensorFlow中生成log文件 3 启动TensorBoard,读取l ...

  7. Keras深度学习实战(1)——神经网络基础与模型训练过程详解

    Keras深度学习实战(1)--神经网络基础与模型训练过程详解 0. 前言 1. 神经网络基础 1.1 简单神经网络的架构 1.2 神经网络的训练 1.3 神经网络的应用 2. 从零开始构建前向传播 ...

  8. 模型训练过程中产生NAN的原因分析

    模型训练过程中产生NAN的原因分析 在模型的训练过程中发现,有时在经过多轮训练后loss会突然变为nan.loss变为nan也就使权重更新后的网络里的参数变为了nan,这样就使整个训练无法再进行下去了 ...

  9. tensorflow 1.14 ssd_mobilenet_v1 模型训练

    tensorflow 1.14 ssd_mobilenet_v1 模型训练 1 工具版本 序号 软件名称 版本 安装命令 1 操作系统 ubuntu 18.04 2 python 3.6.9 3 te ...

最新文章

  1. 感受 lambda 之美!
  2. linux php不支持crypt,(PHP)如何在CRYPT_BLOWFISH中使用crypt()?
  3. 微信小程序正确的异步request请求,根据经纬度获取地理位置信息
  4. java去掉重复字符_Java实现去掉字符串重复字母的方法示例
  5. 一位准程序员对软件行业的8个问题
  6. javaScript技巧表:单提交验证类[转载]
  7. linux hdfs授予文件夹权限,修修改hdfs上的文件所属用户、所属组等读写执行控制权限...
  8. ai中如何插入签名_如何在PDF文件中插入手写签名?手把手教会你,轻松设置签名...
  9. T3.2是什么级别?
  10. windows自带桌面远程控制
  11. wordpress网站被黑后怎么解决
  12. 桌面支持--skype登陆不上
  13. scratch成语接龙 电子学会图形化编程scratch等级考试四级真题和答案解析2021-9
  14. 使用obi fluid进行洪水模拟,持续更新~
  15. 微信朋友圈十周年,你设置了三天可见吗?
  16. Java大端字节和小端字节
  17. vue全套教程(实操)
  18. 查看自己的GitHub地址
  19. Excel VBA工程密码破解程序 (绝对可以破解)
  20. linux系统编程:系统函数system

热门文章

  1. 机器学习算法之支持向量机 SVM
  2. tutte定理证明hall定理_深入浅出|中心极限定理(Central Limit Theorem)及证明
  3. 8数据提供什么掩膜产品_博硕能为你提供什么产品?
  4. iOS当中的设计模式
  5. 第七季1:MP4文件格式解析
  6. bootstrap-daterangepicker插件运用
  7. [BZOJ 4025] 二分图
  8. .NET面试题系列(七)IIS
  9. html总结:文本框填满表格
  10. postgresql数据库安装及简单操作