GAN的量化评估方法

  • IS
    • IS简介
    • IS代码
  • FID
    • FID简介
    • FID代码

IS

IS基于谷歌的Inception Net-V3,输入是图像,输出是1000维的向量,输出响亮的每个维度,代表着对应的属于某一类的概率。
IS用来衡量GAN网络的两个指标:

  1. 生成图片的质量
  2. 多样性

IS简介

定义:

推导出上式的意义:

  1. 对于单一的生成图像,Inceptoin输出的概率分布应该尽量小,越小说明生成图像越可能属于某个类别,图像的质量越高。
  2. 对于生成器生成一批图像而言,Inception输出的平均概率分布熵值应该尽量大,代表着生成器生成的多样性。

IS代码

参考代码:https://github.com/xml94/open/blob/master/compute_IS_for_GAN
本着能不动手就不动手的原则,试了试上面的代码。但是这个需要自己写dataloader函数,还要与代码中匹配,我试了半天也没有成功,所以就自己参考这个写了一个。
只需要把要测试的图片的路径放入path即可:
于此对应的datakoader函数见下面。

from datasets import *import torch.nn as nn
import torch.nn.functional as F
import torchimport torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data
from scipy.stats import entropy
from torchvision.models.inception import inception_v3path = '/'
count = 0
for root,dirs,files in os.walk(path):    #遍历统计for each in files:count += 1   #统计文件夹下文件个数
print(count)
batch_size = 64
transforms_ = [transforms.Resize((256, 256), Image.BICUBIC),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]val_dataloader = DataLoader(ISImageDataset(path, transforms_=transforms_),batch_size = batch_size,
)cuda = True if torch.cuda.is_available() else False
print('cuda: ',cuda)
tensor = torch.cuda.FloatTensorinception_model = inception_v3(pretrained=True, transform_input=False).cuda()
inception_model.eval()
up = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False).cuda()def get_pred(x):if True:x = up(x)x = inception_model(x)return F.softmax(x, dim=1).data.cpu().numpy()print('Computing predictions using inception v3 model')
preds = np.zeros((count, 1000))for i, data in enumerate(val_dataloader):data = data.type(tensor)batch_size_i = data.size()[0]preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(data)print('Computing KL Divergence')
split_scores = []
splits=10
N = count
for k in range(splits):part = preds[k * (N // splits): (k + 1) * (N // splits), :] # split the whole data into several partspy = np.mean(part, axis=0)  # marginal probabilityscores = []for i in range(part.shape[0]):pyx = part[i, :]  # conditional probabilityscores.append(entropy(pyx, py))  # compute divergencesplit_scores.append(np.exp(np.mean(scores)))mean, std  = np.mean(split_scores), np.std(split_scores)
print('IS is %.4f' % mean)
print('The std is %.4f' % std)

dataloader结构体:

import glob
import random
import os
import numpy as npfrom torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transformsclass ISImageDataset(Dataset):def __init__(self, root, transforms_=None):self.transform = transforms.Compose(transforms_)self.files = sorted(glob.glob(os.path.join(root) + "/*.jpg"))def __getitem__(self, index):img = Image.open(self.files[index % len(self.files)]).convert('RGB')      item_image = self.transform(img)return item_imagedef __len__(self):return len(self.files)

将结构体放在为:dataset.py中,IS代码中将其import进来。
最终会输出两个值,一个是IS,一个是std。

但是听说:

由于 Inception V3 是在 ImageNet 上训练的,用 Inception V3 时,应该保证生成模型也在 ImageNet上训练并生成 ImageNet 相似的图片,而不是把什么生成模型生成的图片(卧室,花,人脸)都往 Inception V3中套,那种做法没有任何意义。
不能在一个数据集上训练分类模型,用来评估另一个数据集上训练的生成模型

FID

FID分数是在IS基础上改进的,同样基于Inception Net-V3,它删除了模型原本的输出层,于是输出层变成了最后一层池化层,输出是2048维向量,因此每个图像都被预测为2048个特征。

FID简介

Frechet Inception 距离得分(Frechet Inception Distance score,FID)是计算真实图像和生成图像的特征向量之间距离的一种度量。
假如一个随机变量服从高斯分布,这个分布可以用一个均值和方差来确定。那么两个分布只要均值和方差相同,则两个分布相同。我们就利用这个均值和方差来计算这两个单变量高斯分布之间的距离。但我们这里是多维的分布,我们知道协方差矩阵可以用来衡量两个维度之间的相关性。所以,我们使用均值和协方差矩阵来计算两个分布之间的距离

FID越小代表着生成分布和真实图片之间越接近。

FID代码

可以通过pip之间安装:

pip install pytorch-fid

配置要求如下:

python3
pytorch
torchvision
pillow
numpy
scipy

使用非常的简单:

python -m pytorch_fid path/to/dataset1 path/to/dataset2

把生成图片的路径和真实图片的路径放进去即可,和顺序无关。
也可以选择与–dims N标志一起使用的特征维数,其中N是特征的维数。

64: first max pooling features
192: second max pooling featurs
768: pre-aux classifier features
2048: final average pooling features (this is the default)

比如:

python -m pytorch_fid path/to/dataset1 path/to/dataset2 --dims 2048

一般都是使用默认的2048
FID参考链接:官方github

推荐博客:推荐
英文的:英文

一般的评价图像质量的指标还有SSIM和PSNR,可以参看SSIM和PSNR

GAN的量化评估方法——IS和FID,及其pytorch代码相关推荐

  1. 软件开发工作量及费用量化评估方法在金融行业的应用

    面临的问题 随着国内金融行业市场化进程持续加快以及互联网金融的兴起,信息技术尤其是软件技术的应用对于金融科技创新至关重要.各大金融机构在持续加大科技创新力度的同时,如何科学.高效地管控应用开发的投入并 ...

  2. 生成式对抗网络GAN必读十篇论文(附论文和代码地址)

    目录索引 一.DCGAN 二.Improved Techniques for Training GANs 三.Conditional GANs 四.Progressively Growing of G ...

  3. 涵盖18+ SOTA GAN实现,这个图像生成领域的PyTorch库火了

    视学算法报道 转载自:机器之心 作者:杜伟.陈萍 GAN 自从被提出后,便迅速受到广泛关注.我们可以将 GAN 分为两类,一类是无条件下的生成:另一类是基于条件信息的生成.近日,来自韩国浦项科技大学的 ...

  4. 给GAN一句描述,它就能按要求画画,微软CVPR新研究 | 附PyTorch代码

    晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 让AI认得图像,根据自己的理解给出一段叙述,已经不是什么新鲜事了.从图像到文字容易,把这个过程反过来却很难. 让AI画图有了成熟的解决方案 ...

  5. 图像去模糊代码 python_用Keras搭建GAN:图像去模糊中的应用(附代码)

    雷锋网 (公众号:雷锋网) 按:本文为 雷锋字幕组 编译的技术博客,原标题GAN with Keras: Application to Image Deblurring,作者为Raphaël Meud ...

  6. 股票量化分析工具QTYX使用攻略代码说明——高速版本地行情源v2.5.1

    搭建自己的量化系统 如果要长期在市场中立于不败之地!必须要形成一套自己的交易系统.否则,赚钱或者亏钱我们很难归纳总结,往往是凭借运气赚钱,而不是合理的系统模型,一时凭借运气赚的钱长期来看会因为实力还回 ...

  7. “华为杯”第十五届中国研究生数学建模竞赛-对恐怖袭击事件记录数据的量化分析(Python,Pandas,Scikit-learn,PyTorch,Matplotlib,seaborn)

    首先先说一下编程的工具 Python:编程语言 Pandas:数据处理,清洗,分析的工具 Scikit-learn:机器学习工具箱 PyTorch:深度学习搭建神经网络,训练等的工具 Matplotl ...

  8. 基金反买,别墅靠海?每年买倒数前十基金,能赚这么多?Python量化分析告诉你答案【附代码】-邢不行

     引言: 邢不行的系列帖子"量化小讲堂",通过实际案例教初学者使用python进行量化投资,了解行业研究方向,希望能对大家有帮助 这是邢不行第 83 期量化小讲堂的分享 作者 | ...

  9. Python量化交易策略--双均线策略及代码

    双均线策略是比较经典的策略,股票的价格均线是投资参考的重要指标.均线有快线和慢线之分,当快线向上穿过慢线则是金叉,一般执行买入操作,当快线向下穿过慢线时则形成死叉,一般执行卖出操作.基于这个基本思路, ...

  10. 量化投资(一):十行代码实现一个量化交易入门程序

    1 在浏览器中打开 www.joinquant.com 2 编写策略代码 点击顶部的"我的策略",选择子菜单"我的策略",在左侧输入python代码 def i ...

最新文章

  1. 安卓高级6 拍照或者从相册获取图片 并检测旋转角度或者更新画册扫描
  2. shell脚本采用crontab定时备份数据库日志
  3. Opencv 图片 读取,显示,保存基本操作
  4. disruptor模拟高速处理大规模订单类业务场景
  5. Serializable序列化
  6. mysql-installer-web-community和mysql-installer-communityl两个版本的区别
  7. 信息搜集 - 二层发现 arping
  8. sr650安装linux网卡驱动,SR650安装Windows2016添加RAID卡驱动
  9. JAVAWEB开发Myeclipse 项目中报“无法解析类型 java.io.ObjectInputStream,从必需的 .class 文件间接引用了它”解决办法
  10. echart 三维可视化地图_ECharts实现三维可视化
  11. 进去服务器bios设置u盘启动不了系统,重装系统怎么进不了bios设置u盘启动
  12. 【C语言小游戏】答题系统
  13. MySql ocp认证之备份与恢复(四)
  14. 如何在网页中添加 GitHub Corners
  15. 整理2020java面试题
  16. 我奋斗了18年才和你坐在一起喝咖啡 原作者:麦子
  17. 60度斜坡怎么计算_坡度怎么算
  18. 怎么把视频里的音乐提取成音频,怎么提取视频中的音频
  19. The Guru Myth
  20. 【渝粤教育】广东开放大学 电子支付与安全 形成性考核 (59)

热门文章

  1. 业务常见面试题(数据分析)
  2. 分销系统之项目架构(第一篇)
  3. [TSP-FCOS]Rethinking Transformer-based Set Prediction for Object Detection
  4. 发动机冒黑烟_发动机冒黑烟常见的24个原因和解决方法!
  5. c语言程序窗口设计,C语言窗口程序设计简介.pdf
  6. 信号的同调性(Coherence)分析及MATLAB实例
  7. Android跳转第三方App,淘宝,微信,QQ等。
  8. 非法集资(Illegal Fund-raising)
  9. C#中check和uncheck
  10. war3第一视角集合 UD篇