介绍

本博文中的代码,实现的是加载训练好的模型model_halcon_resenet.h5,并把该模型的参数赋值给两个不同的新的model。

函数式模型

官网上给出的调用一个训练好模型,并输出任意层的feature。

model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool').output)

但是这有一个问题,就是新的model,如果输入inputs和训练好的model的inputs大小不同呢?比如我想建立一个输入是600x600x3的新model,但是训练好的model输入是200x200x3,而这时我又想调用训练好模型的卷积核参数,这时该怎么办呢?

其实想一下,用训练好的模型参数,即使输入的尺寸不同,但是这些模型参数仍然可以处理计算,只是输出的feature map大小不同。那到底怎么赋值呢?其实很简单

在定义新的model时,新的model层在定义时,需要加上名字,而这个名字就是训练好的模型的每层名字。如下代码所示:

inputs=Input(shape=(400,500,3))

X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs)

X=BatchNormalization(name=“batch_normalization_1”)(X)

X=Activation(‘relu',name=“activation_1”)(X)

最后通过以下代码即可建立一个新的模型并拥有训练好模型的参数:

model=Model(inputs=inputs, outputs=X)

model.load_weights(‘model_halcon_resenet.h5', by_name=True)

源代码

from keras.models import load_model

from keras.preprocessing import image

from keras.applications.vgg19 import preprocess_input

from keras.models import Model

import numpy as np

from keras.layers import Conv2D, MaxPooling2D,merge

from keras.layers import BatchNormalization,Activation

from keras.layers import Input, Dense

from PIL import Image

import numpy as np

import keras

from keras.models import Sequential

from keras.layers import Dense, Dropout, Flatten,Input

from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D

from keras.layers import BatchNormalization,Activation

from sklearn.model_selection import train_test_split

from keras.applications.densenet import DenseNet169, DenseNet121

from keras.applications.inception_resnet_v2 import InceptionResNetV2

from keras.applications.inception_v3 import InceptionV3

from keras.optimizers import SGD

from keras import regularizers

from keras.models import Model

import tensorflow as tf

from PIL import Image

from keras.callbacks import TensorBoard

import os

import cv2

from keras import backend as K

from model import focal_loss

import keras.losses

#ReadMe 该代码是参考fast rcnn系列,先对整幅图像提取特征feature map,然后从原图对应位置上映射到feature map,并对feature map进行

# 切片,从而提取对应某个位置上的特征,并把该特征送进后面的识别网络进行分类识别。

keras.losses.focal_loss = focal_loss#这句代码是为了引入定义的loss

base_model=load_model('model_halcon_resenet.h5')

base_model.summary()

inputs=Input(shape=(400,500,3))

X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)

X=BatchNormalization(name="batch_normalization_1")(X)

X=Activation('relu',name="activation_1")(X)

#第一个残差模块

X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)

X_1=BatchNormalization(name="batch_normalization_2")(X_1)

X_1= Activation('relu',name="activation_2")(X_1)

X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)

X_1 = BatchNormalization(name="batch_normalization_3")(X_1)

merge_data = merge([X_1, X], mode='sum',name="merge_1")

X = Activation('relu',name="activation_3")(merge_data)

#第一个残差模块结束

X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)

X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)

X=BatchNormalization(name="batch_normalization_4")(X)

X=Activation('relu',name="activation_4")(X)

#第二个残差模块

X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)

X_2=BatchNormalization(name="batch_normalization_5")(X_2)

X_2= Activation('relu',name="activation_5")(X_2)

X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)

X_2 = BatchNormalization(name="batch_normalization_6")(X_2)

merge_data = merge([X_2, X], mode='sum',name="merge_2")

X = Activation('relu',name="activation_6")(merge_data)

#第二个残差模块结束

X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)

X=Conv2D(64, (3, 3),name="conv2d_7")(X)

X=BatchNormalization(name="batch_normalization_7")(X)

X=Activation('relu',name="activation_7")(X)

X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)

#第三个残差模块开始

X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)

X_3=BatchNormalization(name="batch_normalization_8")(X_3)

X_3= Activation('relu',name="activation_8")(X_3)

X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)

X_3 = BatchNormalization(name="batch_normalization_9")(X_3)

merge_data = merge([X_3, X], mode='sum',name="merge_3")

X = Activation('relu',name="activation_9")(merge_data)

#第三个残差模块结束

X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)

X=BatchNormalization(name="batch_normalization_10")(X)

X=Activation('relu',name="activation_10")(X)

#第四个残差模块开始

X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)

X_4=BatchNormalization(name="batch_normalization_11")(X_4)

X_4= Activation('relu',name="activation_11")(X_4)

X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)

X_4 = BatchNormalization(name="batch_normalization_12")(X_4)

merge_data = merge([X_4, X], mode='sum',name="merge_4")

X = Activation('relu',name="activation_12")(merge_data)

#第四个残差模块结束

X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)

X = Conv2D(64, (3, 3),name="conv2d_13")(X)

X = BatchNormalization(name="batch_normalization_13")(X)

X = Activation('relu',name="activation_13")(X)

#第五个残差模块开始

X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)

X_5=BatchNormalization(name="batch_normalization_14")(X_5)

X_5= Activation('relu',name="activation_14")(X_5)

X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)

X_5 = BatchNormalization(name="batch_normalization_15")(X_5)

merge_data = merge([X_5, X], mode='sum',name="merge_5")

X = Activation('relu',name="activation_15")(merge_data)

#第五个残差模块结束

model=Model(inputs=inputs, outputs=X)

model.load_weights('model_halcon_resenet.h5', by_name=True)

#读取指定图像数据

image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'

img = image.load_img(image_dir, target_size=(400, 500))

x = image.img_to_array(img)

x = np.expand_dims(x, axis=0)

x = preprocess_input(x)

#利用第一个模型预测出特征数据,并对特征数据进行切片

feature_map=model.predict(x)

T=np.array(feature_map)

f_1=T[:,16:21,0:10,:]

print(f_1.shape)

print(feature_map.shape)

#第一个模型没有问题

#定义第二个模型

inputs_sec=Input(shape=(1,5,10,64))

X_= Flatten(name="flatten_1")(inputs_sec)

X_ = Dense(256, activation='relu',name="dense_1")(X_)

X_ = Dropout(0.5,name="dropout_1")(X_)

predictions = Dense(6, activation='softmax',name="dense_2")(X_)

model_sec=Model(inputs=inputs_sec, outputs=predictions)

model_sec.load_weights('model_halcon_resenet.h5', by_name=True)

#第二个模型定义结束

model_sec.summary()

#开始对整幅图像进行切片,并记录坐标位置

pic=cv2.imread(image_dir)

cor_list=[]

name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']

font = cv2.FONT_HERSHEY_SIMPLEX

for i in range(3):

for j in range(5):

if(i==2):

cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]

data = np.expand_dims(cut_feature, axis=0)

result = model_sec.predict(data)

print(result)

result_data=result[0].tolist()

#如果置信度过低,则舍弃

# if(max(result_data)<=0.7):

# continue

index_num = result_data.index(max(result_data))

name=name_list[index_num]

cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标

x=cor_list[0]

y=cor_list[1]

cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)

cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)

else:

cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]

data = np.expand_dims(cut_feature, axis=0)

result = model_sec.predict(data)

print(result)

result_data = result[0].tolist()

#如果置信度过低,则舍弃

# if (max(result_data) <= 0.7):

# continue

index_num = result_data.index(max(result_data))

name = name_list[index_num]

cor_list = [i * 160 + 6, j * 80] # 每个切片数据,映射到原图上,检测框对应的左上角坐标

x = cor_list[0]

y = cor_list[1]

cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)

cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)

cv2.imshow('pic',pic)

cv2.waitKey(0)

cv2.destroyAllWindows()

# data= np.expand_dims(f_1, axis=0)

# result=model_sec.predict(data)

# print(result)

#第二个模型可以完全预测,没有问题

补充知识:加载训练好的模型参数,但是权重一直变化

变量初始化会导致权重发生变化,去掉就好了。

以上这篇keras读取训练好的模型参数并把参数赋值给其它模型详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

python训练模型函数参数_keras读取训练好的模型参数并把参数赋值给其它模型详解...相关推荐

  1. [Python从零到壹] 十一.数据分析之Numpy、Pandas、Matplotlib和Sklearn入门知识万字详解(1)

    欢迎大家来到"Python从零到壹",在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界.所有文章都将结合案例.代码和作者的经验讲 ...

  2. [Python从零到壹] 十五.文本挖掘之数据预处理、Jieba工具和文本聚类万字详解

    欢迎大家来到"Python从零到壹",在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界.所有文章都将结合案例.代码和作者的经验讲 ...

  3. [Python从零到壹] 五十一.图像增强及运算篇之图像灰度直方图对比分析万字详解

    欢迎大家来到"Python从零到壹",在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界.所有文章都将结合案例.代码和作者的经验讲 ...

  4. python训练模型函数参数_一步步亲手用python实现Logistic Regression

    前面的[DL笔记1]Logistic回归:最基础的神经网络和[DL笔记2]神经网络编程原则&Logistic Regression的算法解析讲解了Logistic regression的基本原 ...

  5. python卷积神经网络cnn的训练算法_【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理...

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  6. python元组读取到列表_python中读入二维csv格式的表格方法详解(以元组/列表形式表示)...

    如何去读取一个没有表头的二维csv文件(如下图所示)? 并以元组的形式表现数据: ((1.0, 0.0, 3.0, 180.0), (2.0, 0.0, 2.0, 180.0), (3.0, 0.0, ...

  7. python可以画动态图吗_matplotlib 画动态图以及plt.ion()和plt.ioff()的使用详解

    学习python的道路是漫长的,今天又遇到一个问题,所以想写下来自己的理解方便以后查看. 在使用matplotlib的过程中,常常会需要画很多图,但是好像并不能同时展示许多图.这是因为python可视 ...

  8. Python开发之:Django基于Docker实现Mysql数据库读写分离、集群、主从同步详解 | 原力计划...

    作者 | Pythonicc 责编 | 王晓曼 出品 | CSDN博客 简介 1.什么是数据库读写分离 读写分离,基本的原理是让主数据库处理事务性增.改.删操作(INSERT.UPDATE.DELET ...

  9. python scrapy框架 抓取的图片路径打不开图片_Python中Scrapy爬虫图片处理详解

    下载图片 下载图片有两种方式,一种是通过 Requests 模块发送 get 请求下载,另一种是使用 Scrapy 的 ImagesPipeline 图片管道类,这里主要讲后者. 安装 Scrapy ...

最新文章

  1. ​炸了!程序员现在没有这点技能都还不能就业了?
  2. django安装初步使用命令整理
  3. Zookeeper分布式一致性原理(五):Zookeeper-Java-API
  4. 谈谈关于个人提升的一些思考
  5. 为什么设置行高文字就能居中
  6. 记录一次 自建网盘程序 cloudreve被攻击
  7. docker MySQL-错误:2059-Authentication plugin ‘caching_sha2_password‘ cannot be loaded
  8. Selenium免密码登录学习的方法
  9. Google搜索从入门到精通 v4.0
  10. 中国医院评审/评级标准及区别和特点
  11. 软件无线电实验 matlab,基于MATLAB和ModelSim的软件无线电课程实验设计
  12. mongodb mongoose 的使用
  13. 句法结构解析和Transition_based方法
  14. 部分软件可以正常打开,但图标无法正常显示
  15. python取系统日期前一天_python 获取前一天或前N天的日期
  16. bzoj 3772: 精神污染 (主席树+dfs序)
  17. cordova 调起拨打电话
  18. 【牛客网华为机试】HJ32 密码截取
  19. 如何显示文件夹的后缀和隐藏文件
  20. E3ZG_D62传感器 STM32C8T6

热门文章

  1. svn co 默认密钥' GNOME keyring
  2. WPF应用基础篇---TreeView
  3. Android 完整地操作数据库--日记本实例
  4. html css拖拽设计,css绘制三角形 和 HTML拖拽事件
  5. Java基础练习之流程控制(二)
  6. 当我以为这是最后一个Bug,改完就能提交了的时候
  7. WSL2之kali安装界面kex
  8. Lisp入门(好文)
  9. 细数AVPlayer的那些坑
  10. J2EE技术-Hibernate