李沐d2l《动手学深度学习》第二版——风格迁移源码详解
本文是对李沐Dive to DL《动手学深度学习》第二版13.12节风格迁移的源码详解,整体由Jupyter+VSCode完成,几乎所有重要代码均给出了注释,一看就懂。需要的同学可以在文末链接处下载原文件。
话不多说,我们开始。
图1中的内容图像为本书作者在西雅图郊区的雷尼尔山国家公园拍摄的风景照,而风格图像则是一幅主题为秋天橡树的油画。最终输出的合成图像应用了风格图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。
图1 风格迁移效果示意图
1.一种简单的风格迁移方法
(1) 初始化合成图像,例如将其初始化为内容图像(content image);
(2) 利用预训练网络(如VGG-19)的某些层抽取内容图像与合成图像的内容特征,再用某些层抽取风格图像与合成图像的风格特征;
(3) 根据抽取出来的content feature map和style feature map计算出内容损失(content loss,使合成图像与内容图像在内容特征上接近)和风格损失(style loss,使合成图像与风格图像在风格特征上接近);
(4) 根据当前的合成图像自身计算出全变分损失(total variation loss,有助于减少合成图像中的噪点);
(5) 将这三个损失按一定比例加权(主观更倾向于合成什么样的图像),计算出最终的总损失;
(6) 根据损失反向传播误差,逐步更新合成图像的参数,降低损失,最终结束训练,图像风格迁移成功。
图2 风格迁移的三类损失
2.先观察一下内容图像(content image)和风格图像(style image)
导包:
import torch
import torchvision
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
读取图像并显示:
content_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\rainier.jpg')
style_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\autumn_oak.jpg')
plt.imshow(content_img)
plt.show() # Display all open figures.
plt.imshow(style_img)
plt.show()
3.图像预处理和后处理
预处理函数preprocess()对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入 格式。
后处理函数postprocess()则将输出图像反标准化,输出能正常显示的人眼看得懂的图像。
# content_img[1364, 2047, 2] # imread读取的图像前两个维度是高和宽,第三个维度表示选择RGB三个通道中的哪个通道# ImageNet先验归一化
# 该均值和标准差来源于ImageNet数据集统计得到,如果建立的数据集分布和ImageNet数据集数据分布类似(来自生活真实场景,例如人像、风景、交通工具等),或者使用PyTorch提供的预训练模型,推荐使用该参数归一化。如果建立的数据集并非是生活真实场景(如生物医学图像),则不推荐使用该参数
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])# 在使用深度学习框架构建训练数据时,通常需要数据归一化,以利于网络的训练
# 输入图像必须是PIL/np.ndaray
def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),# ToTensor()将其图像由(h,w,c)转置为(c,h,w),再把像素值从[0,255]变换到[0,1]torchvision.transforms.ToTensor(),# 标准化torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0) # unsqueeze将图像升至4维,增加批次数=1,便于后续图像处理可以更好地进行批操作# 在训练过程可视化中,通常需要反归一化,以显示能用人眼看得懂的正常的图
def postprocess(img):# 将图像从训练的GPU环境移至CPU并转为3维(c,h,w)(消去第四维batch)img = img[0].to(rgb_std.device)# 先将(c,h,w)转化为(h,w,c)实施反归一化,并将输出范围限制到[0,1]img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 再将(h,w,c)转化为(c,h,w),ToPILImage()将Tensor的每个元素乘以255;将数据由Tensor转化成Uint8# ToPILImage()要求输入图像若是tensor,则shape必须是(c,h,w)形式return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
我们可以写一段测试代码看看postprocess()干了些啥:
# 这里content_img是PIL格式
# img是tensor格式
img = preprocess(content_img, (224,244))
img.shape # 输出:torch.Size([1, 3, 224, 244])
img = img[0].to(rgb_std.device)
img.shape # 输出:torch.Size([3, 224, 244])
img = img.permute(1, 2, 0)
img.shape # 输出:torch.Size([224, 244, 3])
img = torch.clamp(img * rgb_std + rgb_mean, 0, 1)
img_before = torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
plt.imshow(img_before) # PIL格式,不可索引单个像素,不可输出shape
输出标准化+反标准化后的图像:
4. 抽取图像特征
我们使用基于ImageNet数据集预训练的VGG-19模型来抽取图像特征 [Gatys et al., 2016]。
pretrained_net = torchvision.models.vgg19(pretrained=True)
pretrained_net # 输出vgg-19的网络结构:5个卷积块,前两个块中有2个卷积层,后三个块中有4个卷积层
为了抽取图像的内容特征和风格特征,我们可以选择VGG网络中某些层的输出。 一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。 为了避免合成图像过多保留内容图像的细节(我们只需要保留一个大概的主题及轮廓即可,细节方面由风格特征把握),我们选择VGG较靠近输出的层来输出图像的内容特征。 我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。 VGG网络使用了5个卷积块,实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。 这些层的索引可以通过打印pretrained_net实例获取。
style_idx = [0, 5, 10, 19, 28] # 各卷积块中的第一个卷积层的索引分别是Sequential中的0, 5, 10, 19, 28
content_idx = [25] # 第4个卷积块中最后一个卷积层的索引是Sequential中的25# 构建一个新的网络net,它只保留需要用到的VGG的所有层。
# 因为用到最深的层是第28层,因此我们要保留vgg网络中第28层及前面的所有层,而28层以后的均不要保留
# pretrained_net.features输出vgg网络的feature属性(即平均池化之前的所有层)
layers = []
# max(content_layers + style_layers)求用到的最深层的索引
for i in range(max(style_idx + content_idx) + 1):# 逐层加入到所需的layers列表中layers.append(pretrained_net.features[i])
# 将layers逐元素加入到Sequential中
net = nn.Sequential(*layers)
net
# 风格特征提取层如下:
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# 内容特征提取层如下:
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
给定输入X,如果我们简单地调用前向传播net(X),只能获得最后一层的输出。 由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出。
def extract_features(x, content_idx, style_idx):content_features = []style_features = []for i in range(len(net)):# 计算当前层输出temp_layer = net[i]x = temp_layer(x)# 若当前层索引在风格索引列表中if i in style_idx:style_features.append(x)# 若当前层索引在内容索引列表中if i in content_idx:content_features.append(x)return content_features, style_features
下面定义两个函数:get_contents函数对内容图像抽取内容特征; get_styles函数对风格图像抽取风格特征。 因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和风格特征。 由于合成图像是风格迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和风格特征。
# 提取内容图像的内容特征
def get_content_features(content_img, image_shape):# 对content_img先进行预处理并移至gpu,便于直接输入网络content_x = preprocess(content_img, image_shape).cuda()# 提取内容图像的内容特征content_features_x, _ = extract_features(content_x, content_idx, style_idx)return content_x, content_features_x# 提取风格图像的风格特征
def get_style_features(style_img, image_shape):# 对style_img先进行预处理并移至gpu,便于直接输入网络style_x = preprocess(style_img, image_shape).cuda()# 提取风格图像的风格特征_, style_features_x = extract_features(style_x, content_idx, style_idx)return style_x, style_features_x
5.定义损失函数:由内容损失、风格损失和全变分损失3部分组成
5.1. 内容损失
内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异。 平方误差函数的两个输入均为extract_features函数计算所得到的内容层的输出。
def calc_contentloss(Y_hat, Y):# 从动态计算梯度的树中分离目标# 计算所有通道对应矩阵的差的平方和,再除以所有元素个数# 这里把Y detach一下是因为原始内容图像的特征图无需参与反向传播(视为已知常量),所以将它从计算图中分离,否则的话反向更新会影响该值return torch.square(Y_hat - Y.detach()).mean()
测试一下这函数干了啥:
Y_hat = torch.randn(512,2,2)
Y = torch.randn(512,2,2)
torch.square(Y_hat - Y.detach()).mean() == torch.square(Y_hat - Y.detach()).sum() / 2048
# 返回tensor(True)
5.2. 风格损失
风格损失与内容损失类似,也通过平方误差函数衡量合成图像与风格图像在风格上的差异。 为了表达风格层输出的风格,我们先通过extract_features函数计算风格层的输出。 假设该输出的样本数为1,通道数为ccc,高和宽分别为hhh和www ,我们可以将此输出转换为矩阵XXX,其有ccc行和hwhwhw列(相当把一个通道的矩阵拉成一个行向量)。 这个矩阵可以被看作是由ccc个长度为hwhwhw的向量x1,...,xcx_{1},...,x_{c}x1,...,xc组合而成的。其中向量xix_{i}xi代表了通道iii上的风格特征(其实就是该通道的所有像素点)。
在这些向量的格拉姆矩阵G=XX⊤∈Rc×cG = XX^⊤∈R^{c×c}G=XX⊤∈Rc×c中,iii行jjj列的元素xijx_{ij}xij即向量xix_{i}xi和xjx_{j}xj的内积。它表达了通道iii和通道jjj上风格特征的相关性(emmm…姑且认为一个像素代表一个特征吧)。我们用这样的格拉姆矩阵来表达风格层输出的风格。 需要注意的是,当hwhwhw的值较大时,格拉姆矩阵中的元素容易出现较大的值。 此外,格拉姆矩阵的高和宽皆为通道数 c 。 为了让风格损失不受这些值的大小影响,下面定义的calc_gram函数将格拉姆矩阵除以了矩阵中元素的个数,即chwchwchw。
# 输入是vgg某层输出的特征图,尺寸为(c,h,w)
def calc_gram(x):c = x.shape[1] # c是输出的风格特征图的通道数hw = x.shape[2] * x.shape[3] # hw是一张特征图矩阵中所有元素的个数x = x.reshape((c, hw)) # 将(c,h,w)变换为(c,h*w)return torch.matmul(x, x.T) / (c * hw) # matmul是矩阵乘法# 计算风格损失,这里假设风格图像的格拉姆矩阵已经提前计算好了
def calc_styleloss(Y_hat, gram_Y):# 这里把gram_Y detach一下是因为原始风格图像的格拉姆矩阵无需参与反向传播(视为已知常量),所以将它从计算图中分离,否则的话反向更新会影响该值return torch.square(calc_gram(Y_hat) - gram_Y.detach()).mean()
5.3. 全变分损失
有时候,我们学到的合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。 一种常见的去噪方法是全变分去噪(total variation denoising): 假设xi,jx_{i,j}xi,j表示坐标(i,j)(i,j)(i,j)处的像素值,则全变分损失定义为:
∑i,j∣xi,j−xi+1,j∣+∑i,j∣xi,j−xi,j+1∣\sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \sum_{i, j} \left|x_{i, j} - x_{i, j+1}\right|∑i,j∣xi,j−xi+1,j∣+∑i,j∣xi,j−xi,j+1∣
这里我们用如下公式计算单幅图像的全变分损失:
loss=12(lossvertical+losshorizontal)=12(1chw∑c∑i,j∣xi,j−xi+1,j∣+1chw∑c∑i,j∣xi,j−xi,j+1∣)loss = \frac{1}{2}(loss_{vertical} + loss_{horizontal}) = \frac{1}{2}(\frac{1}{chw}\sum_{c}\sum_{i, j}\left|x_{i, j} - x_{i+1, j}\right| + \frac{1}{chw}\sum_{c}\sum_{i, j}\left|x_{i, j} - x_{i, j+1}\right|)loss=21(lossvertical+losshorizontal)=21(chw1∑c∑i,j∣xi,j−xi+1,j∣+chw1∑c∑i,j∣xi,j−xi,j+1∣)
其中∑c\sum_{c}∑c表示对各通道求和,ccc同时表示通道数。
def calc_tvloss(Y_hat):# [:, :, 1:, :]表示取各通道图像矩阵的第1行至最后一行(起始行为0行)# [:, :, :-1, :]表示取各通道图像矩阵的第0行至倒数第二行# 矩阵相减取绝对值再在各通道上取平均(所有元素加起来除以总元素数)return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())
测试一下:
y = torch.randn(1,3,4,4)
torch.abs(y[:,:,1:,:] - y[:,:,:-1,:]).mean()
torch.abs(y[:,:,:,1:] - y[:,:,:,:-1]).mean()
# 输出 tensor(0.9967)
5.4. 损失函数
风格转移任务的损失函数是内容损失、风格损失和全变分损失的加权和。 通过调节三者的权重超参数,我们可以权衡合成图像在保留内容、保留风格及去噪三方面的相对重要性。
content_weight, style_weight, tv_weight = 1, 1000, 10
# 计算总的损失函数值
def compute_loss(X, content_Y, content_Y_hat, style_Y_gram, style_Y_hat):# 分别计算内容损失、风格损失和全变分损失# 对1对y,y_hat求内容损失,乘以权重后添加到列表中content_l = [calc_contentloss(Y_hat, Y) * content_weight for Y_hat, Y in zip(content_Y_hat, content_Y)]# 对5对y,y_hat分别求风格损失,乘以权重后添加到列表中style_l = [calc_styleloss(Y_hat, Y) * style_weight for Y_hat, Y in zip(style_Y_hat, style_Y_gram)]# 求总变差损失,乘以权重tv_l = calc_tvloss(X) * tv_weight# 对所有损失求和(5个风格损失,1个内容损失,1个总变差损失)l = sum(10 * style_l + content_l + [tv_l]) #style_l乘10干啥?return content_l, style_l, tv_l, l
6.初始化合成图像
在风格迁移中,合成的图像是训练期间唯一需要更新的变量。因此,我们可以定义一个简单的模型SynthesizedImage,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可。
class SynthesizedImage(nn.Module):def __init__(self, img_shape):super(SynthesizedImage, self).__init__()self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight
下面,我们定义get_inits函数。该函数创建了合成图像的模型实例,并将其初始化为图像X。风格图像在各个风格层的格拉姆矩阵styles_Y_gram将在训练前预先计算好。
def get_inits(X, lr, style_Y):# X是内容图像的预处理结果gen_img = SynthesizedImage(X.shape).cuda()# 将初始化的weight参数改为已有的图像X的参数(即像素)gen_img.weight.data.copy_(X.data)# 定义优化器trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)# 对各风格特征图计算其格拉姆矩阵,并依次存于列表中style_Y_gram = [calc_gram(Y) for Y in style_Y]# !!!gen_img()!!!括号return gen_img(), style_Y_gram, trainer
7.训练模型
在训练模型进行风格迁移时,我们不断抽取合成图像的内容特征和风格特征,然后计算损失函数。下面定义了训练循环。
def train(X, content_Y, style_Y, lr, num_epochs, lr_decay_epoch):# X是初始化的合成图像,style_Y_gram是原始风格图像的格拉姆矩阵列表X, style_Y_gram, trainer = get_inits(X, lr, style_Y)# 定义学习率下降调节器scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)for epoch in range(num_epochs):trainer.zero_grad()# Y_hat是用合成图像计算出的特征图content_Y_hat, style_Y_hat = extract_features(X, content_idx, style_idx)content_l, style_l, tv_l, l = compute_loss(X, content_Y, content_Y_hat, style_Y_gram, style_Y_hat)# 反向传播误差(计算l对合成图像像素矩阵的导数,因为l的唯一自变量是合成图像像素矩阵)l.backward()# 更新一次合成图像的像素参数trainer.step()# 更新学习率超参数scheduler.step()# 每5个epoch记录一次loss信息if (epoch + 1) % 5 == 0:# 由于风格损失列表有5项,因此算个总损失输出print('迭代次数:{} 内容损失:{:.9f} 风格损失:{:.9f} 总变差损失:{:.9f}' .format(epoch+1, sum(content_l).item(), sum(style_l).item(), tv_l.item()))# 训练结束后返回合成图像return X
8.开始训练+输出风格迁移图像
首先导入内容图像和风格图像并进行预处理,同时事先计算好内容特征和风格特征。
content_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\rainier.jpg')
style_img = Image.open(r'C:\Users\HP\Desktop\风格迁移\autumn_oak.jpg')
image_shape = (300, 450)
net = net.cuda()
# 计算内容图像的预处理结果(因为我们将内容图像作为合成图像的初始化图像作为网络的初始输入)和抽取到的内容特征
X, content_features_Y = get_content_features(content_img, image_shape)
# 计算风格图像抽取到的风格特征
_, style_features_Y = get_style_features(style_img, image_shape)
开始训练:
output = train(X, content_features_Y, style_features_Y, lr = 0.3, num_epochs = 500, lr_decay_epoch = 50)
# 调用后处理函数处理最终的合成图像,将其转换为正常格式的可视化图像
output = postprocess(output)
# 显示图像
plt.imshow(output)
plt.show()
结果:
9.Pytorch源码
本文所用的Pytorch源码可直接跑通(环境:Win10+VScode+Cuda11.3 +CuDnn),链接如下:
Pytorch实现风格迁移源码
欢迎各位小伙伴star or fork~
李沐d2l《动手学深度学习》第二版——风格迁移源码详解相关推荐
- 【李沐:动手学深度学习pytorch版】第2章:预备知识
第2章 预备知识 2.1 数据操作 2.1.1 入门 导入的是torch而不是pytorch import torch 一个数叫标量 一个轴叫向量 两个轴叫矩阵 arange # 生成行向量 x = ...
- 李沐《动手学深度学习》第二版 pytorch笔记1 环境搭建
李沐<动手学深度学习>第二版pytorch笔记1 搭建环境 文章目录 李沐<动手学深度学习>第二版pytorch笔记1 搭建环境 此时尚有耐心 虚拟环境搭建 创建虚拟环境 查看 ...
- 李沐《动手学深度学习》第二版比赛2-Classify Leaves
李沐<动手学深度学习>第二版比赛2-Classify Leaves 我的偶像,李沐大神主讲的<动手学深度学习>(使用Pytorch框架,第一版使用的是MXNet框架)目前已经进 ...
- 李沐《动手学深度学习》d2l——安装和使用
今天想要跟着沐神学习一下循环神经网络,在跑代码的时候,d2l出现了问题,这里记录一下解决的过程,方便以后查阅. 李沐<动手学深度学习>d2l--安装和使用 安装d2l 解决 Import ...
- 李沐「动手学深度学习」中文课程笔记来了!代码还有详细中文注释
关注公众号,发现CV技术之美 本文转自机器之心,编辑张倩. markdown笔记与原课程视频一一对应,Jupyter代码均有详细中文注释,这份学习笔记值得收藏. 去年年初,机器之心知识站上线了亚马逊资 ...
- 李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!
点击上方"AI有道",选择"星标"公众号 重磅干货,第一时间送达 李沐,亚马逊 AI 主任科学家,名声在外!半年前,由李沐.Aston Zhang 等人合力打造 ...
- 李沐《动手学深度学习》新增PyTorch和TensorFlow实现,还有中文版
李沐老师的<动手学深度学习>已经有Pytorch和TensorFlow的实现了,并且有了中文版. 网址:http://d2l.ai/ 简介 李沐老师的<动手学深度学习>自一年前 ...
- PyTorch实现的李沐《动手学深度学习》,登上GitHub热榜,获得1000+星
点击我爱计算机视觉标星,更快获取CVML新技术 晓查 发自 凹非寺 量子位 报道 | 公众号 QbitAI 李沐老师的<动手学深度学习>是一本入门深度学习的优秀教材,也是各大在线书店的计 ...
- 重磅!李沐「动手学深度学习」中文课程笔记来了!
点击 机器学习算法与Python学习 ,选择加星标 精彩内容不迷路 机器之心报道 markdown笔记与原课程视频一一对应,Jupyter代码均有详细中文注释,这份学习笔记值得收藏. 亚马逊资深首席科 ...
最新文章
- u-boot移植:解决 Retry count exceeded; starting again
- 这五款Python工具都是最常用的,尤其是第三种,初学者必须掌握的
- 捕捉mysql中不可忽视的知识点(二)
- java 窗口锁定_使用Java锁定屏幕
- Codeforces Round #260 (Div. 1) A - Boredom DP
- ES6-4/5 解构赋值、函数默认值、数组解构、对象解构
- 组合的输出(信息学奥赛一本通-T1317)
- 人口、人口密度分析项目-条形图
- python专题-读取xml文件
- Unix网络编程卷一第三章笔记
- centos7搭建radius认证服务器
- 网络PPTP协议代理加速器的应用
- Python 调用谷歌翻译(2021年9月测试可用)
- 光威猛将240固态掉盘开卡教程
- 【CSS3 】css样式的计算calc属性
- 腾讯Android开发面试记录,附大厂真题面经
- 通过ipmitool监控机房内服务器温度
- Exception in thread “main“ java.lang.NullPointerException问题
- css 隐藏滚动条 但是可以滚动
- UWP开发:获取用户当前所在的网络环境(WiFi、移动网络、LAN…)