目录

DeamNet网络架构图

1.  定义各种类(class)

2.  定义编码-解码块

3. DEAM 模块和 NLO子网络

4. 图像域转换与逆转换

5. 整体DeamNet

DeamNet网络架构图

1.  定义各种类(class)

nn.Module类的基本定义

在定义网络时,需要继承nn.Module类,并重新实现构造函数__init__()和forward这两个方法。在构造函数__init__()中使用super(Model, self).init()来调用父类的构造函数,forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
1.一般把网络中具有可学习参数的层(如全连接层、卷积层)放在构造函数__init__()中。
2.一般把不具有可学习参数的层(如ReLU、dropout)可放在构造函数中,也可不放在构造函数中(在forward中使用nn.functional来调用)。

import torch
import torch.nn as nnclass ConvLayer1(nn.Module):  #ConvLayer1为子类,nn.Module为父类def __init__(self, in_channels, out_channels, kernel_size, stride):#  in_channel: 输入数据的通道数,out_channel: 输出数据的通道数,stride 步长super(ConvLayer1, self).__init__()   # 调用父类的构造函数self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride= stride)# nn.Conv2d 二维卷积可处理二维数据nn.init.xavier_normal_(self.conv2d.weight.data)# nn.init. 参数初始化方法  xavier初始化方式 normal_ 正态分布  # .weight.data:得到的是一个Tensor的张量(向量),不可训练的类型def forward(self, x): # out = self.reflection_pad(x)# out = self.conv2d(out)return self.conv2d(x)class ConvLayer(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride):super(ConvLayer, self).__init__()padding = (kernel_size - 1) // 2self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, stride=stride),nn.ReLU())# self.属性名 = 属性的初始值.  添加属性并赋值  # nn.Sequential是一个有序的容器,神经网络模块按照传入构造器的顺序被添加到计算图中执行nn.init.xavier_normal_(self.block[0].weight.data)def forward(self, x):return self.block(x)class line(nn.Module):def __init__(self):super(line, self).__init__()#randn(*size, out=None, dtype=None) 返回一个张量,包含了从标准正态分布(均值为0,方差为1,即高斯白噪声)中抽取的一组随机数,张量的形状由sizes定义self.delta = nn.Parameter(torch.randn(1, 1))# torch.mul(input, other, *, out=None) 输入:两个张量矩阵;输出:他们的点乘运算结果def forward(self, x, y):return torch.mul((1 - self.delta), x) + torch.mul(self.delta, y)

torch.nn.Parameter()

首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

2.  定义编码-解码块

class Encoding_block(nn.Module):def __init__(self, base_filter, n_convblock):super(Encoding_block, self).__init__()self.n_convblock = n_convblockmodules_body = []   # 空列表 代表list列表数据类型for i in range(self.n_convblock - 1):modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1))#  append()函数用于在列表末尾添加新的对象modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=2))self.body = nn.Sequential(*modules_body)# nn.Sequential的定义来看,输入遇到list,必须用*号进行转化,否则会报错def forward(self, x):for i in range(self.n_convblock - 1):x = self.body[i](x)ecode = xx = self.body[self.n_convblock - 1](x)return ecode, xclass UpsampleConvLayer(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):super(UpsampleConvLayer, self).__init__()self.upsample = upsampleself.conv2d = ConvLayer(in_channels, out_channels, kernel_size, stride)def forward(self, x):x_in = xif self.upsample:# torch.nn.functional.interpolate实现插值和上采样x_in = torch.nn.functional.interpolate(x_in, scale_factor=self.upsample)out = self.conv2d(x_in)return outclass upsample1(nn.Module):def __init__(self, base_filter):super(upsample1, self).__init__()self.conv1 = ConvLayer(base_filter, base_filter, 3, stride=1)self.ConvTranspose = UpsampleConvLayer(base_filter, base_filter, kernel_size=3, stride=1, upsample=2) # 转置卷积self.cat = ConvLayer1(base_filter * 2, base_filter, kernel_size=1, stride=1)def forward(self, x, y):y = self.ConvTranspose(y)x = self.conv1(x)# torch.cat 在给定维度上对输入的张量序列seq进行拼接。return self.cat(torch.cat((x, y), dim=1))class Decoding_block2(nn.Module):def __init__(self, base_filter, n_convblock):super(Decoding_block2, self).__init__()self.n_convblock = n_convblockself.upsample = upsample1(base_filter)modules_body = []for i in range(self.n_convblock - 1):modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1))modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1))self.body = nn.Sequential(*modules_body)def forward(self, x, y):x = self.upsample(x, y)for i in range(self.n_convblock):x = self.body[i](x)return x

3. DEAM 模块和 NLO子网络

# Corresponds to DEAM Module in NLO Sub-network
class Attention_unet(nn.Module):# 注意力机制 参数reduction为缩减率def __init__(self, channel, reduction=16):super(Attention_unet, self).__init__()#  // 的用法还没搜到,猜测是换行self.conv_du = nn.Sequential(ConvLayer1(in_channels=channel, out_channels=channel // reduction, kernel_size=3, stride=1),nn.ReLU(inplace=True),ConvLayer1(in_channels=channel // reduction, out_channels=channel, kernel_size=3, stride=1),nn.Sigmoid())self.cat = ConvLayer1(in_channels=channel * 2, out_channels=channel, kernel_size=1, stride=1)self.C = ConvLayer1(in_channels=channel, out_channels=channel, kernel_size=3, stride=1)self.ConvTranspose = UpsampleConvLayer(channel, channel, kernel_size=3, stride=1, upsample=2)  # up-samplingdef forward(self, x, g):up_g = self.ConvTranspose(g)  # 对应文中SA上采样模块weight = self.conv_du(self.cat(torch.cat([self.C(x), up_g], 1)))rich_x = torch.mul((1 - weight), up_g) + torch.mul(weight, x)return rich_x  # 返回Deam模块的输出

self.conv_du 对应文中的weights mapping模块,即图4中的3维卷积+Relu+3维卷积+Sigmoid

前面加上一个1维卷积连接起来得到(WM模块),weight 即生成的加权张量α

# Corresponds to NLO Sub-network
class ziwangluo1(nn.Module):def __init__(self, base_filter, n_convblock_in, n_convblock_out):super(ziwangluo1, self).__init__()# ConvLayer(in_channels, out_channels, kernel_size, stride)self.conv_dila1 = ConvLayer1(64, 64, 3, 1)self.conv_dila2 = ConvLayer1(64, 64, 5, 1)self.conv_dila3 = ConvLayer1(64, 64, 7, 1)#  nn.Conv2d的参数dilation:膨胀卷积. Pytoch中dilation默认为1,但是实际为不膨胀self.cat1 = torch.nn.Conv2d(in_channels=64 * 3, out_channels=64, kernel_size=1, stride=1, padding=0,dilation=1, bias=True)nn.init.xavier_normal_(self.cat1.weight.data)self.e3 = Encoding_block(base_filter, n_convblock_in)self.e2 = Encoding_block(base_filter, n_convblock_in)self.e1 = Encoding_block(base_filter, n_convblock_in)self.e0 = Encoding_block(base_filter, n_convblock_in)# 文中迭代次数K=4self.attention3 = Attention_unet(base_filter)self.attention2 = Attention_unet(base_filter)self.attention1 = Attention_unet(base_filter)self.attention0 = Attention_unet(base_filter)# 定义的ConvLayer 包含Relu(线性修正单元)self.mid = nn.Sequential(ConvLayer(base_filter, base_filter, 3, 1),ConvLayer(base_filter, base_filter, 3, 1))self.de3 = Decoding_block2(base_filter, n_convblock_out)self.de2 = Decoding_block2(base_filter, n_convblock_out)self.de1 = Decoding_block2(base_filter, n_convblock_out)self.de0 = Decoding_block2(base_filter, n_convblock_out)# 定义的ConvLayer1 不包含Reluself.final = ConvLayer1(base_filter, base_filter, 3, stride=1)def forward(self, x):_input = xencode0, down0 = self.e0(x)encode1, down1 = self.e1(down0)encode2, down2 = self.e2(down1)encode3, down3 = self.e3(down2)# media_end = self.Encoding_block_end(down3)media_end = self.mid(down3)g_conv3 = self.attention3(encode3, media_end)up3 = self.de3(g_conv3, media_end)g_conv2 = self.attention2(encode2, up3)up2 = self.de2(g_conv2, up3)g_conv1 = self.attention1(encode1, up2)up1 = self.de1(g_conv1, up2)g_conv0 = self.attention0(encode0, up1)up0 = self.de0(g_conv0, up1)final = self.final(up0)return _input + finalclass line(nn.Module):def __init__(self):super(line, self).__init__()self.delta = nn.Parameter(torch.randn(1, 1))def forward(self, x, y):return torch.mul((1 - self.delta), x) + torch.mul(self.delta, y)
# 对应 DEAM 模块
class SCA(nn.Module):def __init__(self, channel, reduction=16):super(SCA, self).__init__()self.conv_du = nn.Sequential(ConvLayer1(in_channels=channel, out_channels=channel // reduction, kernel_size=3, stride=1),nn.ReLU(inplace=True),ConvLayer1(in_channels=channel // reduction, out_channels=channel, kernel_size=3, stride=1),nn.Sigmoid())def forward(self, x):y = self.conv_du(x)return yclass Weight(nn.Module):def __init__(self, channel):super(Weight, self).__init__()self.cat = ConvLayer1(in_channels=channel * 2, out_channels=channel, kernel_size=1, stride=1)self.C = ConvLayer1(in_channels=channel, out_channels=channel, kernel_size=3, stride=1)self.weight = SCA(channel)def forward(self, x, y):delta = self.weight(self.cat(torch.cat([self.C(y), x], 1)))return delta

4. 图像域转换与逆转换

根据文中介绍转换特征域(FD)与像素域的模块得到下面代码的网络结构

class transform_function(nn.Module):def __init__(self, in_channel, out_channel):super(transform_function, self).__init__()self.ext = ConvLayer1(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1)self.pre = torch.nn.Sequential(ConvLayer1(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1),nn.ReLU(inplace=True),ConvLayer1(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1),)def forward(self, x):y = self.ext(x)return y + self.pre(y)# 图像域变换与逆变换中定义的self.pre通道数不同class Inverse_transform_function(nn.Module):def __init__(self, in_channel, out_channel):super(Inverse_transform_function, self).__init__()self.ext = ConvLayer1(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1)self.pre = torch.nn.Sequential(ConvLayer1(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1),nn.ReLU(inplace=True),ConvLayer1(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1),)def forward(self, x):x = self.pre(x) + xx = self.ext(x)return x

5. 整体DeamNet

class Deam(nn.Module):def __init__(self, Isreal):super(Deam, self).__init__()if Isreal:# 由以下代码可知像素域通道数为3(或1,即彩色图像与灰度图),特征域通道数为64self.transform_function = transform_function(3, 64)self.inverse_transform_function = Inverse_transform_function(64, 3)else:self.transform_function = transform_function(1, 64)self.inverse_transform_function = Inverse_transform_function(64, 1)# 从网络结构看,转换后的X输入到NLO子网络和DEAM模块中self.line11 = Weight(64)self.line22 = Weight(64)self.line33 = Weight(64)self.line44 = Weight(64)self.net2 = ziwangluo1(64, 3, 2)   # DeamNet中NLO子网络的参数共享def forward(self, x):x = self.transform_function(x)y = x# Corresponds to NLO Sub-networkx1 = self.net2(y)# Corresponds to DEAM Moduledelta_1 = self.line11(x1, y)# 这里y对应低分辨率的分支,x1对应高分辨率分支x1 = torch.mul((1 - delta_1), x1) + torch.mul(delta_1, y)x2 = self.net2(x1)delta_2 = self.line22(x2, y)x2 = torch.mul((1 - delta_2), x2) + torch.mul(delta_2, y)x3 = self.net2(x2)delta_3 = self.line33(x3, y)x3 = torch.mul((1 - delta_3), x3) + torch.mul(delta_3, y)x4 = self.net2(x3)delta_4 = self.line44(x4, y)x4 = torch.mul((1 - delta_4), x4) + torch.mul(delta_4, y)x4 = self.inverse_transform_function(x4)return x4def print_network(net):num_params = 0for param in net.parameters():# += 先将运算符左边和右边的变量值相加,然后将相加的结果赋值给左边的变量# param.numel()  返回param中元素的数量num_params += param.numel()print(net) # 打印的是网络名,没有网络结构#  字符串输出: %d 有符号的十进制整数print('Total number of parameters: %d' % num_params)

net.parameters()

逐列表项输出列表元素。构建好神经网络后,网络的参数都保存在parameters()函数当中

与net.named_parameters()的输出相对比,net.parameters()的输出里只包含参数的值,不包含参数的所属信息。

DeamNet代码学习||网络框架核心代码 逐句查找学习相关推荐

  1. Java 线程池框架核心代码分析

    转载自 Java 线程池框架核心代码分析 前言 多线程编程中,为每个任务分配一个线程是不现实的,线程创建的开销和资源消耗都是很高的.线程池应运而生,成为我们管理线程的利器.Java 通过Executo ...

  2. c语言学习,使用文档来查找学习库函数

    目录 资源推荐 使用文档来查找学习库函数 以www.cplusplus.com学习为例 msdn http://zh.cppreference.com--c/c++的官网 资源推荐 www.cplus ...

  3. UCHome风格模版 框架核心代码提取

    uchome是个sns系统,但也是拥有深厚php技术积累的康盛公司的商业产品,本身有许多值得学习的地方,你可以用它来架设垂直的sns网站,也可以学习他的一些技巧,提高自己的代码水平,改善代码质量. 对 ...

  4. Java 线程池框架核心代码分析--转

    原文地址:http://www.codeceo.com/article/java-thread-pool-kernal.html 前言 多线程编程中,为每个任务分配一个线程是不现实的,线程创建的开销和 ...

  5. Java线程池框架核心代码分析

    前言 多线程编程中,为每个任务分配一个线程是不现实的,线程创建的开销和资源消耗都是很高的.线程池应运而生,成为我们管理线程的利器.Java 通过Executor接口,提供了一种标准的方法将任务的提交过 ...

  6. 黑马程序员视频教程学习mybatis框架常用注释SQL语句学习笔记?

    mybatis学习笔记 常用注释增删改查SQL语句 常用注释拓展SQL语句 解决实体类属性和数据库表中的属性名称不相同的问题: mybatis注解之一对一查询: mybatis注解之一对多查询: my ...

  7. 脑电EEG代码开源分享 【6. 分类模型-深度学习篇】

    往期文章 希望了解更多的道友点这里 0. 分享[脑机接口 + 人工智能]的学习之路 1.1 . 脑电EEG代码开源分享 [1.前置准备-静息态篇] 1.2 . 脑电EEG代码开源分享 [1.前置准备- ...

  8. 安卓网络框架,上传图片花图,上传状态411被服务器驳回

    先看下一开始使用的网络框架核心代码: private Message doPost(final String url, final Map<String, String> params, ...

  9. 9 行代码提高少样本学习泛化能力,代码已开源

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自 | 新智元 来源 | 知乎 作者 | 杨朔 本文介绍一篇最新发 ...

最新文章

  1. 软件测试中的冲突测试
  2. java增加缓存,java – 如何增加Integer对象的缓存大小
  3. 离散数学序关系与相容关系
  4. Ubuntu 14.04 LTS 配置 Juno 版 Keystone
  5. 前后端分离项目,后端是如何处理前端传递的token?
  6. Visual Stdio 注册表相关路径
  7. 李宏毅机器学习HW2-winner or loser-利用逻辑回归进行收入分类
  8. 【暂时完结】Prescan学习笔记
  9. Amos24程序安装及注意事项
  10. “薪水”种种表达方法
  11. Cheat Engine游戏脚本修改器通关教程(脑残版Step9)
  12. MySQL数据库 学习笔记 零基础入门 面试 整理
  13. SecureCRT + SecureFX 8.1 Bundle安装注册教程(完美版)
  14. Hadoop 之上的数据建模 - Data Vault 2.0
  15. SSM优秀宿舍评选系统毕业设计-附源码221511
  16. 江南百景图显示服务器错误,江南百景图通讯失败请保持网络畅通并重试
  17. 密码算法测试题解析之单选题(一)
  18. 在DX12中使用imgui 入门教程 立方体旋转+改变背景颜色
  19. NRF通信中使用的线圈、高频卡、低频卡
  20. php瑜伽馆源码,深蓝健身房瑜伽馆行业小程序源代码4.15.0

热门文章

  1. python语言处理初探——分词、词性标注、提取名词
  2. 黑客瞄上火热人气 网上观影勿忘防毒
  3. 崩坏学园3里离摄像机近距离的头发透明效果在unity里的实现方法
  4. 维也纳大学:光量子忆阻器有望解锁AI神经网络
  5. Unity3D实现地图编辑器的插件
  6. 谢希仁《计算机网络》笔记
  7. 16进制 转为图片 php_十六进制编辑器(010 Editor)官方版下载_十六进制编辑器(010 Editor) v11.0中文汉化版64位...
  8. loj 523 「LibreOJ β Round #3」绯色 IOI(悬念) 霍尔定理+基环树+线段树
  9. ePSXe 1.7.0与DAEMON Tools Lite配合的问题
  10. php模糊查询数组,php 数组模糊查询