本篇介绍如何在matlab中调用python训练好的网络模型和权重。

系统环境:win10,matlab2018b,python3.6,tensorflow1.1

代码如下:

tf = py.importlib.import_module('tensorflow');
np = py.importlib.import_module('numpy');
plt = py.importlib.import_module('matplotlib.pyplot');
sio = py.importlib.import_module('scipy.io');
run = py.importlib.import_module('run');
py.importlib.reload(run);
% 载入数据
loaddata = run.load_data();
% train_images = double(loaddata{1}); % 载入训练数据
% train_labels = double(loaddata{2});
test_images = double(loaddata{3}); % 载入测试数据
test_labels  = double(loaddata{4});
% 载入训练好的模型
model = run.create_model();
model.load_weights("training_1/cp.ckpt");
%% 使用
idx_sel = [1:39];
image_use = cell([1,length(idx_sel)]);
labels_use = zeros(1,length(idx_sel));
for i = 1:20% 测试用例:如果从matlab传参到直接调用keras模型中的fit方法会报错,image_use = np.array(reshape(test_images(idx_sel(i),:,:),[1,28,28]));label_true = test_labels(idx_sel(i));% 执行预测label_pred = int64(model.predict_classes(image_use));% 显示结果subplot(4,5,i)image_use_matlab = reshape(double(image_use),[28,28]); % 转换成matlab格式的imageimshow(image_use_matlab)if label_pred == label_truetitle(['预测',num2str(label_pred),',真实',num2str(label_true)],'color','g')    elsetitle(['预测',num2str(label_pred),',真实',num2str(label_true)],'color','r')    end
end

说明:

1)run是当前文件夹下的python文件,是tensorflow官网例子,实现手写数字识别,训练完成后对模型进行保存,以便在matlab中调用。run.py内容如下:

import tensorflow as tf
from tensorflow import keras
import numpy as np
import scipy.io as sio
#import matplotlib.pyplot as plt
import osdef create_model():model = keras.Sequential([keras.layers.Flatten(input_shape=(28, 28)),keras.layers.Dense(128, activation=tf.nn.relu),keras.layers.Dense(10, activation=tf.nn.softmax)])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])    return modeldef load_data():path='./mnist.npz'      f = np.load(path)  with np.load(path) as f:train_images, train_labels = f['x_train'], f['y_train']  test_images, test_labels = f['x_test'], f['y_test']      return train_images, train_labels, test_images, test_labelsdef train():train_images, train_labels, test_images, test_labels = load_data()    train_images = np.array(train_images)train_labels = np.array(train_labels)test_images = np.array(test_images)test_labels = np.array(test_labels)train_images = train_images / 255.0test_images = test_images / 255.0#plt.figure(figsize=(10,10))#for i in range(25):#    plt.subplot(5,5,i+1)#    plt.xticks([])#    plt.yticks([])#    plt.grid(False)#    plt.imshow(train_images[i], cmap=plt.cm.binary)#    plt.xlabel(class_names[train_labels[i]])#plt.show()checkpoint_path = "training_1/cp.ckpt"checkpoint_dir = os.path.dirname(checkpoint_path)# Create checkpoint callbackcp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,save_weights_only=True,verbose=1)model = create_model()model.fit(train_images, train_labels, epochs=5,callbacks=[cp_callback])test_loss, test_acc = model.evaluate(test_images, test_labels)print('Test accuracy:', test_acc)predictions = model.predict(test_images)return modelif __name__ == '__main__':train_images, train_labels, test_images, test_labels = load_data()model = train()

2)使用流程:

  • 下载mnist数据集
  • 运行run.py进行训练并保存训练结果
  • 运行matlab函数,在matlab中调用python中训练的模型进行推理,结果如下

3)matlab中调用python sys模块中的方法属性会出问题,而keras的fit方法会调用到sys.stdout.flush等方法,因此不能由matlab调用含相关方法的python函数,例如model.fit,model.evaluate等,但model.predict等能够直接在matlab中使用

4)本篇内容仅是对matlab中调用python的一个说明,该方式并不是实现诸如手写数字识别的最佳方式。如果要在matlab中进行相关功能开发,使用matlab深度学习工具箱更加方便;此外matlab也支持onnx转换的模型。

evaluate函数使用无效_在Matlab中使用tensorflow (2)相关推荐

  1. tensorflow numpy版本匹配_在Matlab中使用tensorflow (1)

    为了在matlab中利用丰富的python开源资源,探索了如下内容: 1)在matlab中直接调用tensorflow函数: 2)在matlab中调用tensorflow的python程序,结合mat ...

  2. evaluate函数使用无效_我用这个Excel函数,秀了同事一脸!很多人却连它名字都没听过...

    最近收到在某快递上班的周同学问题求助,主要是在计算包裹的体积时遇到了些麻烦事. 下表是周同学近期整理的快递包裹尺寸数据,其中重要一项工作就是通过长*宽*高来计算出包裹的体积. 周同学表示其实自己也能做 ...

  3. evaluate函数使用无效_[Python实战]使用栈实现简易计算器

    我们这次实现的命令行计算器,支持加减乘除.括号.浮点数.负数,以及查看历史和退出功能. 主要的思路:read - parse - print - loop. read 阶段是指读取用户在提示符(cal ...

  4. evaluate函数使用无效_使用Keras和Pytorch处理RNN变长序列输入的方法总结

    最近在使用Keras和Pytorch处理时间序列数据,在变长数据的输入处理上踩了很多坑.一般的通用做法都需要先将一个batch中的所有序列padding到同一长度,然后需要在网络训练时屏蔽掉paddi ...

  5. matlab gpu deep learning_在Matlab中使用tensorflow (1)

    为了在matlab中利用丰富的python开源资源,探索了如下内容: 1)在matlab中直接调用tensorflow函数: 2)在matlab中调用tensorflow的python程序,结合mat ...

  6. matlab pdepe函数边界,科学网-使用MATLAB中pdepe函数求解一维偏微分方程-邓浩鑫的博文...

    由于自己科研水平较低,记录的各种体会更多的是给自己做个小结,错误之处,欢迎大家指正. 使用MATLAB求解偏微分方程或者方程组,大致有三类方法.第一种是使用MATLAB中的PDE Toolbox,PD ...

  7. zeros什么意思_matlab中zeros函数是什么含义?MATLAB中zeros表示表示什么意思

    matlab中zeros函数是什么含义?MATLAB中zeros表示表示什么意思 发表时间:2019-12-26 10:20:18 小编:4326手游网 阅读: 在手机上看 手机扫描阅读 MATLAB ...

  8. 在MATLAB中使用tensorflow

    在MATLAB中使用tensorflow_m0_47218095的博客-CSDN博客 在Matlab中调用tensorflow或keras_jch_wang的博客-CSDN博客 在Matlab中设置导 ...

  9. matlab中读文件的行数_[转载]MATLAB中获取大型文本文件行数方法研究(转)

    在工作中会有很多特殊的需要,比如我现在就遇到一个需要将大型的文本格式数据文件(比如5G)读取到MATLAB中,同时进行一定的处理.由于XP的内存是绝对没有办法将5G的数据一次性加载到工作空间的,此时一 ...

最新文章

  1. Anaconda 安装 opencv3(Win10)
  2. neo4j springboot 日志_Springboot2.3集成neo4j的过程和踩坑记
  3. ASP.Net MVC开发基础学习笔记(5):区域、模板页与WebAPI初步
  4. 02:输出最高分数的学生姓名
  5. sqlserver 查询中文查询不到 查询英文可以查到_估值数据和财报数据查询方法
  6. [转] 外企面试官最爱提的问题 TOP10
  7. 如何在Python中串联两个列表?
  8. winsock 收发广播包
  9. CString.Format详解【摘录】
  10. 修改android的avd路径方法
  11. 单片机软件反破解 Hex反破解 破解后的hex不能量产
  12. localhost无法连接mysql_详细解说MySQL通过localhost无法连接数据库的问题解决
  13. 一步步教你装超强插件~油猴插件管理器Tampermonkey
  14. 一种新型分割图像中人物的方法,基于人物动作辨认
  15. ks live room danmu
  16. MAB建模规范-Naming Conventions命名规范
  17. 《C++ Templates》笔记 Chapter 12 Fundamentals in Depth-Chapter 13 Names in Templates
  18. 【java笔记】字符流,Properties,序列化,打印流
  19. Java实现对数据库的查操作
  20. 西电“可展开天线”项目获2013年度国家科学技术进步二等奖

热门文章

  1. [trustzone]-ARM Trustzone架构下的软件框图
  2. HBNIS-crypto
  3. (55)_KPCR, _NT_TIB, _KPRCB
  4. 【Win32汇编】字符串逆序
  5. aliyun服务器安装git,g++
  6. 4、MySQL使用二进制日志还原数据库
  7. Quartz框架架构
  8. hive hql文档_30分钟入门 Hive SQL(HQL 入门篇)
  9. Lombok 介绍和使用详情
  10. Java 集合系列11: Hashtable深入解析(1)