介绍

本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model。

函数式模型

官网上给出的调用一个训练好模型,并输出任意层的feature。
model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool’).output)

但是这有一个问题,就是新的model,如果输入inputs和训练好的model的inputs大小不同呢?比如我想建立一个输入是600x600x3的新model,但是训练好的model输入是200x200x3,而这时我又想调用训练好模型的卷积核参数,这时该怎么办呢?

其实想一下,用训练好的模型参数,即使输入的尺寸不同,但是这些模型参数仍然可以处理计算,只是输出的feature map大小不同。那到底怎么赋值呢?其实很简单

在定义新的model时,新的model层在定义时,需要加上名字,而这个名字就是训练好的模型的每层名字。如下代码所示:

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs)
X=BatchNormalization(name=“batch_normalization_1”)(X)
X=Activation(‘relu’,name=“activation_1”)(X)

最后通过以下代码即可建立一个新的模型并拥有训练好模型的参数:

model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5’, by_name=True)

源代码

from keras.models import load_model
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
from keras.layers import Conv2D, MaxPooling2D,merge
from keras.layers import BatchNormalization,Activation
from keras.layers import Input, Dense
from PIL import Image
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Input
from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D
from keras.layers import BatchNormalization,Activation
from sklearn.model_selection import train_test_split
from keras.applications.densenet import DenseNet169, DenseNet121
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras import regularizers
from keras.models import Model
import tensorflow as tf
from PIL import Image
from keras.callbacks import TensorBoard
import os
import cv2
from keras import backend as K
from model import focal_loss
import keras.losses#ReadMe 该代码是参考fast rcnn系列,先对整幅图像提取特征feature map,然后从原图对应位置上映射到feature map,并对feature map进行
# 切片,从而提取对应某个位置上的特征,并把该特征送进后面的识别网络进行分类识别。
keras.losses.focal_loss = focal_loss#这句代码是为了引入定义的loss
base_model=load_model('model_halcon_resenet.h5')
base_model.summary()inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)
X=BatchNormalization(name="batch_normalization_1")(X)
X=Activation('relu',name="activation_1")(X)
#第一个残差模块
X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)
X_1=BatchNormalization(name="batch_normalization_2")(X_1)
X_1= Activation('relu',name="activation_2")(X_1)
X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)
X_1 = BatchNormalization(name="batch_normalization_3")(X_1)
merge_data = merge([X_1, X], mode='sum',name="merge_1")
X = Activation('relu',name="activation_3")(merge_data)
#第一个残差模块结束
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)
X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)
X=BatchNormalization(name="batch_normalization_4")(X)
X=Activation('relu',name="activation_4")(X)
#第二个残差模块
X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)
X_2=BatchNormalization(name="batch_normalization_5")(X_2)
X_2= Activation('relu',name="activation_5")(X_2)
X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)
X_2 = BatchNormalization(name="batch_normalization_6")(X_2)
merge_data = merge([X_2, X], mode='sum',name="merge_2")
X = Activation('relu',name="activation_6")(merge_data)
#第二个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)
X=Conv2D(64, (3, 3),name="conv2d_7")(X)
X=BatchNormalization(name="batch_normalization_7")(X)
X=Activation('relu',name="activation_7")(X)
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)
#第三个残差模块开始
X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)
X_3=BatchNormalization(name="batch_normalization_8")(X_3)
X_3= Activation('relu',name="activation_8")(X_3)
X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)
X_3 = BatchNormalization(name="batch_normalization_9")(X_3)
merge_data = merge([X_3, X], mode='sum',name="merge_3")
X = Activation('relu',name="activation_9")(merge_data)
#第三个残差模块结束
X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)
X=BatchNormalization(name="batch_normalization_10")(X)
X=Activation('relu',name="activation_10")(X)
#第四个残差模块开始
X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)
X_4=BatchNormalization(name="batch_normalization_11")(X_4)
X_4= Activation('relu',name="activation_11")(X_4)
X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)
X_4 = BatchNormalization(name="batch_normalization_12")(X_4)
merge_data = merge([X_4, X], mode='sum',name="merge_4")
X = Activation('relu',name="activation_12")(merge_data)
#第四个残差模块结束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)
X = Conv2D(64, (3, 3),name="conv2d_13")(X)
X = BatchNormalization(name="batch_normalization_13")(X)
X = Activation('relu',name="activation_13")(X)
#第五个残差模块开始
X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)
X_5=BatchNormalization(name="batch_normalization_14")(X_5)
X_5= Activation('relu',name="activation_14")(X_5)
X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)
X_5 = BatchNormalization(name="batch_normalization_15")(X_5)
merge_data = merge([X_5, X], mode='sum',name="merge_5")
X = Activation('relu',name="activation_15")(merge_data)
#第五个残差模块结束
model=Model(inputs=inputs, outputs=X)
model.load_weights('model_halcon_resenet.h5', by_name=True)
#读取指定图像数据
image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'
img = image.load_img(image_dir, target_size=(400, 500))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#利用第一个模型预测出特征数据,并对特征数据进行切片
feature_map=model.predict(x)
T=np.array(feature_map)
f_1=T[:,16:21,0:10,:]
print(f_1.shape)
print(feature_map.shape)
#第一个模型没有问题
#定义第二个模型
inputs_sec=Input(shape=(1,5,10,64))
X_= Flatten(name="flatten_1")(inputs_sec)
X_ = Dense(256, activation='relu',name="dense_1")(X_)
X_ = Dropout(0.5,name="dropout_1")(X_)
predictions = Dense(6, activation='softmax',name="dense_2")(X_)
model_sec=Model(inputs=inputs_sec, outputs=predictions)
model_sec.load_weights('model_halcon_resenet.h5', by_name=True)
#第二个模型定义结束
model_sec.summary()
#开始对整幅图像进行切片,并记录坐标位置
pic=cv2.imread(image_dir)
cor_list=[]
name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(3):for j in range(5):if(i==2):cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]data = np.expand_dims(cut_feature, axis=0)result = model_sec.predict(data)print(result)result_data=result[0].tolist()#如果置信度过低,则舍弃# if(max(result_data)<=0.7):#     continueindex_num = result_data.index(max(result_data))name=name_list[index_num]cor_list = [i * 160 + 6, j * 80]  # 每个切片数据,映射到原图上,检测框对应的左上角坐标x=cor_list[0]y=cor_list[1]cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)else:cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]data = np.expand_dims(cut_feature, axis=0)result = model_sec.predict(data)print(result)result_data = result[0].tolist()#如果置信度过低,则舍弃# if (max(result_data) <= 0.7):#     continueindex_num = result_data.index(max(result_data))name = name_list[index_num]cor_list = [i * 160 + 6, j * 80]  # 每个切片数据,映射到原图上,检测框对应的左上角坐标x = cor_list[0]y = cor_list[1]cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)cv2.imshow('pic',pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
# data= np.expand_dims(f_1, axis=0)
# result=model_sec.predict(data)
# print(result)
#第二个模型可以完全预测,没有问题

keras读取训练好的模型参数并把参数赋值给其它模型相关推荐

  1. python训练模型函数参数_keras读取训练好的模型参数并把参数赋值给其它模型详解...

    介绍 本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model. 函数式模型 官网上给出的调用一个训练好模型,并输出任意层 ...

  2. CV:基于Keras利用训练好的hdf5模型进行目标检测实现输出模型中的脸部表情或性别的gradcam(可视化)

    CV:基于Keras利用训练好的hdf5模型进行目标检测实现输出模型中的脸部表情或性别的gradcam(可视化) 目录 设计思路 核心代码 设计思路 核心代码 #CV:基于keras利用训练好的hdf ...

  3. python实现人脸检测及识别(2)---- 利用keras库训练人脸识别模型

    前面已经采集好数据集boss文件夹存放需要识别的对象照片,other存放其他人的训练集照片,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的tensorflo ...

  4. Keras训练神经网络进行分类并使用GridSearchCV进行参数寻优

    Keras训练神经网络进行分类并使用GridSearchCV进行参数寻优 在机器学习模型中,需要人工选择的参数称为超参数.比如随机森林中决策树的个数,人工神经网络模型中隐藏层层数和每层的节点个数,正则 ...

  5. R语言使用caret包构建岭回归模型(Ridge Regression )构建回归模型、通过method参数指定算法名称、通过trainControl函数控制训练过程

    R语言使用caret包构建岭回归模型(Ridge Regression )构建回归模型.通过method参数指定算法名称.通过trainControl函数控制训练过程 目录

  6. 花书+吴恩达深度学习(二十)构建模型策略(超参数调试、监督预训练、无监督预训练)

    目录 0. 前言 1. 学习率衰减 2. 调参策略 3. 贪心监督预训练 4. 贪心逐层无监督预训练 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花书+吴恩达深度学习(十 ...

  7. Keras : 训练minst数据集并加载模型对本地手写图片进行预测

    我是本期目录酱 引入 minst数据集介绍 训练模型与测试的py代码分析 训练及测试的py代码(全) 训练及测试结果分析 加载模型并预测本地图片结果 加载模型并预测本地图片py代码(不全) 加载模型并 ...

  8. 2020-6-9 吴恩达-改善深层NN-w3 超参数调试(3.3 超参数训练的实践:Pandas(资源少,一次一个模型) VS Caviar(资源多,一次多个模型))

    1.视频网站:mooc慕课https://mooc.study.163.com/university/deeplearning_ai#/c 2.详细笔记网站(中文):http://www.ai-sta ...

  9. CS231n 卷积神经网络与计算机视觉 7 神经网络训练技巧汇总 梯度检验 参数更新 超参数优化 模型融合 等

    前面几章已经介绍了神经网络的结构.数据初始化.激活函数.损失函数等问题,现在我们该讨论如何让神经网络模型进行学习了. 1 梯度检验 权重的更新梯度是否正确决定着函数是否想着正确的方向迭代,在UFLDL ...

  10. LLM-大模型训练-步骤(二)-预训练/Pre-Training(1):全参数预训练(Full-Param Pre-Training)【对LLaMA等模型进一步全量参数预训练】【中文无监督学习语料】

    GitHub项目:KnowLM 一.全参数预训练(Full-Param Pre-training) 使用中文语料对LLaMA等模型进行进一步全量预训练,在尽可能保留原来的英文和代码能力的前提下,进一步 ...

最新文章

  1. 人脸检测识别文献阅读总结
  2. php服务器估算,使用zabbix API估算服务器磁盘空间可用天数
  3. [Git] 删除远程仓库的文件
  4. 刘强东深夜写信诉苦;华为不排斥卖给苹果 5G 芯片;Facebook 再宕机 | 极客头条...
  5. 大数据学习笔记02-HDFS-常用命令
  6. 【Oracle】闪回表
  7. Idea不能显示类的继承关系,pom文件的右键属性中也没有Diagrams选项(已解决)
  8. 推荐几个精致的前端Web UI框架
  9. presto Slice入门
  10. python多元线性回归_多元线性回归模型精度提升 虚拟变量
  11. 蓝牙耳机连接电脑,提示无法安装驱动程序
  12. 产品经理常用的19类50+工具软件盘点
  13. 参加2010年磨房《在路上 - 十年》百公里徒步活动小记
  14. SUMO交通仿真-核心概念和基础知识速览
  15. #clickid#CID#全新小程序链路CID/clickid解决方案,合规、完美防阿里封禁
  16. heu oj 1011 square
  17. 苹果iOS/iPadOS 15.2 Beta 1发布 app隐私报告?
  18. matlab离群值处理,数据平滑和离群值检测
  19. 高级映射(一):一对一、一对多,多对多查询总结
  20. mysql desc hcy.t1_mysql主从同步出错故障处理总结[数据库技术]

热门文章

  1. CodeForces615A-Bulbs-模拟
  2. 2015年总结与2016年目标
  3. 浅谈C++设计模式之抽象工厂(Abstract Factory)
  4. webapp之路--之query media
  5. 修改hadoop配饰文件文件后导致hive无法找到原有的dfs文件
  6. 手动修改Sublime Text2 边栏Sidebar的样式
  7. 【IBM Tivoli Identity Manager 学习文档】14 TIM组织结构设计
  8. JDK 8.0 新特性——接口默认方法与静态方法
  9. bootstrap-table初始化配置
  10. plsql设置代码提示和自动补全