本文主要改动自:https://github.com/sowmyay/medium/blob/master/CV-LossFunctions.ipynb

首先回顾下特征损失(Feature loss)或者感知损失(Perceptual Loss)的初衷:

许多损失函数,如L1 loss、L2 loss、BCE loss,他们都是通过逐像素比较差异,从而对误差进行计算。然而,有的时候看起来十分相似的两个图像(比如图A相对于图B只是整体移动了一个像素),此时对人来说是几乎看不出区别的,但是其像素级损失(pixel-wise loss)将会变的巨大。对于这种任务就不能简单地使用底层的像素损失了,需要设计一种损失来学习语义差异。

既然要比较语义差异,那我们就需要首先获得一张图像的高层特征,而这就可以通过输出卷积神经网络的前几层的输出来实现,他们提取的就是高层的特征。

也就是说,给定两张图,我们不直接比较他们的像素级差异,而是均将他们放入同一网络中,获取某一中间层的输出特征图,然后再用一些传统的loss计算特征图之间的差异即可。在Perceptual Losses for Real-Time Style Transfer and Super-Resolution一文中使用的网络是VGG16,也可以使用一些其他的预训练深度网络(如ResNet, GoogLeNet,VGG19),不过一般VGG16的效果最好。

代码如下,这里使用了MSE来计算特征图的loss。

import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision.models import vgg16_bnclass FeatureLoss(nn.Module):def __init__(self, loss, blocks, weights, device):super().__init__()self.feature_loss = lossassert all(isinstance(w, (int, float)) for w in weights)assert len(weights) == len(blocks)self.weights = torch.tensor(weights).to(device)#VGG16 contains 5 blocks - 3 convolutions per block and 3 dense layers towards the endassert len(blocks) <= 5assert all(i in range(5) for i in blocks)assert sorted(blocks) == blocksvgg = vgg16_bn(pretrained=True).featuresvgg.eval()for param in vgg.parameters():param.requires_grad = Falsevgg = vgg.to(device)bns = [i - 2 for i, m in enumerate(vgg) if isinstance(m, nn.MaxPool2d)]assert all(isinstance(vgg[bn], nn.BatchNorm2d) for bn in bns)self.hooks = [FeatureHook(vgg[bns[i]]) for i in blocks]self.features = vgg[0: bns[blocks[-1]] + 1]def forward(self, inputs, targets):# normalize foreground pixels to ImageNet statistics for pre-trained VGGmean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]inputs = F.normalize(inputs, mean, std)targets = F.normalize(targets, mean, std)# extract feature mapsself.features(inputs)input_features = [hook.features.clone() for hook in self.hooks]self.features(targets)target_features = [hook.features for hook in self.hooks]loss = 0.0# compare their weighted lossfor lhs, rhs, w in zip(input_features, target_features, self.weights):lhs = lhs.view(lhs.size(0), -1)rhs = rhs.view(rhs.size(0), -1)loss += self.feature_loss(lhs, rhs) * wreturn lossclass FeatureHook:def __init__(self, module):self.features = Noneself.hook = module.register_forward_hook(self.on)def on(self, module, inputs, outputs):self.features = outputsdef close(self):self.hook.remove()def perceptual_loss(x, y):F.mse_loss(x, y)def PerceptualLoss(blocks, weights, device):return FeatureLoss(perceptual_loss, blocks, weights, device)

参数:

  • blocks: 选取vgg的哪几块输出作为中间特征图,例如[0, 1, 2]选取前三块。
  • weights: 在计算最终loss时各个特征图loss的权重
  • device: 使用的设备,可以直接传入torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

关键代码分析

首先是导入torchvision中的vgg16,利用eval和requires_grad=False将权重冻结,方便我们输出特征图:

vgg = vgg16_bn(pretrained=True).features
vgg.eval()
for param in vgg.parameters():param.requires_grad = False
vgg = vgg.to(device)

接着是取出vgg16中五个块的输出。这五个块都以max pool结尾,但是考虑到max pool层以及其上的relu层对比较特征图没有帮助(存疑),因此这里取出的是max pool前两层的batch norm层作为五个块的输出:

bns = [i - 2 for i, m in enumerate(vgg) if isinstance(m, nn.MaxPool2d)]

然后,对于我们指定的blocks(需要取出哪几层的输出),将相应bn层使用register_forward_hook方法来获取其输出:

self.hooks = [FeatureHook(vgg[bns[i]]) for i in blocks]

features其实就是一个精简的vgg16。我们需要哪几层的输出,就保留这几层之前的结构。如果我们只需要前两块的输出,那么后面三块其实就可以去掉了,减少运算量。

self.features = vgg[0: bns[blocks[-1]] + 1]

最后,将input和target输入网络,利用hook提取出特征图,对这些特征图进行对比,即可求解feature loss:

self.features(inputs)
input_features = [hook.features.clone() for hook in self.hooks]self.features(targets)
target_features = [hook.features for hook in self.hooks]

Pytorch Feature loss与Perceptual Loss的实现相关推荐

  1. 感知损失(perceptual loss)详解

    本文来自收费专栏:感知损失(perceptual loss)详解_南淮北安的博客-CSDN博客_感知损失 目录 一.感知损失 二.Loss_feature 三.Loss_style 感知损失的作用: ...

  2. 【损失函数:3】感知损失:Perceptual Loss、总变分损失(TV Loss)(附Pytorch实现)

    损失函数 一.感知损失(Perceptual Loss) 1.相关介绍 1)Perceptual Loss是什么? 2)Perceptual Loss如何构造? 3)代码实现 2.代码示例 二.总变分 ...

  3. 深度学习在单图像超分辨率上的应用:SRCNN、Perceptual loss、SRResNet

    单图像超分辨率技术涉及到增加小图像的大小,同时尽可能地防止其质量下降.这一技术有着广泛用途,包括卫星和航天图像分析.医疗图像处理.压缩图像/视频增强及其他应用.我们将在本文借助三个深度学习模型解决这个 ...

  4. 损失函数——感知损失(Perceptual Loss)

    感知损失(Perceptual Loss)是一种基于深度学习的图像风格迁移方法中常用的损失函数.与传统的均方误差损失函数(Mean Square Error,MSE)相比,感知损失更注重图像的感知质量 ...

  5. Perceptual Loss(感知损失)论文笔记

    "Perceptual Losses for Real-Time Style Transfer and Super-Resolution"论文出自斯坦福大学李飞飞团队,发表于ECC ...

  6. pytorch查看loss曲线_pytorch loss总结与测试

    pytorch loss 参考文献: loss 测试 import torch from torch.autograd import Variable ''' 参考文献: https://blog.c ...

  7. Perceptual Loss(感知损失)Perceptual Losses for Real-Time Style Transferand Super-Resolution论文解读

    由于传统的L1,L2 loss是针对于像素级的损失计算,且L2 loss与人眼感知的图像质量并不匹配,单一使用L1或L2 loss对于超分等任务来说恢复出来的图像往往细节表现都不好. 现在的研究中,L ...

  8. Perceptual Loss

    出自2016年李飞飞团队的Perceptual Losses for Real-Time Style Transfer and Super-Resolution 目标是加速图片转换的速度,因为当时的图 ...

  9. CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss、Center Loss)简介、使用方法之详细攻略

    CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss.Center Loss)简介.使用方法之详细攻略 目录 T1.Triplet Loss 1.英文原文解释 ...

最新文章

  1. 黑马程序员Linux系统开发视频之gdb调试方法
  2. Hadoop与Hbase基本配置
  3. reactNative 打包那些事儿
  4. Eclipse配置的tomcat用debug模式启动不了start可以启动
  5. SpringBoot+Vue+Redis实现前后端分离的字典缓存机制
  6. 第十九节:Asp.Net Core WebApi知识总结(一)
  7. 【pyqt5学习】——graphicView显示matplotlib图像
  8. leetcode237 删除链表中的节点(你意想不到的做法,注意细节)
  9. 程序员求助:被领导强行要求写Bug该怎么办?网友的回答让我笑翻
  10. .net开源CMS系统使用教程之:如何用We7 CMS建设全新网站
  11. Servlet技术 - Servlet应用
  12. Unity Editor 判断在哪个视图选中对象(Hierachy, Porject)
  13. 《别做正常的傻瓜》1——结果偏见
  14. LabVIEW控制Arduino驱动1602液晶显示屏(基础篇—10)
  15. postgresql 手动启动_PostGreSql 手动安装
  16. Win10下C:\Users\***修改用户名(完全修改)
  17. 如何使用MyBatis的plugin插件实现多租户的数据过滤?
  18. python 1104: 求因子和(函数专题)
  19. 黑鹰VIP教程超级大全集!!!千G容量!!!
  20. 中秋未到却卖到断货的月饼,究竟有多好吃?

热门文章

  1. android标题栏不被顶上去,Android仿微信QQ聊天顶起输入法不顶起标题栏的问题
  2. php和会计,财务跟会计有什么区别
  3. php sphinx mysql_windows7使用Sphinx+PHP+MySQL详细介绍
  4. php筛选怎么做,thinkphp条件筛选 例子
  5. vi测试仪维修成功率高吗?_欧森杰检测仪:臭氧检测仪的六大特点,您真的了解吗?...
  6. 计算机算法设计与分析 大学生电影节观影问题
  7. AcWing1074. 二叉苹果树(树形DP)题解
  8. 残差网络ResNet
  9. 【干货】Python编程惯例
  10. 【实用】Pyinstaller UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0xce in position解决方案