tf9: PixelCNN
前一帖生成音乐,本帖生成图片。本文使用TensorFlow实现论文《Conditional Image Generation with PixelCNN Decoders》,它是基于PixelCNN架构的模型,最早出现在《Pixel Recurrent Neural Networks》一文。
使用的图片数据
我本想使用ImageNet做为图片来源,就像论文中使用的。ImageNet图像有现成的分类,抓取也容易,但是由于很多源都被防火墙屏蔽,下载速度堪忧。《OpenCV之使用Haar Cascade进行对象识别》
我看到网上有很多爬妹纸图的Python脚本,额,我爬了几天几夜的妹纸图(特别暴露那种),额,我就想看看PixelCNN最后能生成什么鬼。
如果你懒的爬图片,可以使用我抓取的图片(分成两部分):
- https://pan.baidu.com/s/1kVSA8z9 (密码: atqm)
- https://pan.baidu.com/s/1ctbd9O (密码: kubu)
数据预处理
下载的图片分布在多个目录,把图片汇总到一个新目录:
- import os
- old_dir = 'images'
- new_dir = 'girls'
- if not os.path.exists(new_dir):
- os.makedirs(new_dir)
- count = 0
- for (dirpath, dirnames, filenames) in os.walk(old_dir):
- for filename in filenames:
- if filename.endswith('.jpg'):
- new_filename = str(count) + '.jpg'
- os.rename(os.sep.join([dirpath, filename]), os.sep.join([new_dir, new_filename]))
- print(os.sep.join([dirpath, filename]))
- count += 1
- print("Total Picture: ", count)
使用《open_nsfw: 基于Caffe的成人图片识别模型》剔除掉和妹子图不相关的图片,给open_nsfw输入要检测的图片,它会返回图片评级(0-1),等级越高,图片越黄越暴力。使用OpenCV应该也不难。
为了减小计算量,我把图像缩放为64×64像素:
- import os
- import cv2
- import numpy as np
- image_dir = 'girls'
- new_girl_dir = 'little_girls'
- if not os.path.exists(new_girl_dir):
- os.makedirs(new_girl_dir)
- for img_file in os.listdir(image_dir):
- img_file_path = os.path.join(image_dir, img_file)
- img = cv2.imread(img_file_path)
- if img is None:
- print("image read fail")
- continue
- height, weight, channel = img.shape
- if height < 200 or weight < 200 or channel != 3:
- continue
- # 你也可以转为灰度图片(channel=1),加快训练速度
- # 把图片缩放为64x64
- img = cv2.resize(img, (64, 64))
- new_file = os.path.join(new_girl_dir, img_file)
- cv2.imwrite(new_file, img)
- print(new_file)
去除重复图片:
- import os
- import cv2
- import numpy as np
- # 判断两张图片是否完全一样(使用哈希应该要快很多)
- def is_same_image(img_file1, img_file2):
- img1 = cv2.imread(img_file1)
- img2 = cv2.imread(img_file2)
- if img1 is None or img2 is None:
- return False
- if img1.shape == img2.shape and not (np.bitwise_xor(img1, img2).any()):
- return True
- else:
- return False
- # 去除重复图片
- file_list = os.listdir('little_girls')
- try:
- for img1 in file_list:
- print(len(file_list))
- for img2 in file_list:
- if img1 != img2:
- if is_same_image('little_girls/'+img1, 'little_girls/'+img2) is True:
- print(img1, img2)
- os.remove('little_girls/'+img1)
- file_list.remove(img1)
- except Exception as e:
- print(e)
PixelCNN生成妹纸图完整代码
下面代码只实现了unconditional模型(无条件),没有实现conditional和autoencoder模型。详细信息,请参看论文。
- # -*- coding: utf-8 -*-
- import tensorflow as tf
- import numpy as np
- import os
- import cv2
- # 如果使用mnist数据集,把MNIST设置为True
- MNIST = False
- if MNIST == True:
- from tensorflow.examples.tutorials.mnist import input_data
- data = input_data.read_data_sets('/tmp/')
- image_height = 28
- image_width = 28
- image_channel = 1
- batch_size = 128
- n_batches = data.train.num_examples // batch_size
- else:
- picture_dir = 'little_girls'
- picture_list = []
- # 建议不要把图片一次加载到内存,为了节省内存,最好边加载边使用
- for (dirpath, dirnames, filenames) in os.walk(picture_dir):
- for filename in filenames:
- if filename.endswith('.jpg'):
- picture_list.append(os.sep.join([dirpath, filename]))
- print("图像总数: ", len(picture_list))
- # 图像大小和Channel
- image_height = 64
- image_width = 64
- image_channel = 3
- # 每次使用多少样本训练
- batch_size = 128
- n_batches = len(picture_list) // batch_size
- #图片格式对应输入X
- img_data = []
- for img_file in picture_list:
- img_data.append(cv2.imread(img_file))
- img_data = np.array(img_data)
- img_data = img_data / 255.0
- #print(img_data.shape) # (44112, 64, 64, 3)
- X = tf.placeholder(tf.float32, shape=[None, image_height, image_width, image_channel])
- def gated_cnn(W_shape_, fan_in, gated=True, payload=None, mask=None, activation=True):
- W_shape = [W_shape_[0], W_shape_[1], fan_in.get_shape()[-1], W_shape_[2]]
- b_shape = W_shape_[2]
- def get_weights(shape, name, mask=None):
- weights_initializer = tf.contrib.layers.xavier_initializer()
- W = tf.get_variable(name, shape, tf.float32, weights_initializer)
- if mask:
- filter_mid_x = shape[0]//2
- filter_mid_y = shape[1]//2
- mask_filter = np.ones(shape, dtype=np.float32)
- mask_filter[filter_mid_x, filter_mid_y+1:, :, :] = 0.
- mask_filter[filter_mid_x+1:, :, :, :] = 0.
- if mask == 'a':
- mask_filter[filter_mid_x, filter_mid_y, :, :] = 0.
- W *= mask_filter
- return W
- if gated:
- W_f = get_weights(W_shape, "v_W", mask=mask)
- W_g = get_weights(W_shape, "h_W", mask=mask)
- b_f = tf.get_variable("v_b", b_shape, tf.float32, tf.zeros_initializer)
- b_g = tf.get_variable("h_b", b_shape, tf.float32, tf.zeros_initializer)
- conv_f = tf.nn.conv2d(fan_in, W_f, strides=[1,1,1,1], padding='SAME')
- conv_g = tf.nn.conv2d(fan_in, W_g, strides=[1,1,1,1], padding='SAME')
- if payload is not None:
- conv_f += payload
- conv_g += payload
- fan_out = tf.mul(tf.tanh(conv_f + b_f), tf.sigmoid(conv_g + b_g))
- else:
- W = get_weights(W_shape, "W", mask=mask)
- b = tf.get_variable("b", b_shape, tf.float32, tf.zeros_initializer)
- conv = tf.nn.conv2d(fan_in, W, strides=[1,1,1,1], padding='SAME')
- if activation:
- fan_out = tf.nn.relu(tf.add(conv, b))
- else:
- fan_out = tf.add(conv, b)
- return fan_out
- def pixel_cnn(layers=12, f_map=32):
- v_stack_in, h_stack_in = X, X
- for i in range(layers):
- filter_size = 3 if i > 0 else 7
- mask = 'b' if i > 0 else 'a'
- residual = True if i > 0 else False
- i = str(i)
- with tf.variable_scope("v_stack"+i):
- v_stack = gated_cnn([filter_size, filter_size, f_map], v_stack_in, mask=mask)
- v_stack_in = v_stack
- with tf.variable_scope("v_stack_1"+i):
- v_stack_1 = gated_cnn([1, 1, f_map], v_stack_in, gated=False, mask=mask)
- with tf.variable_scope("h_stack"+i):
- h_stack = gated_cnn([1, filter_size, f_map], h_stack_in, payload=v_stack_1, mask=mask)
- with tf.variable_scope("h_stack_1"+i):
- h_stack_1 = gated_cnn([1, 1, f_map], h_stack, gated=False, mask=mask)
- if residual:
- h_stack_1 += h_stack_in
- h_stack_in = h_stack_1
- with tf.variable_scope("fc_1"):
- fc1 = gated_cnn([1, 1, f_map], h_stack_in, gated=False, mask='b')
- color = 256
- with tf.variable_scope("fc_2"):
- fc2 = gated_cnn([1, 1, image_channel * color], fc1, gated=False, mask='b', activation=False)
- fc2 = tf.reshape(fc2, (-1, color))
- return fc2
- def train_pixel_cnn():
- output = pixel_cnn()
- loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(output, tf.cast(tf.reshape(X, [-1]), dtype=tf.int32)))
- trainer = tf.train.RMSPropOptimizer(1e-3)
- gradients = trainer.compute_gradients(loss)
- clipped_gradients = [(tf.clip_by_value(_[0], -1, 1), _[1]) for _ in gradients]
- optimizer = trainer.apply_gradients(clipped_gradients)
- with tf.Session() as sess:
- sess.run(tf.initialize_all_variables())
- saver = tf.train.Saver(tf.trainable_variables())
- for epoch in range(50):
- for batch in range(n_batches):
- if MNIST == True:
- batch_X, _ = data.train.next_batch(batch_size)
- batch_X = batch_X.reshape([batch_size, image_height, image_width, image_channel])
- else:
- batch_X = img_data[batch_size * batch : batch_size * (batch + 1)]
- _, cost = sess.run([optimizer, loss], feed_dict={X:batch_X})
- print("epoch:", epoch, ' batch:', batch,' cost:', cost)
- if epoch % 7 == 0:
- saver.save(sess, "girl.ckpt", global_step=epoch)
- # 训练
- train_pixel_cnn()
- def generate_girl():
- output = pixel_cnn()
- predict = tf.reshape(tf.multinomial(tf.nn.softmax(output), num_samples=1, seed=100), tf.shape(X))
- #predict_argmax = tf.reshape(tf.argmax(tf.nn.softmax(output), dimension=tf.rank(output) - 1), tf.shape(X))
- with tf.Session() as sess:
- sess.run(tf.initialize_all_variables())
- saver = tf.train.Saver(tf.trainable_variables())
- saver.restore(sess, 'girl.ckpt-49')
- pics = np.zeros((1*1, image_height, image_width, image_channel), dtype=np.float32)
- for i in range(image_height):
- for j in range(image_width):
- for k in range(image_channel):
- next_pic = sess.run(predict, feed_dict={X:pics})
- pics[:, i, j, k] = next_pic[:, i, j, k]
- cv2.imwrite('girl.jpg', pics[0])
- print('生成妹子图: girl.jpg')
- # 生成图像
- generate_girl()
额,妹子图正在训练中…
补充练习:使用OpenCV提取图像中的脸,然后使用上面模型进行训练,看看能生成什么。
- Deep Generative Image Models using a Laplacian Pyramid of Adversarial Networks
tf9: PixelCNN相关推荐
- MachineLN博客目录
MachineLN博客目录 https://blog.csdn.net/u014365862/article/details/78422372 本文为博主原创文章,未经博主允许不得转载.有问题可以加微 ...
- 如何比较PixelCNN与DCGAN两种Image generation方法?
今天组会读了一下deepmind的PixelCNN(nips的那篇),不是很明白到底为什么follow的work这么多(而且pixel rnn还拿了best paper award..),感觉pixe ...
- TensorFlow练习9: 生成妹子图(PixelCNN)
前一帖生成音乐,本帖生成图片.本文使用TensorFlow实现论文<Conditional Image Generation with PixelCNN Decoders>,它是基于Pix ...
- 自回归模型PixelCNN 的盲点限制以及如何修复
来源: DeepHub IMBA 本文约4500字,建议阅读10分钟 本篇文章我们将关注 PixelCNNs 的最大限制之一(即盲点)以及如何改进以修复它. 在这篇文章中我们将介绍盲点的概念,讨论 P ...
- PaperNotes(13)-Conditional Image Generation with PixelCNN Decoders
conditional Image generation with PixelCNN Decoders ICML的best paper pixel cnn 属于完全可见的信念网络,需要对 概率密度 建 ...
- 生成模型——自回归模型详解与PixelCNN构建
生成模型--自回归模型详解与PixelCNN构建 自回归模型(Autoregressive models) 简介 PixelRNN 使用TensorFlow 2构建PixelCNN模型 输入和标签 掩 ...
- 图片生成模型——gated pixelCNN
google DeepMind团队在<pixel recurrent neural networks>中提出了pixelRNN/CNN之后又发表了一篇论文--<Conditional ...
- 【2017CS231n】第十三讲:生成模型(PixelRNN/PixelCNN,变分自编码器,生成对抗网络)
一.有监督学习与无监督学习 有监督学习我们都很熟悉了,我们有数据x和标签y,我们在有监督学习中学习到一个函数可以将数据x映射到标签y,标签可以有很多形式.典型的有监督学习有:分类问题中输入一张图片,输 ...
- 什么是PixelCNN
PixelCNN是一种自回归模型,自回归模型是生成模型的一种. DeepMind在2016年推出了PixelCNN,该模型开启了自回归生成模型系列之一.它已被用于生成语音,视频和高分辨率图片 Pixe ...
最新文章
- 洛谷p1162填涂颜色(dfs写法)
- 一次奇怪的AP注册异常问题处理
- 编程语言的排名取决于应用场景和主要公司的需求
- 解决SQL命令行回退的问题
- LeetCode 1504. 统计全 1 子矩形(记录左侧的连续1的个数)
- 获取iOS任意线程调用堆栈(四)符号化实战
- 5款替代微软Visio的开源免费软件(转)
- 金蝶云星空使用WebAPI来新增单据
- opencv 轮廓放大_基于openCV,PIL的深色多背景复杂验证码图像转灰度二值化,并去噪降噪处理分析...
- 程序员除了编代码,还能做哪些职业规划?
- 关于.NET异常处理的思考
- 「学习路线分享」SLAM/深度估计/三维重建/相机标定/传感器融合(目录)
- 如何将本地文件夹映射为硬盘盘符?
- CSS 如何让 height:100%;起作用
- Matlab Robitic Toolbox学习笔记Day1
- leetcode之幂集(C++)
- BC26通过LWM2M接入电信AEP平台(透传模式)
- fv计算机公式,p=fv是什么公式
- 202102-一个小屁民的若有所思
- 虚拟机如何安装优麒麟19.10
热门文章
- 向量时钟同步算法_如何让超级下载算法在不同CortexM内核下也能跑?
- java hibernate configuration 获取_1 Hibernate Configuration 配置
- 第一章课后习题(Java)
- linux反汇编暴力破解,逆向教程之-反编译apk暴力去除弹窗和更新提示(三)
- 使用Redis 管理事务(Java)
- 神经网络入门——14多层感知机
- Linux常见命令(五)——rmdir
- 怎么使用CorelDRAW 中的默认调色板
- IC卡读卡器web开发,支持IE,Chrome,Firefox,Safari,Opera等主流浏览 器
- jquery 临时存值