cfg是配置文件,一般为了代码的可读性,把一层层的神经网络用cfg格式文件保存,用的时候可以直接读取调用,简单轻便。以下用一个例子来了解。

下面cfg文件是yolov3的网络层次:

[net]
# Testing
batch=1
subdivisions=1
# Training
# batch=64
# subdivisions=2
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1# 0
[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky# 1
[maxpool]
size=2
stride=2# 2
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky# 3
[maxpool]
size=2
stride=2# 4
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky# 5
[maxpool]
size=2
stride=2# 6
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky# 7
[maxpool]
size=2
stride=2# 8
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky# 9
[maxpool]
size=2
stride=2# 10
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky# 11
[maxpool]
size=2
stride=1# 12
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky############ 13
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky# 14
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky# 15
[convolutional]
size=1
stride=1
pad=1
filters=18
activation=linear# 16
[yolo]
mask = 3,4,5
anchors = 10,14,  23,27,  37,58,  81,82,  135,169,  344,319
classes=1
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1# 17
[route]
layers = -4# 18
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky# 19
[upsample]
stride=2# 20
[route]
layers = -1, 8# 21
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky# 22
[convolutional]
size=1
stride=1
pad=1
filters=18
activation=linear# 23
[yolo]
mask = 1,2,3
anchors = 10,14,  23,27,  37,58,  81,82,  135,169,  344,319
classes=1
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1

由上图配置文件可知,YOLOv3分为23个模块,每个模块都设置参数,下面用pytorch来调用他并创建神经网络。

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npdef create_modules(module_defs):"""Constructs module list of layer blocks from module configuration in module_defs"""hyperparams = module_defs.pop(0)output_filters = [int(hyperparams["channels"])]  # [3]module_list = nn.ModuleList()for module_i, module_def in enumerate(module_defs):modules = nn.Sequential()if module_def["type"] == "convolutional":bn = int(module_def["batch_normalize"])filters = int(module_def["filters"])  #模型定义的滤波器个数、大小等kernel_size = int(module_def["size"])pad = (kernel_size - 1) // 2modules.add_module(f"conv_{module_i}",nn.Conv2d(in_channels=output_filters[-1],out_channels=filters,kernel_size=kernel_size,stride=int(module_def["stride"]),padding=pad,bias=not bn,),)if bn:modules.add_module(f"batch_norm_{module_i}", nn.BatchNorm2d(filters, momentum=0.9, eps=1e-5))if module_def["activation"] == "leaky":modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1))elif module_def["type"] == "maxpool":kernel_size = int(module_def["size"])stride = int(module_def["stride"])if kernel_size == 2 and stride == 1:modules.add_module(f"_debug_padding_{module_i}", nn.ZeroPad2d((0, 1, 0, 1)))maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2))modules.add_module(f"maxpool_{module_i}", maxpool)elif module_def["type"] == "upsample":upsample = Upsample(scale_factor=int(module_def["stride"]), mode="nearest")modules.add_module(f"upsample_{module_i}", upsample)elif module_def["type"] == "route":layers = [int(x) for x in module_def["layers"].split(",")]filters = sum([output_filters[1:][i] for i in layers])modules.add_module(f"route_{module_i}", EmptyLayer())elif module_def["type"] == "shortcut":filters = output_filters[1:][int(module_def["from"])]modules.add_module(f"shortcut_{module_i}", EmptyLayer())elif module_def["type"] == "yolo":anchor_idxs = [int(x) for x in module_def["mask"].split(",")]# Extract anchorsanchors = [int(x) for x in module_def["anchors"].split(",")]anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]anchors = [anchors[i] for i in anchor_idxs]num_classes = int(module_def["classes"])img_size = int(hyperparams["height"])# Define detection layeryolo_layer = YOLOLayer(anchors, num_classes, img_size)modules.add_module(f"yolo_{module_i}", yolo_layer)# Register module list and number of output filtersmodule_list.append(modules)output_filters.append(filters)return hyperparams, module_list

由上面代码可知,每个模块都编写对应的代码来读取它,然后相应的用pytorch来创建神经网络。

最后可以在创建整个神经网络框架的时候直接调用上面定义的函数,如下代码所示:

self.hyperparams, self.module_list = create_modules(self.module_defs) #根据cfg文件创建模型

代码返回神经网络的参数和模块的list。比如我创建一个DarkNet的检测模型,代码如下所示:

def parse_model_config(path):"""Parses the yolo-v3 layer configuration file and returns module definitions"""file = open(path, 'r')lines = file.read().split('\n')lines = [x for x in lines if x and not x.startswith('#')]lines = [x.rstrip().lstrip() for x in lines]  # get rid of fringe whitespacesmodule_defs = []for line in lines:if line.startswith('['):  # This marks the start of a new blockmodule_defs.append({})module_defs[-1]['type'] = line[1:-1].rstrip()if module_defs[-1]['type'] == 'convolutional':module_defs[-1]['batch_normalize'] = 0else:key, value = line.split("=")value = value.strip()module_defs[-1][key.rstrip()] = value.strip()return module_defsclass Darknet(nn.Module):"""YOLOv3 object detection model"""def __init__(self, config_path, img_size=416):super(Darknet, self).__init__()self.module_defs = parse_model_config(config_path)self.hyperparams, self.module_list = create_modules(self.module_defs) #根据cfg文件创建模型self.yolo_layers = [layer[0] for layer in self.module_list if hasattr(layer[0], "metrics")]self.img_size = img_sizeself.seen = 0self.header_info = np.array([0, 0, 0, self.seen, 0], dtype=np.int32)def forward(self, x, targets=None):img_dim = x.shape[2]loss = 0layer_outputs, yolo_outputs = [], []for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):if module_def["type"] in ["convolutional", "upsample", "maxpool"]:x = module(x)elif module_def["type"] == "route":x = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)elif module_def["type"] == "shortcut":layer_i = int(module_def["from"])x = layer_outputs[-1] + layer_outputs[layer_i]elif module_def["type"] == "yolo":x, layer_loss = module[0](x, targets, img_dim)loss += layer_lossyolo_outputs.append(x)layer_outputs.append(x)yolo_outputs = to_cpu(torch.cat(yolo_outputs, 1))return yolo_outputs if targets is None else (loss, yolo_outputs)

上面写完整个神经网络根据cfg配置文件的代码框架就已经出来了。

如有错误,欢迎各位大佬指正!

cfg文件搭建神经网络并用pytorch读取创建模型相关推荐

  1. Ruby on Rails,创建模型,附赠模型与表名不一致时的解决方法

    在前文<Ruby on Rails,创建和执行migrations迁移文件>中我们提到过创建模型的事情,我们创建模型的同时生成迁移文件.那时候我们关注的是迁移文件,现在我们把目光投向模型这 ...

  2. Deep Learning:基于pytorch搭建神经网络的花朵种类识别项目(内涵完整文件和代码)—超详细完整实战教程

    基于pytorch的深度学习花朵种类识别项目完整教程(内涵完整文件和代码) 相关链接:: 超详细--CNN卷积神经网络教程(零基础到实战) 大白话pytorch基本知识点及语法+项目实战 文章目录 基 ...

  3. 基于pytorch搭建神经网络的花朵种类识别(深度学习)

    基于pytorch搭建神经网络的花朵种类识别(深度学习) 文章目录 基于pytorch搭建神经网络的花朵种类识别(深度学习) 一.知识点 1.特征提取.神经元逐层判断 2.中间层(隐藏层) 3.学习权 ...

  4. 【网络安全】如何搭建MySQL恶意服务器读取文件?

    前言 注:本文不涉及对MySQL协议报文研究,仅讲解原理,并且做部分演示. 搭建MySQL恶意服务器读取文件这件事,虽然直接利用门槛较高,但是由于在网上看到了一种比较新颖的利用方式(利用社会工程学引诱 ...

  5. python中怎么创建配置文件_在Python中创建游戏配置/选项(config.cfg)文件

    myConfig.cfg: [info] Width = 100 Height = 200 Name = My Game 解析python: import ConfigParser configPar ...

  6. c++读取cfg文件

    参考:https://www.cnblogs.com/zhuzhenwei918/p/8569160.html 四个文件 get_cfg.h,get_cfg.cpp,read_cfg.cpp,conf ...

  7. 我们用PyTorch搭建神经网络时,会遇到nn.ReLU(inplace=True),inplace=True是什么意思呢?

    我们用PyTorch搭建神经网络时,会遇到nn.ReLU(inplace=True),inplace=True是什么意思呢? nn.Conv2d(64,192,kernel_size=3,stride ...

  8. chrome js 读取文件夹_JS读取/创建本地文件及目录文件夹的方法

    注:以下操作只在IE下有效! Javascript是网页制作中离不开的脚本语言,依靠它,一个网页的内容才生动活泼.富有朝气.但也许你还没有发现并应用它的一些更高级的功能吧?比如,对文件和文件夹进行读. ...

  9. Javascript FileSystemObject 读取/创建本地文件及目录文件夹的方法

    注:以下操作只在IE下有效! Javascript是网页制作中离不开的脚本语言,依靠它,一个网页的内容才生动活泼.富有朝气.但也许你还没有发现并应用它的一些更高级的功能吧?比如,对文件和文件夹进行读. ...

最新文章

  1. PowerDesigner的样式设置
  2. entity、model和domain三者区别
  3. tessorflow实战
  4. MVC4.0网站发布和部署到IIS7.0上的方法
  5. 2pin接口耳机_悦耳好音质,续航10小时,用了小米生态链这款耳机,扔掉其它吧...
  6. mysql 占用swap_查看swap占用情况
  7. 本地主机作服务器解决AJAX跨域请求访问数据的方法
  8. mysql server 5.7.16_mysql 5.7.16 安装配置方法图文教程(ubuntu 16.04)
  9. win7计算机未连接网络连接,解决win7能上网但是网络图标显示未连接的方法-win7之家...
  10. 夜班工作有哪些优缺点?
  11. php程序员需要精通js的程度_PHP程序员基本要求和必备技能
  12. spring自动注入bean
  13. 新概念第三册背诵: Lesson 1 - A Puma at large
  14. SCC1传输请求(同系统跨Client)
  15. python登录代码_python自动登录126等邮箱的代码
  16. Latex 对号和叉号的
  17. 鼠标跟计算机的USB设备运行不正常,为什么鼠标跟这台计算机连接的一个USB设备运行不正常,windo? 爱问知识人...
  18. Vue 组件封装之 Questionnaire 问卷调查
  19. c语言编程第四版李丽娟课程,C语言程序设计教程 第4版 普通高等教育“十一五”国家级规划教材 教学课件 李丽娟 C语言程序设计教程(第4版)_第4章_分支结构.pdf...
  20. 电信 NB-IoT无缝对接阿里云IoT 物联网平台

热门文章

  1. 阅读作业二之waterfall——洪虹
  2. Console调试常用用法
  3. UEFI 基础教程 (零) - 目录
  4. 台灯哪个牌子的比较好保护视力的?推荐五款护眼台灯
  5. from PIL import Image不能导入Image
  6. 李维作答 《insideVCL》(转摘)
  7. 【图像分类】2021-EfficientNetV2 CVPR
  8. HibernateException - A collection with cascade=all-delete-orphan was no longer referenced by the
  9. 低代码如何助力广播媒体行业构建数字系统
  10. 11.01T2 树状数组维护动态LIS