tensorflow 卷积、反卷积形式的去噪自编码器

对于去噪自编码器,网上好多都是利用全连接神经网络进行构建,我自己写了一个卷积、反卷积形式的去噪自编码器,其中的参数调优如果有兴趣的话,可以自行修改查看结果。


数据集我使用最简单的mnist:



网络结构:

mnist输入(28*28=784向量) => 28*28*1矩阵 => 卷积层1 => 14*14*64 => 卷积层2 => 7*7*64 => 卷积层3 => 4*4*32 => 反卷积层1 => 7×7*32 => 反卷积层2 => 14*14*64 => 反卷积层3 => 28*28*64 => 卷积层X => 28×28*1


训练:

我用train集训练train_epochs轮,然后用test集对训练好的模型进行评测,同时保存加噪图像及对应的去噪图像。


Code:

  1. #! -*- coding: utf-8 -*-
  2. ## by Colie (lijixiang)
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import numpy as np
  6. from PIL import Image
  7. train_epochs = 35  ## int(1e5+1)
  8. INPUT_HEIGHT = 28
  9. INPUT_WIDTH = 28
  10. batch_size = 256
  11. noise_factor = 0.5  ## (0~1)
  12. ## 原始输入是28×28*3
  13. input_x = tf.placeholder(tf.float32, [None, INPUT_HEIGHT * INPUT_WIDTH], name='input_with_noise')
  14. input_matrix = tf.reshape(input_x, shape=[-1, INPUT_HEIGHT, INPUT_WIDTH, 1])
  15. input_raw = tf.placeholder(tf.float32, shape=[None, INPUT_HEIGHT * INPUT_WIDTH], name='input_without_noise')
  16. ## 1 conv layer
  17. ## 输入28*28*3
  18. ## 经过卷积、激活、池化,输出14*14*64
  19. weight_1 = tf.Variable(tf.truncated_normal(shape=[3, 3, 1, 64], stddev=0.1, name = 'weight_1'))
  20. bias_1 = tf.Variable(tf.constant(0.0, shape=[64], name='bias_1'))
  21. conv1 = tf.nn.conv2d(input=input_matrix, filter=weight_1, strides=[1, 1, 1, 1], padding='SAME')
  22. conv1 = tf.nn.bias_add(conv1, bias_1, name='conv_1')
  23. acti1 = tf.nn.relu(conv1, name='acti_1')
  24. pool1 = tf.nn.max_pool(value=acti1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='max_pool_1')
  25. ## 2 conv layer
  26. ## 输入14*14*64
  27. ## 经过卷积、激活、池化,输出7×7×64
  28. weight_2 = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 64], stddev=0.1, name='weight_2'))
  29. bias_2 = tf.Variable(tf.constant(0.0, shape=[64], name='bias_2'))
  30. conv2 = tf.nn.conv2d(input=pool1, filter=weight_2, strides=[1, 1, 1, 1], padding='SAME')
  31. conv2 = tf.nn.bias_add(conv2, bias_2, name='conv_2')
  32. acti2 = tf.nn.relu(conv2, name='acti_2')
  33. pool2 = tf.nn.max_pool(value=acti2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='max_pool_2')
  34. ## 3 conv layer
  35. ## 输入7*7*64
  36. ## 经过卷积、激活、池化,输出4×4×32
  37. ## 原始输入是28*28*3=2352,转化为4*4*32=512,大量噪声会在网络中过滤掉
  38. weight_3 = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 32], stddev=0.1, name='weight_3'))
  39. bias_3 = tf.Variable(tf.constant(0.0, shape=[32]))
  40. conv3 = tf.nn.conv2d(input=pool2, filter=weight_3, strides=[1, 1, 1, 1], padding='SAME')
  41. conv3 = tf.nn.bias_add(conv3, bias_3)
  42. acti3 = tf.nn.relu(conv3, name='acti_3')
  43. pool3 = tf.nn.max_pool(value=acti3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='max_pool_3')
  44. ## 1 deconv layer
  45. ## 输入4*4*32
  46. ## 经过反卷积,输出7*7*32
  47. deconv_weight_1 = tf.Variable(tf.truncated_normal(shape=[3, 3, 32, 32], stddev=0.1), name='deconv_weight_1')
  48. deconv1 = tf.nn.conv2d_transpose(value=pool3, filter=deconv_weight_1, output_shape=[batch_size, 7, 7, 32], strides=[1, 2, 2, 1], padding='SAME', name='deconv_1')
  49. ## 2 deconv layer
  50. ## 输入7*7*32
  51. ## 经过反卷积,输出14*14*64
  52. deconv_weight_2 = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 32], stddev=0.1), name='deconv_weight_2')
  53. deconv2 = tf.nn.conv2d_transpose(value=deconv1, filter=deconv_weight_2, output_shape=[batch_size, 14, 14, 64], strides=[1, 2, 2, 1], padding='SAME', name='deconv_2')
  54. ## 3 deconv layer
  55. ## 输入14*14*64
  56. ## 经过反卷积,输出28*28*64
  57. deconv_weight_3 = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 64], stddev=0.1, name='deconv_weight_3'))
  58. deconv3 = tf.nn.conv2d_transpose(value=deconv2, filter=deconv_weight_3, output_shape=[batch_size, 28, 28, 64], strides=[1, 2, 2, 1], padding='SAME', name='deconv_3')
  59. ## conv layer
  60. ## 输入28*28*64
  61. ## 经过卷积,输出为28*28*1
  62. weight_final = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 1], stddev=0.1, name = 'weight_final'))
  63. bias_final = tf.Variable(tf.constant(0.0, shape=[1], name='bias_final'))
  64. conv_final = tf.nn.conv2d(input=deconv3, filter=weight_final, strides=[1, 1, 1, 1], padding='SAME')
  65. conv_final = tf.nn.bias_add(conv_final, bias_final, name='conv_final')
  66. ## output
  67. ## 输入28*28*1
  68. ## reshape为28*28
  69. output = tf.reshape(conv_final, shape=[-1, INPUT_HEIGHT * INPUT_WIDTH])
  70. ## loss and optimizer
  71. loss = tf.reduce_mean(tf.pow(tf.subtract(output, input_raw), 2.0))
  72. optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)
  73. with tf.Session() as sess:
  74. mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  75. n_samples = int(mnist.train.num_examples)
  76. print('train samples: %d' % n_samples)
  77. print('batch size: %d' % batch_size)
  78. total_batch = int(n_samples / batch_size)
  79. print('total batchs: %d' % total_batch)
  80. init = tf.global_variables_initializer()
  81. sess.run(init)
  82. for epoch in range(train_epochs):
  83. for batch_index in range(total_batch):
  84. batch_x, _ = mnist.train.next_batch(batch_size)
  85. noise_x = batch_x + noise_factor * np.random.randn(*batch_x.shape)
  86. noise_x = np.clip(noise_x, 0., 1.)
  87. _, train_loss = sess.run([optimizer, loss], feed_dict={input_x: noise_x, input_raw: batch_x})
  88. print('epoch: %04d\tbatch: %04d\ttrain loss: %.9f' % (epoch + 1, batch_index + 1, train_loss))
  89. ## 训练结束后,用测试集测试,并保存加噪图像、去噪图像
  90. n_test_samples = int(mnist.test.num_examples)
  91. test_total_batch = int(n_test_samples / batch_size)
  92. for i in range(test_total_batch):
  93. batch_test_x, _ = mnist.test.next_batch(batch_size)
  94. noise_test_x = batch_test_x + noise_factor * np.random.randn(*batch_test_x.shape)
  95. noise_test_x = np.clip(noise_test_x, 0., 1.)
  96. test_loss, pred_result = sess.run([loss, conv_final], feed_dict={input_x: noise_test_x, input_raw: batch_test_x})
  97. print('test batch index: %d\ttest loss: %.9f' % (i + 1, test_loss))
  98. for index in range(batch_size):
  99. array = np.reshape(pred_result[index], newshape=[INPUT_HEIGHT, INPUT_WIDTH])
  100. array = array * 255
  101. image = Image.fromarray(array)
  102. if image.mode != 'L':
  103. image = image.convert('L')
  104. image.save('./pred/' + str(i * batch_size + index) + '.png')
  105. array_raw = np.reshape(noise_test_x[index], newshape=[INPUT_HEIGHT, INPUT_WIDTH])
  106. array_raw = array_raw * 255
  107. image_raw = Image.fromarray(array_raw)
  108. if image_raw.mode != 'L':
  109. image_raw = image_raw.convert('L')
  110. image_raw.save('./pred/' + str(i * batch_size + index) + '_raw.png')
  111. #break</span>

去噪效果:

tensorflow 卷积、反卷积形式的去噪自编码器相关推荐

  1. DL之CNN:卷积神经网络算法简介之卷积矩阵、转置卷积(反卷积Transpose)、膨胀卷积(扩张卷积Dilated/带孔卷积atrous)之详细攻略

    DL之CNN:卷积神经网络算法简介之卷积矩阵.转置卷积(反卷积Transpose).膨胀卷积(扩张卷积Dilated/带孔卷积atrous)之详细攻略 目录 卷积矩阵的简介 卷积.转置卷积--Tran ...

  2. 卷积 反卷积 上采样 下采样 区别

    1.卷积 就是利用卷积核  步长前进 卷积整个图片 2.反卷积 反卷积的具体操作 原图输入尺寸为[1,3,3,3]对应[batch_size,channels,width,height] 反卷积tco ...

  3. tensorflow:双线性插值反卷积

    首先生成3×3×3的黑色图片 """ 生成3×3×3黑色图像 """ def produce_image():size = 3x, y = ...

  4. tensorflow实现反卷积

    先看ogrid用法 from numpy import ogrid,repeat,newaxis from skimage import io import numpy as np size=3 x, ...

  5. 通过图+代码来理解tensorflow中反卷积

    反卷积这个东西老是容易忘,而且很多文章理论讲的很详细,但反卷积实际怎么操作的却没有概念,因此想以自己喜欢的方式(直接上图和代码)写一篇,以便随时翻阅. 卷积 tf中的padding方式有两种,SAME ...

  6. 卷积/反卷积前后的张量尺寸计算

    1.下采样/卷积: 先定义几个参数 输入图片大小 :W×W Filter大小 :k×k 步长 :S padding的像素数 :P 输出图片大小为: N×N 于是我们可以得出:N = (W − k + ...

  7. 分组卷积/转置卷积/空洞卷积/反卷积/可变形卷积/深度可分离卷积/DW卷积/Ghost卷积/

    文章目录 1. 常规卷积 2. 分组卷积 3. 转置卷积 4. 空洞卷积 5. 可变形卷积 6. 深度可分离卷积(Separable Convolution) 6.1 Depthwise Convol ...

  8. python 反卷积(DeConv) tensorflow反卷积(DeConv)(实现原理+手写)

    Tensorflow反卷积(DeConv)实现原理+手写python代码实现反卷积(DeConv) 理解: https://www.zhihu.com/question/43609045/answer ...

  9. Tensorflow反卷积(DeConv)实现原理+手写python代码实现反卷积(DeConv)

    最近看到一个巨牛的人工智能教程,分享一下给大家.教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家.平时碎片时间可以当小说看,[点这里可以去膜拜一下大神的" ...

最新文章

  1. maven学习(4)-Maven 构建Web 项目
  2. 利用WampServer挂载MySQL数据库
  3. php客服窗口,制作一个客服小界面
  4. 深入struts2.0(七)--ActionInvocation接口以及3DefaultActionInvocation类
  5. 湖南大学让晶体管小至3纳米,沟道长度仅一层原子 | Nature子刊
  6. Queue接口及是实现类PriorityQueue介绍
  7. gridview自动换行
  8. 一周学习总结PPT-学会VLOOKUP函数,1分钟搞定数据汇总
  9. python自动化库_Python自动化测试常用库整理
  10. mysql boost 5.7.21_mysql 5.7.21 安装配置方法图文教程(window)
  11. [Java基础]Lambda表达式的省略模式
  12. C/S、B/S的区别
  13. Minidao_1.6.2版本发布,超轻量Java持久化框架
  14. 剑指offer面试题47. 礼物的最大价值(动态规划)
  15. linux文本编辑命令vim查找,Linux编辑器vi中文本搜索与替换操作
  16. NXP iMX8系列处理器核心性能对比测试
  17. Activity启动模式singleTask模式
  18. 虚拟现实计算机理论文献,虚拟现实在计算机教学中的应用研究
  19. CloudCompare:V2.6.3 菜单栏和工具栏 中英文对照 功能简述
  20. 基于Python的自动聊天机器人

热门文章

  1. mysql怎么在海量数据上ddl_浅谈MySQL Online DDL(中)
  2. css限制字体三行_CSS美化网页
  3. 分层和分段用什么符号_如何划分段落层次,如何给段落分层
  4. 在蓄电池管理系统中计算机应用,汽车电器与电子技术.docx
  5. mysql半备份_MySQL半同步复制与增强半同步复制详解及安装
  6. 据说电脑上可以刷朋友圈啦!又多了个上班摸鱼的途径?
  7. StackOverflow热帖:Java整数相加溢出怎么办?
  8. 赠书:聊聊「分布式架构」那些事儿
  9. JDK 14 里的调试神器了解一下?
  10. 你没见过Java台式计算机和Java操作系统吧