课程来源:人工智能实践:Tensorflow笔记2

文章目录

  • 前言
  • 断点续训主要步骤
  • 参数提取主要步骤
  • 总结

前言

本讲目标:断点续训,存取最优模型;保存可训练参数至文本


断点续训主要步骤

读取模型:

先定义出存放模型的路径和文件名,命名为.ckpt文件。
生成ckpt文件的时候会同步生成索引表,所以通过判断是否存在索引表来知晓是不是已经保存过模型参数。
如果有了索引表就利用load_weights函数读取已经保存的模型参数。

code:


checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)

保存模型:

保存模型参数可以使用TensorFlow给出的回调函数,直接保存训练出来的模型参数
tf.keras.callbacks.ModelCheckpoint( filepath=路径文件名(文件存储路径),
save_weights_only=True/False,(是否只保留参数模型)
save_best_only=True/False(是否只保留最优结果)) 执行训练过程中时,加入callbacks选项:
history=model.fit(callbacks=[cp_callback])

code:

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])

第一次运行:

第二次运行:可以发现模型并不是从初始训练,而是在基于保存的模型开始训练的(这一点可以从准确率和损失看出):

全部代码:

import tensorflow as tf
import osfashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()

参数提取主要步骤

设置打印的格式,使所有参数都打印出来

np.set_printoptions(threshold=np.inf)
print(model.trainable_variables)

将所有可训练参数存入文本:

file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

完整代码:

import tensorflow as tf
import os
import numpy as npnp.set_printoptions(threshold=np.inf)fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

效果:

总结

课程链接:MOOC人工智能实践:TensorFlow笔记2

【神经网络扩展】:断点续训和参数提取相关推荐

  1. 断点续训 Pytorch 和 Tensorflow 框架 VGG16 模型 猫狗大战 鸢尾花分类

    神经网络训练模型的过程中,如果程序突然中断,竹篮打水一场空? >>>断点续训来解决! 目录 (1)Pytorch框架的断点续训(猫狗大战) (2)Tensorflow框架的断点续训( ...

  2. pytorch学习(一)pytorch中的断点续训

    1. 设置断点续训的目的 在遇到停电宕机,设备内存不足导致实验还没有跑完的情况下,如果没有使用断点续训,就需要从头开始训练,耗时费力. 断点续训主要保存的是网络模型的参数以及优化器optimizer的 ...

  3. tensorflow1运用模型断点续训、恢复图和进行预测

    前言 本文是代码根据吴恩达深度学习第四课程第一周第二节作业图像分类识别修改而成,会简单介绍一下项目流程,然后介绍tensorflow1保存模型的两种方法,以及如何用模型预测. 项目流程简单介绍 这里直 ...

  4. kera TensorBoard的可视化和断点续训同时处理

    一.实现可视化的步骤 ① 从keras.callbacks中导入Tensorboard类 from keras.callbacks import TensorBoard ② 在model.fit中添加 ...

  5. Python机器学习-搭建神经网络以及数据集引入和断点续存

    前言 本文旨在通过Python编程角度进行机器学习神经网络的引导,需要掌握基础的全连接神经网络基础,这包括了神经网络全连接层的结构,权重模板与偏置的作用,节点的处理方法.在掌握这些知识之后,本文将从代 ...

  6. Scrapy_redis框架原理分析并实现断点续爬以及分布式爬虫

    1. 下载github的demo代码 1.1 clone github scrapy-redis源码文件 git clone https://github.com/rolando/scrapy-red ...

  7. Android视频编辑器(一)通过OpenGL预览、录制视频以及断点续录等

    前言 如今的视频类app可谓是如日中天,火的不行.比如美拍.快手.VUE.火山小视频.抖音小视频等等.而这类视频的最基础和核心的功能就是视频录制和视频编辑功能.包括了手机视频录制.美白.加滤镜.加水印 ...

  8. linux 参数扩展,Shell Bash 中的参数扩展

    对于访问 $9 之后的位置参数也同样需要使用大括号,比如: echo "Argument 1 is $1" echo "Argument 10 is ${10}" ...

  9. 网络编程学习(11)/ FTP项目(5) ——文件上传和上传断点续存功能

    网络编程学习(11)/ FTP项目(5) --文件上传和上传断点续存功能 `服务端 lib 文件夹下的 main.py 状态码的变化` 文件上传功能 `服务端 lib 文件夹下的 main.py` ` ...

最新文章

  1. linux系统中指定端口连接数限制
  2. 开源网络备份软件bacula学习笔记
  3. html表单验证js代码,JavaScript表单验证实现代码
  4. 深入理解ElasticSearch(七):执行分布式检索
  5. 华三实现vlan通过
  6. Android7.0 emui主题,全新EMUI5.0基于Android7.0 天生快,一生快!
  7. 面试官:备战年终,这些面试考点,请你牢牢记住
  8. Ubuntu 星际译王StarDict
  9. mysql快速部署主从复制
  10. 关于iostream.h与iostream的区别
  11. Android PackageManagerService(三)pm命令安装流程详解
  12. 图解通信原理与案例分析-16:2G GSM基站的工作原理--时分多址与无线资源管理RRM
  13. java vo 什么意思_在Java中VO , PO , BO , QO, DAO ,POJO是什么意思
  14. windows文字转语音示例
  15. 【清华大学】深入理解操作系统(陈渝) 第一章
  16. logiscope系列-使用说明书
  17. 微信小程序:云开发开通
  18. 软件测试职业发展三步曲之一
  19. 计算机主机重启键,重启(计算机术语)_百度百科
  20. Locust简单使用

热门文章

  1. Flutter快速构建集美观与⾼性能于⼀体的APP
  2. c语言100以内奇数的和为多少,编写C#程序,计算100以内所有奇数的和。谢谢了,大神帮忙啊...
  3. ant vue 兼容性问题_ant design for vue 关于table的一些问题
  4. 参考文献中会议名称怎么缩写_期刊缩写查询总结
  5. EasyUI加zTree使用解析 easyui修改操作的表单回显方法 验证框提交表单前验证 datagrid的load方法
  6. 关于使用JQ scrollTop方法进行滚动定位
  7. 深入理解JavaScript之Event Loop
  8. display转块状化
  9. webpack 引入jquery和第三方jquery插件
  10. 移动优先的响应式布局