摘要:这篇文章将讲解TensorFlow如何保存变量和神经网络参数,通过Saver保存神经网络,再通过Restore调用训练好的神经网络。

本文分享自华为云社区《[Python人工智能] 十一.Tensorflow如何保存神经网络参数 丨【百变AI秀】》,作者: eastmount。

一.保存变量

通过tf.Variable()定义权重和偏置变量,然后调用tf.train.Saver()存储变量,将数据保存至本地“my_net/save_net.ckpt”文件中。

# -*- coding: utf-8 -*-
"""
Created on Thu Jan  2 20:04:57 2020
@author: xiuzhang Eastmount CSDN
"""
import tensorflow as tf
import numpy as np#---------------------------------------保存文件---------------------------------------
W = tf.Variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')# 初始化
init = tf.initialize_all_variables()# 定义saver 存储各种变量
saver = tf.train.Saver()# 使用Session运行初始化
with tf.Session() as sess:sess.run(init)# 保存 官方保存格式为ckptsave_path = saver.save(sess, "my_net/save_net.ckpt")print("Save to path:", save_path)

“Save to path: my_net/save_net.ckpt”保存成功如下图所示:

打开内容如下图所示:

接着定义标记变量train,通过Restore操作使用我们保存好的变量。注意,在Restore时需要定义相同的dtype和shape,不需要再定义init。最后直接通过 saver.restore(sess, “my_net/save_net.ckpt”) 提取保存的变量并输出即可。

# -*- coding: utf-8 -*-
"""
Created on Thu Jan  2 20:04:57 2020
@author: xiuzhang Eastmount CSDN
"""
import tensorflow as tf
import numpy as np# 标记变量
train = False#---------------------------------------保存文件---------------------------------------
# Save
if train==True:# 定义变量W = tf.Variable([[1,2,3], [3,4,5]], dtype=tf.float32, name='weights') #2行3列的数据b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')# 初始化init = tf.global_variables_initializer()# 定义saver 存储各种变量saver = tf.train.Saver()# 使用Session运行初始化with tf.Session() as sess:sess.run(init)# 保存 官方保存格式为ckptsave_path = saver.save(sess, "my_net/save_net.ckpt")print("Save to path:", save_path)
#---------------------------------------Restore变量-------------------------------------
# Restore
if train==False:# 记住在Restore时定义相同的dtype和shape# redefine the same shape and same type for your variablesW = tf.Variable(np.arange(6).reshape((2,3)), dtype=tf.float32, name='weights') #空变量b = tf.Variable(np.arange(3).reshape((1,3)), dtype=tf.float32, name='biases') #空变量# Restore不需要定义initsaver = tf.train.Saver()with tf.Session() as sess:# 提取保存的变量saver.restore(sess, "my_net/save_net.ckpt")# 寻找相同名字和标识的变量并存储在W和b中print("weights", sess.run(W))print("biases", sess.run(b))

运行代码,如果报错“NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. ”,则需要重置Spyder即可。

最后输出之前所保存的变量,weights为 [[1,2,3], [3,4,5]],偏置为 [[1,2,3]]。

二.保存神经网络

那么,TensorFlow如何保存我们的神经网络框架呢?我们需要把整个网络训练好再进行保存,其方法和上面类似,完整代码如下:

"""
Created on Sun Dec 29 19:21:08 2019
@author: xiuzhang Eastmount CSDN
"""
import os
import glob
import cv2
import numpy as np
import tensorflow as tf# 定义图片路径
path = 'photo/'#---------------------------------第一步 读取图像-----------------------------------
def read_img(path):cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]imgs = []labels = []fpath = []for idx, folder in enumerate(cate):# 遍历整个目录判断每个文件是不是符合for im in glob.glob(folder + '/*.jpg'):#print('reading the images:%s' % (im))img = cv2.imread(im)             #调用opencv库读取像素点img = cv2.resize(img, (32, 32))  #图像像素大小一致imgs.append(img)                 #图像数据labels.append(idx)               #图像类标fpath.append(path+im)            #图像路径名#print(path+im, idx)return np.asarray(fpath, np.string_), np.asarray(imgs, np.float32), np.asarray(labels, np.int32)# 读取图像
fpaths, data, label = read_img(path)
print(data.shape)  # (1000, 256, 256, 3)
# 计算有多少类图片
num_classes = len(set(label))
print(num_classes)# 生成等差数列随机调整图像顺序
num_example = data.shape[0]
arr = np.arange(num_example)
np.random.shuffle(arr)
data = data[arr]
label = label[arr]
fpaths = fpaths[arr]# 拆分训练集和测试集 80%训练集 20%测试集
ratio = 0.8
s = np.int(num_example * ratio)
x_train = data[:s]
y_train = label[:s]
fpaths_train = fpaths[:s]
x_val = data[s:]
y_val = label[s:]
fpaths_test = fpaths[s:]
print(len(x_train),len(y_train),len(x_val),len(y_val)) #800 800 200 200
print(y_val)
#---------------------------------第二步 建立神经网络-----------------------------------
# 定义Placeholder
xs = tf.placeholder(tf.float32, [None, 32, 32, 3])  #每张图片32*32*3个点
ys = tf.placeholder(tf.int32, [None])               #每个样本有1个输出
# 存放DropOut参数的容器
drop = tf.placeholder(tf.float32)                   #训练时为0.25 测试时为0# 定义卷积层 conv0
conv0 = tf.layers.conv2d(xs, 20, 5, activation=tf.nn.relu)    #20个卷积核 卷积核大小为5 Relu激活
# 定义max-pooling层 pool0
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("Layer0:\n", conv0, pool0)# 定义卷积层 conv1
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu) #40个卷积核 卷积核大小为4 Relu激活
# 定义max-pooling层 pool1
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])        #pooling窗口为2x2 步长为2x2
print("Layer1:\n", conv1, pool1)# 将3维特征转换为1维向量
flatten = tf.layers.flatten(pool1)# 全连接层 转换为长度为400的特征向量
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
print("Layer2:\n", fc)# 加上DropOut防止过拟合
dropout_fc = tf.layers.dropout(fc, drop)# 未激活的输出层
logits = tf.layers.dense(dropout_fc, num_classes)
print("Output:\n", logits)# 定义输出结果
predicted_labels = tf.arg_max(logits, 1)
#---------------------------------第三步 定义损失函数和优化器---------------------------------# 利用交叉熵定义损失
losses = tf.nn.softmax_cross_entropy_with_logits(labels = tf.one_hot(ys, num_classes),       #将input转化为one-hot类型数据输出logits = logits)# 平均损失
mean_loss = tf.reduce_mean(losses)# 定义优化器 学习效率设置为0.0001
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(losses)
#------------------------------------第四步 模型训练和预测-----------------------------------
# 用于保存和载入模型
saver = tf.train.Saver()
# 训练或预测
train = False
# 模型文件路径
model_path = "model/image_model"with tf.Session() as sess:if train:print("训练模式")# 训练初始化参数sess.run(tf.global_variables_initializer())# 定义输入和Label以填充容器 训练时dropout为0.25train_feed_dict = {xs: x_train,ys: y_train,drop: 0.25}# 训练学习1000次for step in range(1000):_, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)if step % 50 == 0:  #每隔50次输出一次结果print("step = {}\t mean loss = {}".format(step, mean_loss_val))# 保存模型saver.save(sess, model_path)print("训练结束,保存模型到{}".format(model_path))else:print("测试模式")# 测试载入参数saver.restore(sess, model_path)print("从{}载入模型".format(model_path))# label和名称的对照关系label_name_dict = {0: "人类",1: "沙滩",2: "建筑",3: "公交",4: "恐龙",5: "大象",6: "花朵",7: "野马",8: "雪山",9: "美食"}# 定义输入和Label以填充容器 测试时dropout为0test_feed_dict = {xs: x_val,ys: y_val,drop: 0}# 真实label与模型预测labelpredicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):# 将label id转换为label名real_label_name = label_name_dict[real_label]predicted_label_name = label_name_dict[predicted_label]print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))# 评价结果print("正确预测个数:", sum(y_val==predicted_labels_val))print("准确度为:", 1.0*sum(y_val==predicted_labels_val) / len(y_val))

核心步骤为:

saver = tf.train.Saver()
model_path = "model/image_model"
with tf.Session() as sess:if train:#保存神经网络sess.run(tf.global_variables_initializer())for step in range(1000):_, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)if step % 50 == 0:print("step = {}\t mean loss = {}".format(step, mean_loss_val))saver.save(sess, model_path)else:#载入神经网络saver.restore(sess, model_path)predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)for fpath, real_label, predicted_label in zip(fpaths_test, y_val, predicted_labels_val):real_label_name = label_name_dict[real_label]predicted_label_name = label_name_dict[predicted_label]print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))  

预测输出结果如下图所示,最终预测正确181张图片,准确度为0.905。相比之前机器学习KNN的0.500有非常高的提升。

测试模式

INFO:tensorflow:Restoring parameters from model/image_model
从model/image_model载入模型
b'photo/photo/3\\335.jpg'       公交 => 公交
b'photo/photo/1\\129.jpg'       沙滩 => 沙滩
b'photo/photo/7\\740.jpg'       野马 => 野马
b'photo/photo/5\\564.jpg'       大象 => 大象
...
b'photo/photo/9\\974.jpg'       美食 => 美食
b'photo/photo/2\\220.jpg'       建筑 => 公交
b'photo/photo/9\\912.jpg'       美食 => 美食
b'photo/photo/4\\459.jpg'       恐龙 => 恐龙
b'photo/photo/5\\525.jpg'       大象 => 大象
b'photo/photo/0\\44.jpg'        人类 => 人类正确预测个数: 181
准确度为: 0.905

点击关注,第一时间了解华为云新鲜技术~

Tensorflow保存神经网络参数有妙招:Saver和Restore相关推荐

  1. tensorflow保存模型参数文件pb查看

    查看pb文件的节点参数: with tf.Session() as sess: with open(model, 'rb') as model_file: graph_def = tf.GraphDe ...

  2. 不知道电脑压缩图片怎么压缩?这有3个压缩妙招推荐给你

    我们想要在手机上实现图片压缩非常简单,很多美图类的APP里都有图片压缩的功能,那你知道电脑压缩图片怎么压缩吗?今天我带来了3个电脑图片压缩的妙招,感兴趣的小伙伴往下看吧. 妙招一:在"图片转 ...

  3. tensorflow没有这个参数_TensorFlow入门笔记(五) : 神经网络参数与TensorFlow变量

    神经网络参数简介 在TensorFlow中,变量(tf.Variable)的作用就是保存和更新神经网络中的参数.和其他编程语言类似,在TensorFlow中的变量也需要初始值.因为在TensorFlo ...

  4. Tensorflow训练神经网络保存*.pb模型及载入*.pb模型

    1 神经网络结构 1.0 保存*.pb模型 import tensorflow as tf from tensorflow.python.framework import graph_util fro ...

  5. 贝叶斯优化神经网络参数_贝叶斯超参数优化:神经网络,TensorFlow,相预测示例

    贝叶斯优化神经网络参数 The purpose of this work is to optimize the neural network model hyper-parameters to est ...

  6. java根据入参不同调不同方法_java根据传入参数不同调用不同的方法,求高手支妙招!...

    java根据传入参数不同调用不同的方法,求高手支妙招! 关注:138  答案:5  mip版 解决时间 2021-02-02 20:33 提问者我微笑着泪滴 2021-02-02 07:00 比如in ...

  7. TensorFlow实现超参数调整

    TensorFlow实现超参数调整 正如你目前所看到的,神经网络的性能非常依赖超参数.因此,了解这些参数如何影响网络变得至关重要. 常见的超参数是学习率.正则化器.正则化系数.隐藏层的维数.初始权重值 ...

  8. tensorflow保存模型和加载模型的方法(Python和Android)

    tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...

  9. TensorFlow保存或加载训练的模型

    什么是Tensorflow的模型 模型部分主要参考了这篇文章和这篇博客:另外,官方文档也给出了很多指导. Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络 ...

最新文章

  1. 225.用队列实现栈
  2. Bootstrap简介--目前最受欢迎的前端框架(一)
  3. ajax 在新选卡打开,JavaScript在新窗口中打开,而不是选项卡
  4. (三)html5的结构
  5. ssl1776-游乐场【图论,深搜】
  6. C算法编程题(七)购物
  7. 【grpc】[Python] A file with this name is already in the pool
  8. Hosts 文件作用及如何修改
  9. 去除Word文档中的页眉横线
  10. UVA11349 Symmetric Matrix【数学】
  11. 软件开发工作过程中的一些总结
  12. OpenFOAM编程基础(2) -数据读取与保存
  13. 报表技术2(百万数据导入导出,POI操作word)
  14. 【题解】P2678 [NOIP2015 提高组] 跳石头
  15. SQL企业管理器打不开
  16. 【2017百度之星程序设计大赛 - 初赛(B)】度度熊的交易计划
  17. Linux之(27)networkctl命令
  18. 方配网站服务器,方配网站服务器
  19. 【DS with Python】 re模块与正则表达式
  20. 今日小坑:Vue-Router之路径routes拼写错误

热门文章

  1. (25)HTML5之<canvas>和<svg>标签
  2. (23)css3文字阴影text-shadow
  3. c语言数字排列和算法思路,冒泡排序、快速排序算法理解及C程序实现
  4. python3调用c代码_在Python3.6中调用C代码
  5. echarts 柱状图圆柱_Echarts 柱状图配置详解
  6. 在 Vs2013中查看类的内部布局
  7. YIi 设置 ajax 验证
  8. redis-cluster
  9. Luogu P1311 选择客栈(前缀和)
  10. hdu 2586 How far away? (LCA模板)