之前的博文《论文阅读笔记之——《Multi-level Wavelet-CNN for Image Restoration》及基于pytorch的复现》曾经研究过WMCNN。本博文就是采用DWT变换代替octave中的pooling

代码

类似于博文《实验笔记之——octave conv (without pooling)》对octave layer的结构进行改进如下:

pytorch中实现离散小波变换

https://github.com/fbcotter/pytorch_wavelets

git clone https://github.com/fbcotter/pytorch_wavelets

cd pytorch_wavelets
pip install .
pip install -r tests/requirements.txt

测试

改修代码如下:

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_OctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()self.ifm = DTCWTInverse( biort='near_sym_b', qshift='qshift_b').cuda()self.stride = strideself.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):X_h, X_l = x#if self.stride ==2:#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)X_l,X_ll=self.xfm(X_l)X_h2h = self.h2h(X_h)#X_l2h = self.upsample(self.l2h(X_l))#X_l2h = self.l2h(X_l)#=self.l2h(X_l)# print(X_l.shape,"~~~~",X_ll[0].shape,X_ll[1].shape,X_ll[2].shape)# exit()X_l2ha = self.ifm((X_l,X_ll))X_l2h = self.l2h(X_l2ha)# print(X_l.shape,"~~~~",X_ll[0].shape)# exit()# X_l2h = self.ifm((X_l,X_ll))X_l2l = self.l2l(X_l2ha)X_h2l = self.h2l(X_h)X_h2l,X_h2ll=self.xfm(X_h2l)X_l2l,X_lla=self.xfm(X_l2l)#print(X_lla[0].shape,"~~~~",X_h2ll[0].shape)X_h = X_l2h + X_h2hX_l = X_h2l + X_l2l# print(X_lla[0].shape,"~~~~",X_h2ll[0].shape)#exit()X_ll[0]=X_lla[0]+X_h2ll[0]X_ll[1]=X_lla[1]+X_h2ll[1]X_ll[2]=X_lla[2]+X_h2ll[2]X_l=self.ifm((X_l,X_ll))if self.n_h and self.n_l:X_h = self.n_h(X_h)X_l = self.n_l(X_l)if self.a:X_h = self.a(X_h)X_l = self.a(X_l)return X_h, X_lclass DWT_FirstOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_FirstOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()self.stride = stride###low frequencyself.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)###high frequencyself.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):#if self.stride ==2:#x = self.h2g_pool(x)X_h = self.h2h(x)X_l = self.h2l(x)if self.n_h and self.n_l:##batch normX_h = self.n_h(X_h)X_l = self.n_l(X_l)if self.a:#Activation layerX_h = self.a(X_h)X_l = self.a(X_l)return X_h, X_lclass DWT_LastOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_LastOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()self.stride = strideself.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, out_nc) if norm_type else Nonedef forward(self, x):X_h, X_l = x#if self.stride ==2:#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)X_h2h = self.h2h(X_h)X_l2h=self.l2h(X_l)#X_l2h = self.l2h(X_l)X_h = X_h2h + X_l2hif self.n_h:X_h = self.n_h(X_h)if self.a:X_h = self.a(X_h)return X_hclass DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):'''Residual Dense Blockstyle: 4 convsThe core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)'''def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()# gc: growth channel, i.e. intermediate channelsself.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)if mode == 'CNA':last_act = Noneelse:last_act = act_typeself.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=last_act, mode=mode)def forward(self, x):x1 = self.conv1(x)x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),(torch.cat((x[1], x1[1]), dim=1))))x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),(torch.cat((x[1], x1[1],x2[1]), dim=1))))x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),(torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1))))res = (x4[0].mul(0.2), x4[1].mul(0.2))x = (x[0] + res[0], x[1] + res[1])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return xclass DWT_octave_RRDBTiny(nn.Module):'''Residual in Residual Dense Block(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)'''def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_RRDBTiny, self).__init__()self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)def forward(self, x):out = self.RDB1(x)out = self.RDB2(out)res = (out[0].mul(0.2), out[1].mul(0.2))x = (x[0] + res[0], x[1] + res[1])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return x
##################this is ESRGAN based on DWT_octave
class DWT_Octave_RRDBNet(nn.Module):def __init__(self, in_nc, out_nc, nf, nb, gc=32,alpha=0.125, upscale=4, norm_type=None, \act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):super(DWT_Octave_RRDBNet, self).__init__()n_upscale = int(math.log(upscale, 2))if upscale == 3:n_upscale = 1fea_conv1 = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)fea_conv = B.DWT_FirstOctaveConv(nf, nf, kernel_size=3,alpha=alpha, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')rb_blocks = [B.DWT_octave_RRDBTiny(nf, kernel_size=3, gc=32,alpha=alpha,stride=1, bias=True, pad_type='zero', \norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]LR_conv = B.DWT_LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')if upsample_mode == 'upconv':upsample_block = B.upconv_blcokelif upsample_mode == 'pixelshuffle':upsample_block = B.pixelshuffle_blockelse:raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))if upscale == 3:upsampler = upsample_block(nf, nf, 3, act_type=act_type)else:upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)self.model = B.sequential(fea_conv1,B.ShortcutBlock(B.sequential(fea_conv,*rb_blocks, LR_conv)),\*upsampler, HR_conv0, HR_conv1)def forward(self, x):x = self.model(x)return x
##############################################################################################

实验

改进

通过测试有新发现

https://blog.csdn.net/qq_40587575/article/details/83154042

从上面测试可以看出,只需要J=1

改进代码如下:

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_OctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0#self.xfm = DWTForward(J=1, wave='db3', mode='zero')#self.ifm = DWTInverse(wave='db3', mode='zero')self.stride = stride# self.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),#                         kernel_size, 1, padding, dilation, groups, bias)# self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),#                         kernel_size, 1, padding, dilation, groups, bias)# self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),#                         kernel_size, 1, padding, dilation, groups, bias)# self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),#                         kernel_size, 1, padding, dilation, groups, bias)self.l2l = nn.Conv2d(in_nc, out_nc,kernel_size, 1, padding, dilation, groups, bias)self.l2h = nn.Conv2d(in_nc, out_nc,kernel_size, 1, padding, dilation, groups, bias)self.h2l = nn.Conv2d(in_nc, out_nc,kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc, out_nc,kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):X_ll,X_lh,X_hl,X_hh = x# print(X_ll.shape,'~~~',X_lh.shape,'~~~',X_lh.shape,'~~~',X_lh.shape)# exit()# A,B=self.xfm(x)# X_ll=A# X_lh=B[0][:,:,0]# X_hl=B[0][:,:,1]# X_hh=B[0][:,:,2]#if self.stride ==2:#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)X_hh2h = self.h2h(X_hh)#X_l2h = self.upsample(self.l2h(X_l))X_lh2h = self.l2h(X_lh)#X_l2h = self.upsample(self.l2h(X_l))X_ll2l = self.l2l(X_ll)#X_h2l = self.h2l(self.h2g_pool(X_h))X_hl2l = self.h2l(X_hl)#X_h2l = self.h2l(self.h2g_pool2(self.h2g_pool(X_h)))#print(X_l2h.shape,"~~~~",X_h2h.shape)X_hh=X_hh2hX_lh=X_lh2hX_ll=X_ll2lX_hl=X_hl2lif self.n_h and self.n_l:X_hh= self.n_h(X_hh)X_hh=self.n_h(X_hh)X_lh=self.n_h(X_lh)X_ll=self.n_h(X_ll)X_hl=self.n_h(X_hl)if self.a:X_hh = self.a(X_hh)X_hh=self.a(X_hh)X_lh=self.a(X_lh)X_ll=self.a(X_ll)X_hl=self.a(X_hl)# A=X_ll# B[0][:,:,0]=X_lh# B[0][:,:,1]=X_hl# B[0][:,:,2]=X_hh# x=ifm((A,B))return X_ll,X_lh,X_hl,X_hhclass DWT_FirstOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_FirstOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()self.stride = stride###low frequencyself.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)###high frequencyself.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):#if self.stride ==2:#x = self.h2g_pool(x)#X_h = self.h2h(x)#X_l = self.h2l(x)A,B=self.xfm(x)X_ll=AX_lh=B[0][:,:,0]X_hl=B[0][:,:,1]X_hh=B[0][:,:,2]# if self.n_h and self.n_l:##batch norm#     X_h = self.n_h(X_h)#     X_l = self.n_l(X_l)# if self.a:#Activation layer#     X_h = self.a(X_h)#     X_l = self.a(X_l)return X_ll,X_lh,X_hl,X_hhclass DWT_LastOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_LastOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()self.ifm = DWTInverse(wave='db3', mode='zero').cuda()self.stride = strideself.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, out_nc) if norm_type else Nonedef forward(self, x):X_ll,X_lh,X_hl,X_hh = x#if self.stride ==2:#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)A=X_llc= A.shapeC = torch.randn(c[0], c[1], 3, c[-2], c[-1])#X_lh_ = torch.unsqueeze(X_lh, 2)#X_hl_ = torch.unsqueeze(X_hl, 2)#X_hh_ = torch.unsqueeze(X_hh, 2)#C=torch.cat((X_lh_,X_hl_,X_hh_), dim=2)C[:,:,0]=X_lhC[:,:,1]=X_hlC[:,:,2]=X_hh#C=C.cpu()C_ = [C.cuda()]#A=A.cpu()X_h=self.ifm((A,C_))# print(X_h.shape)# exit()# X_h2h = self.h2h(X_h)# X_l2h=self.l2h(X_l)# #X_l2h = self.l2h(X_l)# X_h = X_h2h + X_l2h# if self.n_h:#     X_h = self.n_h(X_h)# if self.a:#     X_h = self.a(X_h)return X_hclass DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):'''Residual Dense Blockstyle: 4 convsThe core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)'''def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()# gc: growth channel, i.e. intermediate channelsself.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)if mode == 'CNA':last_act = Noneelse:last_act = act_typeself.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=last_act, mode=mode)def forward(self, x):# print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)# exit()x1 = self.conv1(x)x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1),torch.cat((x[2], x1[2]), dim=1),torch.cat((x[3], x1[3]), dim=1)))x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1),torch.cat((x[2], x1[2],x2[2]), dim=1),torch.cat((x[3], x1[3],x2[3]), dim=1)))x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1),torch.cat((x[2], x1[2],x2[2],x3[2]), dim=1),torch.cat((x[3], x1[3],x2[3],x3[3]), dim=1)))res = (x4[0].mul(0.2), x4[1].mul(0.2),x4[2].mul(0.2),x4[3].mul(0.2))x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return xclass DWT_octave_RRDBTiny(nn.Module):'''Residual in Residual Dense Block(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)'''def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_RRDBTiny, self).__init__()self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)def forward(self, x):out = self.RDB1(x)out = self.RDB2(out)res = (out[0].mul(0.2), out[1].mul(0.2),out[2].mul(0.2),out[3].mul(0.2))x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return x

实验结果:

改进2

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_OctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()self.ifm = DWTInverse(wave='db3', mode='zero').cuda()self.stride = strideself.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):X_h, X_l = x# print(X_ll.shape,'~~~',X_lh.shape,'~~~',X_lh.shape,'~~~',X_lh.shape)# exit()# A,B=self.xfm(x)# X_ll=A# X_lh=B[0][:,:,0]# X_hl=B[0][:,:,1]# X_hh=B[0][:,:,2]#if self.stride ==2:#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)X_h2h = self.h2h(X_h)#X_l2h = self.upsample(self.l2h(X_l))X_l2h = self.l2h(X_l)#X_l2h = self.upsample(self.l2h(X_l))#DWT for X_hA,B=self.xfm(X_h)X_hll=AX_hlh=B[0][:,:,0]X_hhl=B[0][:,:,1]X_hhh=B[0][:,:,2]#transferX_hll2l=self.h2l(X_hll)X_hlh2l=self.h2l(X_hlh)X_hhl2l=self.h2l(X_hhl)X_hhh2l=self.h2l(X_hhh)#DWT for X_lC,D=self.xfm(X_l)X_lll=CX_llh=D[0][:,:,0]X_lhl=D[0][:,:,1]X_lhh=D[0][:,:,2]#transferX_lll2l=self.l2l(X_lll)X_llh2l=self.l2l(X_llh)X_lhl2l=self.l2l(X_lhl)X_lhh2l=self.l2l(X_lhh)#X_ll2l = self.l2l(X_ll)#X_h2l = self.h2l(self.h2g_pool(X_h))#X_hl2l = self.h2l(X_hl)#X_h2l = self.h2l(self.h2g_pool2(self.h2g_pool(X_h)))#print(X_l2h.shape,"~~~~",X_h2h.shape)X_h=X_h2h+X_l2hE=X_lll2l+X_hll2lf= E.shapeF = torch.randn(f[0], f[1], 3, f[-2], f[-1])F[:,:,0]=X_llh2l+X_hlh2lF[:,:,1]=X_lhl2l+X_hhl2lF[:,:,2]=X_lhh2l+X_hhh2lF_ = [F.cuda()]X_l=self.ifm((E,F_))if self.n_h and self.n_l:X_h = self.n_h(X_h)X_l = self.n_l(X_l)if self.a:X_h = self.a(X_h)X_l = self.a(X_l)return X_h, X_lclass DWT_FirstOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_FirstOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()self.stride = stride###low frequencyself.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)###high frequencyself.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):#if self.stride ==2:#x = self.h2g_pool(x)X_h = self.h2h(x)X_l = self.h2l(x)if self.n_h and self.n_l:##batch normX_h = self.n_h(X_h)X_l = self.n_l(X_l)if self.a:#Activation layerX_h = self.a(X_h)X_l = self.a(X_l)return X_h,X_lclass DWT_LastOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_LastOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double pool#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()#self.ifm = DWTInverse(wave='db3', mode='zero').cuda()self.stride = strideself.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, out_nc) if norm_type else Nonedef forward(self, x):X_h,X_l = x#if self.stride ==2:#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)X_h2h = self.h2h(X_h)X_l2h=self.l2h(X_l)#X_l2h = self.l2h(X_l)X_h = X_h2h + X_l2hif self.n_h:X_h = self.n_h(X_h)if self.a:X_h = self.a(X_h)return X_hclass DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):'''Residual Dense Blockstyle: 4 convsThe core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)'''def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()# gc: growth channel, i.e. intermediate channelsself.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)if mode == 'CNA':last_act = Noneelse:last_act = act_typeself.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=last_act, mode=mode)def forward(self, x):# print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)# exit()x1 = self.conv1(x)x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1)))x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1)))x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1)))res = (x4[0].mul(0.2), x4[1].mul(0.2))x = (x[0] + res[0], x[1] + res[1])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return xclass DWT_octave_RRDBTiny(nn.Module):'''Residual in Residual Dense Block(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)'''def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_RRDBTiny, self).__init__()self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)def forward(self, x):out = self.RDB1(x)out = self.RDB2(out)res = (out[0].mul(0.2), out[1].mul(0.2))x = (x[0] + res[0], x[1] + res[1])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return x

结果

改进3

##################################################################################
##################################################################################
##################################################################################
#DWT octave
# Block for OctConv
####################
class DWT_OctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_OctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()self.ifm = DWTInverse(wave='db3', mode='zero').cuda()self.stride = strideself.l2l = nn.Conv2d(int(alpha * in_nc), int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.l2h = nn.Conv2d(int(alpha * in_nc), out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.h2l = nn.Conv2d(in_nc - int(alpha * in_nc), int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh = x#for X_h to hX_h2h = self.h2h(X_h)#X_l2h = self.upsample(self.l2h(X_l))#get X_lH=X_l_llj= H.shapeJ = torch.randn(j[0], j[1], 3, j[-2], j[-1])J[:,:,0]=X_l_lhJ[:,:,1]=X_l_hlJ[:,:,2]=X_l_hhJ_ = [J.cuda()]X_l=self.ifm((H,J_))#X_l to hX_l2h = self.l2h(X_l)#DWT for X_hA,B=self.xfm(X_h)X_hll=AX_hlh=B[0][:,:,0]X_hhl=B[0][:,:,1]X_hhh=B[0][:,:,2]#transferX_hll2l=self.h2l(X_hll)X_hlh2l=self.h2l(X_hlh)X_hhl2l=self.h2l(X_hhl)X_hhh2l=self.h2l(X_hhh)#for X_l series (X_l_ll,X_l_lh,X_l_hl,X_l_hh)#transferX_lll2l=self.l2l(X_l_ll)X_llh2l=self.l2l(X_l_lh)X_lhl2l=self.l2l(X_l_hl)X_lhh2l=self.l2l(X_l_hh)#for X_hX_h=X_h2h+X_l2h#for X_l series (X_l_ll,X_l_lh,X_l_hl,X_l_hh)X_l_ll=X_lll2l+X_hll2lX_l_lh=X_llh2l+X_hlh2lX_l_hl=X_lhl2l+X_hhl2lX_l_hh=X_lhh2l+X_hhh2lif self.n_h and self.n_l:X_h = self.n_h(X_h)#X_l = self.n_l(X_l)X_l_ll = self.n_l(X_l_ll)X_l_lh = self.n_l(X_l_lh)X_l_hl = self.n_l(X_l_hl)X_l_hh = self.n_l(X_l_hh)if self.a:X_h = self.a(X_h)#X_l = self.a(X_l)X_l_ll = self.a(X_l_ll)X_l_lh = self.a(X_l_lh)X_l_hl = self.a(X_l_hl)X_l_hh = self.a(X_l_hh)return X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hhclass DWT_FirstOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_FirstOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0self.xfm = DWTForward(J=1, wave='db3', mode='zero').cuda()#self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.h2g_pool2 = nn.AvgPool2d(kernel_size=(2, 2), stride=2)#self.xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()#self.ifm = DTCWTInverse(biort='near_sym_b', qshift='qshift_b').cuda()self.stride = stride###low frequencyself.h2l = nn.Conv2d(in_nc, int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)###high frequencyself.h2h = nn.Conv2d(in_nc, out_nc - int(alpha * out_nc),kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, int(out_nc*(1 - alpha))) if norm_type else Noneself.n_l = norm(norm_type, int(out_nc*alpha)) if norm_type else Nonedef forward(self, x):#if self.stride ==2:#x = self.h2g_pool(x)X_h = self.h2h(x)X_l = self.h2l(x)A,B=self.xfm(X_l)X_l_ll=AX_l_lh=B[0][:,:,0]X_l_hl=B[0][:,:,1]X_l_hh=B[0][:,:,2]if self.n_h and self.n_l:##batch normX_h = self.n_h(X_h)#X_l = self.n_l(X_l)X_l_ll = self.n_l(X_l_ll)X_l_lh = self.n_l(X_l_lh)X_l_hl = self.n_l(X_l_hl)X_l_hh = self.n_l(X_l_hh)if self.a:#Activation layerX_h = self.a(X_h)#X_l = self.a(X_l)X_l_ll = self.a(X_l_ll)X_l_lh = self.a(X_l_lh)X_l_hl = self.a(X_l_hl)X_l_hh = self.a(X_l_hh)return X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hhclass DWT_LastOctaveConv(nn.Module):def __init__(self, in_nc, out_nc, kernel_size, alpha=0.5, stride=1, dilation=1, groups=1, \bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA'):super(DWT_LastOctaveConv, self).__init__()assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)padding = get_valid_padding(kernel_size, dilation) if pad_type == 'zero' else 0#self.upsample = nn.Upsample(scale_factor=2, mode='nearest')#self.upsample = nn.Upsample(scale_factor=4, mode='nearest')##double poolself.ifm = DWTInverse(wave='db3', mode='zero').cuda()self.stride = strideself.l2h = nn.Conv2d(int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.h2h = nn.Conv2d(in_nc - int(alpha * in_nc), out_nc,kernel_size, 1, padding, dilation, groups, bias)self.a = act(act_type) if act_type else Noneself.n_h = norm(norm_type, out_nc) if norm_type else Nonedef forward(self, x):X_h,X_l_ll,X_l_lh,X_l_hl,X_l_hh = x#if self.stride ==2:#X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)A=X_l_llc= A.shapeC = torch.randn(c[0], c[1], 3, c[-2], c[-1])#X_lh_ = torch.unsqueeze(X_lh, 2)#X_hl_ = torch.unsqueeze(X_hl, 2)#X_hh_ = torch.unsqueeze(X_hh, 2)#C=torch.cat((X_lh_,X_hl_,X_hh_), dim=2)C[:,:,0]=X_l_lhC[:,:,1]=X_l_hlC[:,:,2]=X_l_hh#C=C.cpu()C_ = [C.cuda()]#A=A.cpu()X_l=self.ifm((A,C_))X_h2h = self.h2h(X_h)X_l2h=self.l2h(X_l)#X_l2h = self.l2h(X_l)X_h = X_h2h + X_l2hif self.n_h:X_h = self.n_h(X_h)if self.a:X_h = self.a(X_h)return X_hclass DWT_octave_ResidualDenseBlockTiny_4C(nn.Module):'''Residual Dense Blockstyle: 4 convsThe core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)'''def __init__(self, nc, kernel_size=3, gc=16,alpha=0.5, stride=1, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_ResidualDenseBlockTiny_4C, self).__init__()# gc: growth channel, i.e. intermediate channelsself.conv1 =DWT_OctaveConv(nc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv2 = DWT_OctaveConv(nc+gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)self.conv3 = DWT_OctaveConv(nc+2*gc, gc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=act_type, mode=mode)if mode == 'CNA':last_act = Noneelse:last_act = act_typeself.conv4 = DWT_OctaveConv(nc+3*gc, nc, kernel_size, alpha, stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode) # conv_block(nc+3*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \#     norm_type=norm_type, act_type=last_act, mode=mode)def forward(self, x):# print(x[0].shape,'~~~',x[1].shape,'~~~',x[2].shape,'~~~',x[3].shape)# exit()x1 = self.conv1(x)x2 = self.conv2((torch.cat((x[0], x1[0]), dim=1),torch.cat((x[1], x1[1]), dim=1),torch.cat((x[2], x1[2]), dim=1),torch.cat((x[3], x1[3]), dim=1),torch.cat((x[4], x1[4]), dim=1)))x3 = self.conv3((torch.cat((x[0], x1[0],x2[0]), dim=1),torch.cat((x[1], x1[1],x2[1]), dim=1),torch.cat((x[2], x1[2],x2[2]), dim=1),torch.cat((x[3], x1[3],x2[3]), dim=1),torch.cat((x[4], x1[4],x2[4]), dim=1)))x4 = self.conv4((torch.cat((x[0], x1[0],x2[0],x3[0]), dim=1),torch.cat((x[1], x1[1],x2[1],x3[1]), dim=1),torch.cat((x[2], x1[2],x2[2],x3[2]), dim=1),torch.cat((x[3], x1[3],x2[3],x3[3]), dim=1),torch.cat((x[4], x1[4],x2[4],x3[4]), dim=1)))res = (x4[0].mul(0.2), x4[1].mul(0.2),x4[2].mul(0.2),x4[3].mul(0.2),x4[4].mul(0.2))x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3], x[4] + res[4])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return xclass DWT_octave_RRDBTiny(nn.Module):'''Residual in Residual Dense Block(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)'''def __init__(self, nc, kernel_size=3, gc=16, stride=1, alpha=0.5, bias=True, pad_type='zero', \norm_type=None, act_type='leakyrelu', mode='CNA'):super(DWT_octave_RRDBTiny, self).__init__()self.RDB1 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)self.RDB2 = DWT_octave_ResidualDenseBlockTiny_4C(nc=nc, kernel_size=kernel_size,alpha=alpha, gc=gc, stride=stride, bias=bias, pad_type=pad_type, \norm_type=norm_type, act_type=act_type, mode=mode)def forward(self, x):out = self.RDB1(x)out = self.RDB2(out)res = (out[0].mul(0.2), out[1].mul(0.2),out[2].mul(0.2),out[3].mul(0.2),out[4].mul(0.2))x = (x[0] + res[0], x[1] + res[1], x[2] + res[2], x[3] + res[3],x[4] + res[4])#print(len(x),"~~~",len(res),"~~~",len(x + res))#return (x[0] + res[0], x[1]+res[1])return x

结果

实验笔记之——基于DWT的octave layer(DWT在pytorch中实现)相关推荐

  1. ROS实验笔记之——基于Prometheus的无人机运动规划

    本博文基于Prometheus项目来学习无人机的运动规划.关于该项目的配置可以参考<ROS实验笔记之--基于Prometheus自主无人机开源项目的学习与仿真> Demo演示 基于2D-L ...

  2. ROS实验笔记之——基于Prometheus自主无人机开源项目的学习与仿真

    最近在公众号上看到Prometheus无人机的资料,发现里面开源了很好的无人机的仿真环境,并且有很好的教程.而本人正好在上<Introduction to Aerial Robotics> ...

  3. 【论文笔记】基于深度学习的机器人抓取虚拟仿真实验教学系统

    文章目录 摘要 关键词 0 引言 1 基于深度学习的机器人抓取实验原理 2 机器人抓取虚拟仿真实验设计方案 2.1 虚拟仿真实验系统总体设计 2.2 机器人抓取实验教学过程 3 实验教学考核与管理 4 ...

  4. 学习笔记之——基于深度学习的图像超分辨率重建

    最近开展图像超分辨率( Image Super Resolution)方面的研究,做了一些列的调研,并结合本人的理解总结成本博文~(本博文仅用于本人的学习笔记,不做商业用途) 本博文涉及的paper已 ...

  5. 学习笔记之——基于深度学习的目标检测算法

    国庆假期闲来无事~又正好打算入门基于深度学习的视觉检测领域,就利用这个时间来写一份学习的博文~本博文主要是本人的学习笔记与调研报告(不涉及商业用途),博文的部分来自我团队的几位成员的调研报告(由于隐私 ...

  6. NLP-Beginner任务三学习笔记:基于注意力机制的文本匹配

    **输入两个句子判断,判断它们之间的关系.参考ESIM(可以只用LSTM,忽略Tree-LSTM),用双向的注意力机制实现** 数据集:The Stanford Natural Language Pr ...

  7. 视频理解论文实验笔记2014-2022

    视频理解论文实验笔记 看了李沐团队的视频,其中关于视频理解的串讲(上集 下集)讲的太好了,按照他的顺序看了这些论文,并做了重点针对实验部分的笔记 文章目录 视频理解论文实验笔记 2D Base cvp ...

  8. ROS实验笔记之——Intel Realsense l515激光相机的使用

    最近实验室购买了Intel Realsense l515相机.本博文记录使用过程~ 驱动安装 先到官网安装驱动:https://github.com/IntelRealSense/realsense- ...

  9. 实验笔记之——单片机烧录的实验过程

    本博文为本人最近做的,基于32与51开发过程的实验记录. 本博文为本人的实验笔记,仅仅供本人学习记录用,不作任何商业用途. 目录 IDE安装 STM32烧录步骤 51烧录步骤 IDE安装 keil4和 ...

最新文章

  1. svn的使用(转载)
  2. R语言-文本挖掘 主题模型 文本分类
  3. 子域名枚举工具Sublist3r
  4. Java初学者疑难杂症之:一对一和一对多的关系
  5. VMware Workstation 与 Device/Credential Guard 不兼容。在禁用 Device/Credential Guard 后,可以运行 VMware Workstati
  6. pyecharts学习(part2)--pyecharts Line
  7. am335x PDK3.0 设置为单网口配置记录
  8. PostgreSQL体系架构
  9. 预备作业03 20162316刘诚昊
  10. python编程遍历_Python字典遍历操作实例小结
  11. 20行 Python 代码爬取王者荣耀全英雄皮肤 | 原力计划
  12. linux内核的裁剪与移植
  13. 给大家推荐一款冰点文档下载器(免登陆,免积分)下载百度,豆丁,畅享网,mbalib,hp009,mab.book118文库文档
  14. 有哪些好用且免费的安全测试工具?
  15. Microsoft Excel 教程「41」,如何在 Excel 图表中添加标题?
  16. 香港服务器还能否备案?
  17. 硕士生论文存在的问题
  18. 真的还有必要学习JAVA多线程吗?
  19. vue animate bounceInRight 只执行一遍
  20. Unity Cinemachine Timeline 制作镜头动画

热门文章

  1. FPGA实现IRIG-B(DC)码编码和解码的设计
  2. 基于微信预约挂号小程序系统设计与实现 开题报告
  3. 关于wemall,你知道多少?
  4. Google软件工程(续)
  5. 何谓KVM切换器及其功能之详解
  6. 大学计算机基础二进制数试讲,大学计算机基础习题(Clare整理版)
  7. 第07章 循环神经网络
  8. 大型机增速超过混合云太多,这个季度IBM Z营收增长88%
  9. linux上传文件put,详解Linux ftp 命令行中下载文件get与上传文件put的操作方法
  10. 服务器Nacos集群搭建及使用总结