YOLO:Darknet框架的.weights模型转keras框架的.h5模型

  • 前言
    • 所需材料
    • 实现转换的代码
    • 应用举例
    • .weights转.pt格式
    • 参考信息

前言

由于存在多种框架下的YOLO算法,例如模型压缩大部分直接使用Darknet53,而计算mAP等可能需要使用keras等框架,这个时候需要将训练所得的.weights模型类型转换为.h5。本文给出了转换代码。

所需材料

1、待转换模型的算法框架,往往是.cfg文件;
2、待转换模型.weights文件;
3、转换代码 convert.py;
4、python对应的包。

实现转换的代码

import argparse
import configparser
import io
import os
from collections import defaultdictimport numpy as np
from keras import backend as K
from keras.layers import (Conv2D, Input, ZeroPadding2D, Add,UpSampling2D, MaxPooling2D, Concatenate)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.regularizers import l2
from keras.utils.vis_utils import plot_model as plotparser = argparse.ArgumentParser(description='Darknet To Keras Converter.')
parser.add_argument('config_path', help='Path to Darknet cfg file.')
parser.add_argument('weights_path', help='Path to Darknet weights file.')
parser.add_argument('output_path', help='Path to output Keras model file.')
parser.add_argument('-p','--plot_model',help='Plot generated Keras model and save as image.',action='store_true')
parser.add_argument('-w','--weights_only',help='Save as Keras weights file instead of model file.',action='store_true')def unique_config_sections(config_file):"""Convert all config sections to have unique names.Adds unique suffixes to config sections for compability with configparser."""section_counters = defaultdict(int)output_stream = io.StringIO()with open(config_file) as fin:for line in fin:if line.startswith('['):section = line.strip().strip('[]')_section = section + '_' + str(section_counters[section])section_counters[section] += 1line = line.replace(section, _section)output_stream.write(line)output_stream.seek(0)return output_stream# %%
def _main(args):config_path = os.path.expanduser(args.config_path)weights_path = os.path.expanduser(args.weights_path)assert config_path.endswith('.cfg'), '{} is not a .cfg file'.format(config_path)assert weights_path.endswith('.weights'), '{} is not a .weights file'.format(weights_path)output_path = os.path.expanduser(args.output_path)assert output_path.endswith('.h5'), 'output path {} is not a .h5 file'.format(output_path)output_root = os.path.splitext(output_path)[0]# Load weights and config.print('Loading weights.')weights_file = open(weights_path, 'rb')major, minor, revision = np.ndarray(shape=(3, ), dtype='int32', buffer=weights_file.read(12))if (major*10+minor)>=2 and major<1000 and minor<1000:seen = np.ndarray(shape=(1,), dtype='int64', buffer=weights_file.read(8))else:seen = np.ndarray(shape=(1,), dtype='int32', buffer=weights_file.read(4))print('Weights Header: ', major, minor, revision, seen)print('Parsing Darknet config.')unique_config_file = unique_config_sections(config_path)cfg_parser = configparser.ConfigParser()cfg_parser.read_file(unique_config_file)print('Creating Keras model.')input_layer = Input(shape=(None, None, 3))prev_layer = input_layerall_layers = []weight_decay = float(cfg_parser['net_0']['decay']) if 'net_0' in cfg_parser.sections() else 5e-4count = 0out_index = []for section in cfg_parser.sections():print('Parsing section {}'.format(section))if section.startswith('convolutional'):filters = int(cfg_parser[section]['filters'])size = int(cfg_parser[section]['size'])stride = int(cfg_parser[section]['stride'])pad = int(cfg_parser[section]['pad'])activation = cfg_parser[section]['activation']batch_normalize = 'batch_normalize' in cfg_parser[section]padding = 'same' if pad == 1 and stride == 1 else 'valid'# Setting weights.# Darknet serializes convolutional weights as:# [bias/beta, [gamma, mean, variance], conv_weights]prev_layer_shape = K.int_shape(prev_layer)weights_shape = (size, size, prev_layer_shape[-1], filters)darknet_w_shape = (filters, weights_shape[2], size, size)weights_size = np.product(weights_shape)print('conv2d', 'bn'if batch_normalize else '  ', activation, weights_shape)conv_bias = np.ndarray(shape=(filters, ),dtype='float32',buffer=weights_file.read(filters * 4))count += filtersif batch_normalize:bn_weights = np.ndarray(shape=(3, filters),dtype='float32',buffer=weights_file.read(filters * 12))count += 3 * filtersbn_weight_list = [bn_weights[0],  # scale gammaconv_bias,  # shift betabn_weights[1],  # running meanbn_weights[2]  # running var]conv_weights = np.ndarray(shape=darknet_w_shape,dtype='float32',buffer=weights_file.read(weights_size * 4))count += weights_size# DarkNet conv_weights are serialized Caffe-style:# (out_dim, in_dim, height, width)# We would like to set these to Tensorflow order:# (height, width, in_dim, out_dim)conv_weights = np.transpose(conv_weights, [2, 3, 1, 0])conv_weights = [conv_weights] if batch_normalize else [conv_weights, conv_bias]# Handle activation.act_fn = Noneif activation == 'leaky':pass  # Add advanced activation later.elif activation != 'linear':raise ValueError('Unknown activation function `{}` in section {}'.format(activation, section))# Create Conv2D layerif stride>1:# Darknet uses left and top padding instead of 'same' modeprev_layer = ZeroPadding2D(((1,0),(1,0)))(prev_layer)conv_layer = (Conv2D(filters, (size, size),strides=(stride, stride),kernel_regularizer=l2(weight_decay),use_bias=not batch_normalize,weights=conv_weights,activation=act_fn,padding=padding))(prev_layer)if batch_normalize:conv_layer = (BatchNormalization(weights=bn_weight_list))(conv_layer)prev_layer = conv_layerif activation == 'linear':all_layers.append(prev_layer)elif activation == 'leaky':act_layer = LeakyReLU(alpha=0.1)(prev_layer)prev_layer = act_layerall_layers.append(act_layer)elif section.startswith('route'):ids = [int(i) for i in cfg_parser[section]['layers'].split(',')]layers = [all_layers[i] for i in ids]if len(layers) > 1:print('Concatenating route layers:', layers)concatenate_layer = Concatenate()(layers)all_layers.append(concatenate_layer)prev_layer = concatenate_layerelse:skip_layer = layers[0]  # only one layer to routeall_layers.append(skip_layer)prev_layer = skip_layerelif section.startswith('maxpool'):size = int(cfg_parser[section]['size'])stride = int(cfg_parser[section]['stride'])all_layers.append(MaxPooling2D(pool_size=(size, size),strides=(stride, stride),padding='same')(prev_layer))prev_layer = all_layers[-1]elif section.startswith('shortcut'):index = int(cfg_parser[section]['from'])activation = cfg_parser[section]['activation']assert activation == 'linear', 'Only linear activation supported.'all_layers.append(Add()([all_layers[index], prev_layer]))prev_layer = all_layers[-1]elif section.startswith('upsample'):stride = int(cfg_parser[section]['stride'])assert stride == 2, 'Only stride=2 supported.'all_layers.append(UpSampling2D(stride)(prev_layer))prev_layer = all_layers[-1]elif section.startswith('yolo'):out_index.append(len(all_layers)-1)all_layers.append(None)prev_layer = all_layers[-1]elif section.startswith('net'):passelse:raise ValueError('Unsupported section header type: {}'.format(section))# Create and save model.if len(out_index)==0: out_index.append(len(all_layers)-1)model = Model(inputs=input_layer, outputs=[all_layers[i] for i in out_index])print(model.summary())if args.weights_only:model.save_weights('{}'.format(output_path))print('Saved Keras weights to {}'.format(output_path))else:model.save('{}'.format(output_path))print('Saved Keras model to {}'.format(output_path))# Check to see if all weights have been read.remaining_weights = len(weights_file.read()) / 4weights_file.close()print('Read {} of {} from Darknet weights.'.format(count, count +remaining_weights))if remaining_weights > 0:print('Warning: {} unused weights'.format(remaining_weights))if args.plot_model:plot(model, to_file='{}.png'.format(output_root), show_shapes=True)print('Saved model plot to {}.png'.format(output_root))if __name__ == '__main__':_main(parser.parse_args())

应用举例

例如,需要将进行模型压缩所得的权重文件.weights转换为keras可以使用的.h5类型的权重文件。假设进行剪枝、压缩后的模型为yolov3.cfg,权重文件为yolov3.weights,需要转换为yolov3.h5则在命令行输入:

python convert.py yolov3.cfg yolov3.weights model_data/yolov3.ht

后面三个参数分别就是,需要转换的模型文件,权重文件,输出的路径。

.weights转.pt格式

类似的,可以通过下面的代码实现weights格式转为pt格式:

python -c "from models import *; convert('cfg/yolov3.cfg', 'weights/last.pt')"

参考信息

https://github.com/huangbinz/yolov3-weights2h5

Darknet框架的权重文件.weights类型转换为keras框架的权重文件类型.h5相关推荐

  1. Windows10 将 YOLOX模型转换为OpenVINO需要的IR文件

    环境 Windows:10 Anaconda:2.0.4 Python 3.7.10 torch:1.7.0 torchvision:0.8.0 YOLOX:0.1.0 OpenVINO 工具包 20 ...

  2. python3.5将list类型转换为矩阵类型

    在python3中,取消了mat()函数,转而用matrix()代替 这个函数的作用是将list类型转换为numpy库中的矩阵类型

  3. DCMTK:将DICOM文件的内容转换为XML格式

    DCMTK:将DICOM文件的内容转换为XML格式 将DICOM文件的内容转换为XML格式 将DICOM文件的内容转换为XML格式 #include "dcmtk/config/osconf ...

  4. DCMTK:将DICOM文件的内容转换为JSON格式

    DCMTK:将DICOM文件的内容转换为JSON格式 将DICOM文件的内容转换为JSON格式 将DICOM文件的内容转换为JSON格式 #include "dcmtk/config/osc ...

  5. DL之Keras:基于Keras框架建立模型实现【预测】功能的简介、设计思路、案例分析、代码实现之详细攻略(经典,建议收藏)

    DL之Keras:基于Keras框架建立模型实现[预测]功能的简介.设计思路.案例分析.代码实现之详细攻略(经典,建议收藏) 目录 Keras框架使用分析 Keras框架设计思路 案例分析 代码实现 ...

  6. TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的.pb文件

    TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的frozen_inference_graph.pb文件 目录 ...

  7. YOLO:将yolo的.weights文件转换为keras、tensorflow等模型所需的.h5文件的图文教程

    YOLO:将yolo的.weights文件转换为keras.tensorflow等模型所需的.h5文件的图文教程 目录 解决问题 操作过程 结果输出 解决问题 将yolo的.weights文件转换为k ...

  8. java doubke类型转换为String_Java基础知识面试题大集合

    本文整理自作者:ThinkWon  链接: blog.csdn.net/ThinkWon/article/details/104390612 本文知识点目录 Java概述 何为编程 什么是Java j ...

  9. 如何将一组列表(三个以上,数值类型不一)保存为txt文件

    点击上方"Python爬虫与数据挖掘",进行关注 回复"书籍"即可获赠Python从入门到进阶共10本电子书 今 日 鸡 汤 郡邑浮前浦,波澜动远空. 大家好, ...

最新文章

  1. 武汉网络推广教大家如何编辑出更高质量的文章TDK?
  2. python的none是什么-Python中的None与Null(空字符)的区别
  3. ubuntu16.04无法连接WiFi搜索不到网络网卡驱动
  4. northwind中文 for mysql_学习心得 | PHP与mysql通信的若干问题
  5. .NET多线程编程(7)——C#多线程编程传递参数解决方案
  6. 带有Angular JS的Java EE 7 –第1部分
  7. 四步创建TCP客户端
  8. 【编译原理】如何编写BNF?
  9. linux修改栈指针x86,x86-堆栈指针未填充16时libc的system()导致分段...
  10. 多项式输出(洛谷-P1067)
  11. linux usb驱动u盘启动不了,Linux环境下USB的原理、驱动和配置(4)
  12. java 微信开发收到乱码,微信公众号发送模板消息中文乱码(java)
  13. 新冠肺炎疫情数学模型的一点想法
  14. 2018.08.17 洛谷P3135 [USACO16JAN]堡哞(前缀和处理)
  15. 计算机辅助三维设计大纲,《电脑辅助三维设计》课程教学大纲.doc
  16. JScrollBar().setValue(0)设置滚动条位置失效问题
  17. 你为什么要去博物馆? 我的理由比较另类
  18. html实现图片裁剪,【前端】图片裁剪(二)Jcrop实现裁剪
  19. Eclipse Neno版本 安装插件开发JavaEE
  20. 《蜡烛人》制作人高鸣:如何原汁原味的将主机游戏移植到手机平台

热门文章

  1. 巴比特 | 元宇宙每日必读:蒂芙尼宣布推出限量版 CryptoPunk 定制吊坠
  2. POJ 1877 Flooded! G++
  3. CocosStudio(八)AtlasLabel数字标签、BitmapLabel自定义字体、Label文本框
  4. 智能路由器要成功 该怎样修炼穿墙术?
  5. table thead tr设置表头背景色未完全覆盖的问题
  6. Number of Operations to Decrement Target to Zero - 滑动窗口
  7. 微信打开链接提示用浏览器打开
  8. 灰色预测(MATLAB)
  9. vmware vmbox 使用虚拟机安装Windows11提示电脑不符合最低系统要求的解决方案
  10. 学习Java编程入门书籍