Neural Style原理

CnnCnnC_{nn} 是一个预先训练好的深度卷积神经网络, XXX 是输入图片. Cnn(X)" role="presentation" style="position: relative;">Cnn(X)Cnn(X)C_{nn}(X) 是输入了图像 XXX 的网络. FXL∈Cnn(X)" role="presentation" style="position: relative;">FXL∈Cnn(X)FXL∈Cnn(X)F_{XL} \in C_{nn}(X) 是 LLL 层的特征图. 我们通过 FXL" role="presentation" style="position: relative;">FXLFXLF_{XL} 定义了LLL 层 X" role="presentation" style="position: relative;">XXX 的Content. 如果 YYY 是与 X" role="presentation" style="position: relative;">XXX 同样大小的另一个图片,则 LLL 层的Content的距离误差定义如下:

(1)DCL(X,Y)=‖FXL−FYL‖2=∑i(FXL(i)−FYL(i))2" role="presentation" style="position: relative;">DLC(X,Y)=∥FXL−FYL∥2=∑i(FXL(i)−FYL(i))2(1)(1)DCL(X,Y)=‖FXL−FYL‖2=∑i(FXL(i)−FYL(i))2

\begin{align}D_C^L(X,Y) = \|F_{XL} - F_{YL}\|^2 = \sum_i (F_{XL}(i) - F_{YL}(i))^2\end{align}

FXL(i)FXL(i)F_{XL}(i) 是 FXLFXLF_{XL} 的第 iii 个元素 .GXL" role="presentation" style="position: relative;">GXLGXLG_{XL} 是一个 KKK\ x\ K" role="presentation" style="position: relative;">KKK 大小的矩阵. GXLGXLG_{XL}中 第 kkk 行 第l" role="presentation" style="position: relative;">lll列元素 GXL(k,l)GXL(k,l)G_{XL}(k,l)是 FkXLFXLkF_{XL}^k 和FlXLFXLlF_{XL}^l 的积:

GXL(k,l)=⟨FkXL,FlXL⟩=∑iFkXL(i).FlXL(i)(2)(2)GXL(k,l)=⟨FXLk,FXLl⟩=∑iFXLk(i).FXLl(i)

\begin{align}G_{XL}(k,l) = \langle F_{XL}^k, F_{XL}^l\rangle = \sum_i F_{XL}^k(i) . F_{XL}^l(i)\end{align}

FkXL(i)FXLk(i)F_{XL}^k(i) 是 FkXLFXLkF_{XL}^k 的第 iii 个元素. GXL(k,l)" role="presentation" style="position: relative;">GXL(k,l)GXL(k,l)G_{XL}(k,l) 用来度量 kkk 与 l" role="presentation" style="position: relative;">lll 之间的相关性. GXLGXLG_{XL}表示 LLL 层 X" role="presentation" style="position: relative;">XXX 特征图的相关矩阵. 定义 LLL 层Style的距离误差如下:

(3)DSL(X,Y)=‖GXL−GYL‖2=∑k,l(GXL(k,l)−GYL(k,l))2" role="presentation" style="position: relative;">DLS(X,Y)=∥GXL−GYL∥2=∑k,l(GXL(k,l)−GYL(k,l))2(3)(3)DSL(X,Y)=‖GXL−GYL‖2=∑k,l(GXL(k,l)−GYL(k,l))2

\begin{align}D_S^L(X,Y) = \|G_{XL} - G_{YL}\|^2 = \sum_{k,l} (G_{XL}(k,l) - G_{YL}(k,l))^2\end{align}

为了最小化关于可变图像 XXX 和目标图像 C" role="presentation" style="position: relative;">CCC 的 DC(X,C)DC(X,C)D_C(X,C)以及关于 XXX 和目标样式 S" role="presentation" style="position: relative;">SSS的DS(X,S)DS(X,S)D_S(X,S), 我们计算每个想要的layer上的梯度,并求和:

∇extittotal(X,S,C)=∑LCwCLC.∇LCextitcontent(X,C)+∑LSwSLS.∇LSextitstyle(X,S)(4)(4)∇extittotal(X,S,C)=∑LCwCLC.∇extitcontentLC(X,C)+∑LSwSLS.∇extitstyleLS(X,S)

\begin{align}\nabla_{ extit{total}}(X,S,C) = \sum_{L_C} w_{CL_C}.\nabla_{ extit{content}}^{L_C}(X,C) + \sum_{L_S} w_{SL_S}.\nabla_{ extit{style}}^{L_S}(X,S)\end{align}

LCLCL_C 和 LSLSL_S分别是所需的图层的Content和Style, wCLCwCLCw_{CL_C} 和 wSLSwSLSw_{SL_S} 分别为与之相关的权值. 梯度下降如下:

X←X−α∇extittotal(X,S,C)(19)(19)X←X−α∇extittotal(X,S,C)

\begin{align}X \leftarrow X - \alpha \nabla_{ extit{total}}(X,S,C)\end{align}

基于TensorFlow的图像艺术化实现:

#导入库
import numpy as np
import tensorflow as tf
import scipy.io as sio
from PIL import Image
import matplotlib.pyplot as plt#设置参数
#style权重和content权重可以控制结果是趋于风格还是趋于内容
STYLE_WEIGHT=1.5
CONTENT_WEIGHT=1
#style的层数越多,就越能挖掘更多的风格特征,content的层数越深,得到的特征越抽象
STYLE_LAYERS=['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']
CONTENT_LAYERS=['relu4_2', 'relu5_2']
VGG_PATH = 'imagenet-vgg-verydeep-19.mat'
#VGG模型结构
VGG_LAYERS=('conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1','conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2','conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4','pool3','conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4','conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5')
POOL_TYPE='max'

定义网络:

#定义vgg网络
def net_vgg19(input_image,layers,vgg_path,pool_type='max'):weights=sio.loadmat(vgg_path)['layers'][0]net=input_imagenetwork={}for i,name in enumerate(layers):layer_type=name[:4]if layer_type=='conv':kernels,bias=weights[i][0][0][0][0]kernels=np.transpose(kernels,(1,0,2,3))conv=tf.nn.conv2d(net,tf.constant(kernels),strides=(1,1,1,1),padding='SAME',name=name)net=tf.nn.bias_add(conv,bias.reshape(-1))net=tf.nn.relu(net)elif layer_type=='pool':if pool_type == 'avg':net=tf.nn.avg_pool(net, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),padding='SAME')else:net=tf.nn.max_pool(net, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),padding='SAME')network[name]=netreturn network

定义损失误差:

#定义损失误差
def loss_function(style_image,content_image,target_image):style_features=net_vgg19([style_image],VGG_LAYERS,VGG_PATH,POOL_TYPE)content_features=net_vgg19([content_image],VGG_LAYERS,VGG_PATH,POOL_TYPE)target_features=net_vgg19([target_image],VGG_LAYERS,VGG_PATH,POOL_TYPE)loss=0.0for layer in CONTENT_LAYERS:_,height,width,channel=map(lambda i:i.value,content_features[layer].get_shape())content_size=height*width*channelloss_content=tf.nn.l2_loss(target_features[layer]-content_features[layer])/content_sizeloss+=CONTENT_WEIGHT*loss_contentfor layer in STYLE_LAYERS:target_feature=target_features[layer]style_feature=style_features[layer]_,height,width,channel=map(lambda i:i.value,target_feature.get_shape())style_size=height*width*channeltarget_feature=tf.reshape(target_feature,(-1,channel))target_gram=tf.matmul(tf.transpose(target_feature),target_feature)/style_sizestyle_feature=tf.reshape(style_feature,(-1,channel))style_gram=tf.matmul(tf.transpose(style_feature),style_feature)/style_sizeloss_style=tf.nn.l2_loss(target_gram-style_gram)/style_sizeloss+=STYLE_WEIGHT*loss_stylereturn loss

训练部分:

#定义stylize函数,进行训练
def stylize(style_image,content_image,learning_rate=0.1,epochs=100):target = tf.Variable(tf.random_normal(content_image.shape),dtype=tf.float32)style_input = tf.constant(style_image,dtype=tf.float32)content_input = tf.constant(content_image, dtype=tf.float32)cost=loss_function(style_input,content_input,target)#定义优化器train=tf.train.AdamOptimizer(learning_rate).minimize(cost)with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:tf.global_variables_initializer().run()for i in range(epochs):_,loss,target_img=sess.run([train,cost,target])if(i+1)%100==0:print('迭代: %d ,loss: %.8f'%(i+1,loss))image=np.clip(target_img+128,0,255).astype(np.uint8)img=Image.fromarray(image)plt.imshow(img)plt.axis('on')plt.title('Image')plt.show()

读入数据部分:

style_image=Image.open('3-style.jpg')
style_image=np.array(style_image).astype(np.float32)-128.0
content_image=Image.open('shan2.jpg')
content_image=np.array(content_image).astype(np.float32)-128.0
stylize(style_image,content_image,0.2,1000)

本次案例迭代了1000次,每100次边显示艺术效果,可以看看艺术化的变化效果:










注:本案例需要下载已训练好的vgg19模型
下载地址:
https://blog.csdn.net/zyb228/article/details/80140951
或者:
http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat

更多人工智能技术干货请关注:

深度学习 Neural Style 之TensorFlow实践相关推荐

  1. 深度学习之Style Transfer

    Style Transfer 1.引入 最近看了一些基于深度学习的Style Transfer, 也就是风格迁移相关的paper,感觉挺有意思的. 所谓风格迁移,其实就是提供一幅画(Reference ...

  2. 福利 | Python、深度学习、机器学习、TensorFlow 好书推荐

    在上次的送书活动中,营长做了个调查问卷,结果显示大家更喜欢深度学习.Python以及TensorFlow方面的书,所以这期送书活动一并满足大家.本期图书选自人民邮电出版社图书,包括:近期AI圈儿比较流 ...

  3. 两个月入门深度学习,全靠动手实践!一位前端小哥的经验分享

    两个月入门深度学习,全靠动手实践!一位前端小哥的经验分享   在当前社会,技术日新月异,一个全栈工程师不及时学习新知识,掌握AI技能,再过两年就算不上"全栈"了. 产品发烧友.前端 ...

  4. 【百家稷学】深度学习与嵌入式平台AI实践(北京交通大学实训)

    继续咱们百家稷学专题,本次是有三AI在北京交通大学进行的暑期课程教学.百家稷学专题的目标,是走进100所高校和企业进行学习与分享. 分享主题 本次分享是在北京交通大学计算机与信息技术学院进行,主题是& ...

  5. numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践

    <<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗  读完<<深度学习框架PyTorc ...

  6. 深度学习与 Spark 和 TensorFlow

    2019独角兽企业重金招聘Python工程师标准>>> 深度学习与 Spark 和 TensorFlow 在过去几年中,神经网络领域的发展非常迅猛,也是现在图像识别和自动翻译领域中最 ...

  7. R语言︱H2o深度学习的一些R语言实践——H2o包

    每每以为攀得众山小,可.每每又切实来到起点,大牛们,缓缓脚步来俺笔记葩分享一下吧,please~ --------------------------- R语言H2o包的几个应用案例 笔者寄语:受启发 ...

  8. 华为在深度学习平台上的优化实践

    "Kubernetes Meetup 中国 2017"--北京站3.18落幕啦!本次分享嘉宾彭靖田来自华为,他的分享题目是<华为在深度学习平台上的优化实践>.实录将从深 ...

  9. 第一门课 神经网络和深度学习(Neural Networks and Deep Learning)

    第一门课 神经网络和深度学习(Neural Networks and Deep Learning) 文章目录 第一门课 神经网络和深度学习(Neural Networks and Deep Learn ...

最新文章

  1. SylixOS 基于STM32平台的GPIO模仿I2C总线的驱动开发流程
  2. java 线程崩溃_java语言中application异常退出和线程异常崩溃的捕获方法,并且在捕获的钩子方法中进行异常处理...
  3. C++归并排序递归写法
  4. shell脚本游戏之:剪刀石头布
  5. 基于SpringBoot的CodeGenerator
  6. c++面向对象高级编程 学习四 静态、类模板、函数模板
  7. Mars的mp3实例
  8. iOS即时通讯输入框随字数自适应高度
  9. Redis学习---(8)Redis 哈希(Hash)
  10. linux 统计 程序运行时间
  11. Mac电脑:调整 VMware中Windows10 屏幕分辨率(解决win10与Mac界面切换后,分辨率改变问题)
  12. 干货干货:px和毫米之间的转换
  13. table与tr td样式重叠 table样式边框变细
  14. C++--第26课 - 异常处理 - 下
  15. DVWA high暴力破解
  16. 广度优先搜索和深度优先搜索
  17. 任天堂(Nintendo)(什么是ps4,什么是ns(switch))
  18. 【每日新闻】百度云王龙:数据库与AI的融合主要分三个阶段 | 中国移动研究院:5G第一个版本出炉...
  19. 微信撤回消息在服务器可以看到吗,微信撤回消息可以查看了,对方撤回了什么一目了然...
  20. openfiler服务器打不开web管理页面

热门文章

  1. Spark SQL中出现 CROSS JOIN 问题解决
  2. Springboot 抛出Failed to determine a suitable driver class异常原因
  3. Android 解决不同进程发送KeyEvent 的问题
  4. Spark应用程序第三方jar文件依赖解决方案
  5. 处理错误:ORA-27101: shared memory realm does not exist 解决方案
  6. 为什么要使用String.Equals over ==? [重复]
  7. 在图像旁边垂直对齐文字?
  8. JAVA 多用户商城系统b2b2c-服务容错保护(Hystrix依赖隔离)
  9. 第 3 章 kickstart
  10. Kibana——数据图形化制作