模型构建


class Encoder(layers.Layer):def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):super(Encoder, self).__init__(name=name, **kwargs)'''w_init = tf.random_normal_initializer()self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units), dtype="float32"),trainable=True)b_init = tf.zeros_initializer()self.b = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"), trainable=True)'''# 简洁写法self.w = self.add_weight(shape=(input_dim, units), initializer="random_normal", trainable=True)self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)# 可具有不可训练权重self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)# 可以延迟权重创建在得知输出形状后:https://www.tensorflow.org/guide/keras/custom_layers_and_modelsdef call(self, inputs):# ...class Decoder(layers.Layer):def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):super(Decoder, self).__init__(name=name, **kwargs)self.dense_proj = layers.Dense(intermediate_dim, activation="relu")self.dense_output = layers.Dense(original_dim, activation="sigmoid")def call(self, inputs):x = self.dense_proj(inputs)return self.dense_output(x)class VariationalAutoEncoder(keras.Model):def __init__(self,original_dim,intermediate_dim=64,latent_dim=32,name="autoencoder",**kwargs):super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)self.original_dim = original_dimself.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)def call(self, inputs):z_mean, z_log_var, z = self.encoder(inputs)reconstructed = self.decoder(z)# Add KL divergence regularization loss.kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)self.add_loss(kl_loss)return reconstructed

模型训练

# 数据集加载
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)# 模型初始化
model = VariationalAutoEncoder(784, 64, 32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()# 模型训练
for epoch in range(3):for x_batch_train in train_dataset:with tf.GradientTape() as tape:reconstructed = model(x_batch_train)loss = loss_fn(x_batch_train, reconstructed) # Compute reconstruction lossloss += sum(model.losses)  # Add KLD regularization lossgrads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))print("step %d: mean loss = %.4f" % (epoch, loss.numpy()))# 由于模型是 Model 子类化的结果,它具有内置的训练循环。因此,您也可以用以下方式训练它:
model.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
model.fit(x_train, x_train, epochs=2, batch_size=64)

模型保存和加载

# 模型保存
model.save('path/to/location')# 模型加载
model = keras.models.load_model('path/to/location')# 其他详细内容:https://www.tensorflow.org/guide/keras/save_and_serialize

案例二

# 自定义一个Layer
class Linear(keras.layers.Layer):def __init__(self, units=32, input_dim=32):super(Linear, self).__init__()# ...def call(self, inputs):# ...# 层递归组合
class MLPBlock(keras.layers.Model):def __init__(self):super(MLPBlock, self).__init__()self.linear_1 = Linear(64, 32)self.linear_2 = Linear(32, 16)self.linear_3 = Linear(16, 1)def call(self, inputs):x = self.linear_1(inputs)x = tf.nn.relu(x)x = self.linear_2(x)x = tf.nn.relu(x)return self.linear_3(x)# 自定义损失函数和评估方法 add_loss()/add_metric():https://www.tensorflow.org/guide/keras/custom_layers_and_modelsd_optimizer = keras.optimizers.Adam(learning_rate=0.001)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
model = MLPBlock()@tf.function
def train_step(x, y):with tf.GradientTape() as tape:predictions = model(x, training=True)loss_value = loss_fn(y, predictions)grads = tape.gradient(loss_value, model.trainable_weights)d_optimizer.apply_gradients(zip(grads, model.trainable_weights))@tf.function
def test_step(x, y):predictions = model(x, training=False)val_acc_metric.update_state(y, predictions)

Tensorflow2.0模型构建与训练相关推荐

  1. 【金融】【pytorch】使用深度学习预测期货收盘价涨跌——LSTM模型构建与训练

    [金融][pytorch]使用深度学习预测期货收盘价涨跌--LSTM模型构建与训练 LSTM 创建模型 模型训练 查看指标 LSTM 创建模型 指标函数参考<如何用keras/tf/pytorc ...

  2. TensorFlow2快速模型构建及tensorboard初体验

    学习目标: TensorFlow由谷歌开源的机器学习框架,其对常见机器学习算法的包装性好,"开箱即用",让开发者能够轻松地构建和部署各类机器学习模型,并可直接用于生产系统.Tens ...

  3. TensorFlow2.0 —— 模型保存与加载

    目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...

  4. 真正的秘笈!授人鱼不如授人渔,如何用pytorch编写一个完美又不失自由的数据准备、模型构建、训练、评估、测试流程?看完本文大呼过瘾!

    前言 之前的教程中,有同学要求将讲解的代码开源,以方便使用.本文将会用最精简的框架去介绍来自顶级公司的pytorch模型的整个框架及流程,并整理开源为通用的模型框架,供研究使用.(如果你还没有阅读过之 ...

  5. 看图说话实战教程 | 第三节 | 模型构建及训练

    欢迎来到<看图说话实战教程>系列第三节.在这一节中,我们正式进入看图说话深度模型的构建与训练. 文章目录 1. 加载数据 2. 构建模型 3. 拟合模型 4. 完整代码 5. 结束语 1. ...

  6. 3、tensorflow2.0 实现MTCNN、训练O_net网络,并进行测试图片

    训练O_net网络,并测试图片 上一篇,我们已经知道如何生成O_net训练集,这次我们开始训练Onet网络. 训练完成后,保存权重,我们随机抽取一张照片,测试一下效果. 代码: train_Onet. ...

  7. TensorFlow2.0:模型的保存与加载

    ** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...

  8. 生成式对抗网络GAN之实现手写字体的生成(基于keras Tensorflow2.0实现)详细分析训练过程和代码

  9. TensorFlow2.0(三)--Keras构建神经网络回归模型

    Keras构建神经网络回归模型 1. 前言 1. 导入相应的库 2. 数据导入与处理 2.1 加载数据集 2.2 划分数据集 2.3 数据归一化 3. 模型构建与训练 3.1 神经网络回归模型的构建 ...

最新文章

  1. 软件测试2019:第四次作业—— 性能测试(含JMeter实验)
  2. Sql字符串操作函数
  3. 你所接触的计算机网络,学计算机网络的进
  4. 腾讯应用宝采集数据分析
  5. Linux16.04和Windows 10双系统下,解决时间不一致问题
  6. HDU 1085 Holding Bin-Laden Captive!
  7. 电脑c盘怎么清理_电脑C盘内存不足?三分钟教你彻底清理C盘空间,瞬间多出10个G...
  8. gephi java教程_gephi生成图(java版)
  9. 计算机网络的abc类地址,abc类私有ip地址范围
  10. PS中套索工具的使用
  11. 隐蔽的大数据杀熟,“百亿补贴”会员竟比非会员贴得少?
  12. ubuntu 我喜欢的快捷键
  13. 利用2阶分数阶微分掩模的边缘检测(Matlab代码实现)
  14. Unity3D空战游戏模板 Air Warfare Pro
  15. MIUI系统语音识别引擎识别错误的解决方式
  16. 【前端】HTML标签基础复习笔记
  17. 敏捷-细说敏捷建模思想与实践(转)
  18. SQLyog通过excel导入数据
  19. 绯闻女孩传八卦也能作为区块链协议?10分钟告诉你为啥
  20. Go十大常见错误第7篇:不使用-race选项做并发竞争检测

热门文章

  1. boost中bind的使用
  2. 排序算法之快速排序(Java)
  3. 8X25Q充电部分软件梳理(CP侧)
  4. 2017年12月计算机一级c,2017年12月计算机二级C语言考试操作题冲刺卷(2)
  5. python os renames_Python3 os.renames() 方法
  6. 【Pytorch神经网络理论篇】 35 GaitSet模型:步态识别思路+水平金字塔池化+三元损失
  7. 64位处理器_电脑操作系统的32位和64位有什么区别
  8. LeetCode 1863. 找出所有子集的异或总和再求和(DFS)
  9. 程序员面试金典 - 面试题 04.05. 合法二叉搜索树(中序遍历)
  10. LeetCode 第 17 场双周赛(469/897,前52.3%)