学习目标:

  1. 理解自动编码器的基本原理。
  2. 掌握利用自动编码器进行图像去燥的方法。

学习内容:

  1. 对cifar10数据库,变为黑白图像,划分训练集和测试集,加上噪音。构造自动编解码器,用加上噪音的训练集和原图进行训练,然后用等于加上噪音的测试集去噪。

学习过程:

模型各层的输出大小:

自动解码模型的输出大小:

一次迭代后的测试图片:

测试图片输出时变得更加模糊了;

可以看出去噪后的图片比去噪前的图片更加难以辨认了;


源码:

# In[0]: 读取数据
from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras.datasets import cifar10
from keras import backend as Kimport numpy as np
import matplotlib.pyplot as plt#加载手写数字图片数据
(x_train, _), (x_test, _) = cifar10.load_data()
image_size = x_train.shape[1]#把彩色图转化为灰度图,如果当前像素点为[r,g,b],那么对应的灰度点为0.299*r+0.587*g+0.114*b
def rgb2gray(rgb):return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])x_train = rgb2gray(x_train)
x_test = rgb2gray(x_test)#把图片大小统一转换成28*28,并把像素点值都转换为[0,1]之间
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255# In[1]: 构造编码网络input_shape = (image_size, image_size, 1)
batch_size = 32
#对图片做3*3分割
kernel_size = 3
#让编码器将输入图片编码成含有16个元素的向量
latent_dim = 16
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
'''
编码器含有两个卷积层,卷积核的大小为3*3
第一个卷积层有32个输出通道,第二个卷积层有64个输出通道
'''
layer_filters = [32, 64]
#stride=2表明每次挪到2个像素,如此一来做一次卷积运算后输出大小会减半
x = Conv2D(filters = layer_filters[0], kernel_size = kernel_size, activation='relu',strides = 2, padding = 'same')(x)
x = Conv2D(filters = layer_filters[1], kernel_size = kernel_size, activation='relu',strides = 2, padding = 'same')(x)shape = K.int_shape(x)
print('shape: ', shape)
print(shape[1])
x = Flatten()(x)
#最后一层全连接网络输出含有16个元素的向量
latent = Dense(latent_dim, name = 'latent_vector')(x)
encoder = Model(inputs, latent, name='encoder')
encoder.summary()# In[2]:构造解码器,解码器的输入正好是编码器的输出结果latent_inputs = Input(shape = (latent_dim, ), name = 'decoder_input')
'''
它的结构正好和编码器相反,它先是一个全连接层,然后是两层反卷积网络
'''
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)#两层与编码器对应的反卷积网络
x = Conv2DTranspose(filters = layer_filters[1], kernel_size = kernel_size,activation='relu', strides = 2, padding = 'same')(x)
x = Conv2DTranspose(filters = layer_filters[0], kernel_size = kernel_size,activation='relu', strides = 2, padding = 'same')(x)outputs = Conv2DTranspose(filters = 1, kernel_size = kernel_size,activation = 'sigmoid',padding = 'same',name = 'decoder_output')(x)
decoder = Model(latent_inputs, outputs, name = 'decoder')
decoder.summary()# In[3]: 联合上述的编码器和解码器,构造自动编解码器autoencoder = Model(inputs, decoder(encoder(inputs)), name = 'autoencoder')
autoencoder.summary()'''
网络训练时,我们采用最小和方差,也就是我们希望解码器输出的图片与输入编码器的图片,在像素上的差异
尽可能的小
'''
autoencoder.compile(loss='mse', optimizer='adam')
autoencoder.fit(x_train, x_train, validation_data=(x_test, x_test),epochs = 1,batch_size = batch_size)'''
x_test是输入编码器的测试图片,我们看看解码器输出的图片与输入时是否差别不大
'''
x_decoded = autoencoder.predict(x_test)
#把测试图片集中的前8张显示出来,看看解码器生成的图片是否与原图片足够相似
imgs = np.concatenate([x_test[:8], x_decoded[: 8]])
imgs = imgs.reshape((4, 4, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Input: 1st 2 rows, Decoded: last 2 rows')
plt.imshow(imgs, interpolation='none', cmap='gray')
plt.show()# In[4]: 为图像像素点增加高斯噪音noise = np.random.normal(loc=0.5, scale = 0.5, size = x_train.shape)
x_train_noisy = x_train + noise
noise = np.random.normal(loc=0.5, scale = 0.5, size = x_test.shape)
x_test_noisy = x_test + noise
#添加噪音值后,像素点值可能会超过1或小于0,我们把这些值调整到[0,1]之间
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)# In[5]: 利用自动编解码器进行图像去噪autoencoder = Model(inputs, decoder(encoder(inputs)), name = 'autoencoder')
autoencoder.compile(loss='mse', optimizer='adam')
autoencoder.fit(x_train_noisy, x_train, validation_data = (x_test_noisy, x_test),epochs = 2, # 迭代两次batch_size = batch_size)# In[6]: 获取去噪后的图片x_decode = autoencoder.predict(x_test_noisy)
'''
将去噪前和去噪后的图片显示出来,第一行是原图片,第二行时增加噪音后的图片,
第三行时去除噪音后的图片
'''
rows , cols = 3, 9
num = rows * cols
imgs = np.concatenate([x_test[:num], x_test_noisy[:num], x_decode[:num]])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis = 1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = (imgs * 255).astype(np.uint8)
plt.figure(dpi=120)
plt.axis('off')
plt.title('first row: original image , middle row: noised image, third row: denoised image')
plt.imshow(imgs, interpolation='none', cmap='gray')
plt.show()

源码下载


学习产出:

  1. 图片去噪后的效果变差了。

人工智能--自动编码器相关推荐

  1. 开源人工智能使用卷积网格自动编码器生成3D面部

    开源人工智能使用卷积网格自动编码器生成3D面部摘要:人脸的学习3D表示对于计算机视觉问题是有用的,例如3D面部跟踪和从图像重建,以及诸如角色生成和动画的图形应用.传统模型使用线性子空间或高阶张量概括来 ...

  2. 本科-人工智能复习题

    填空题 首次提出"人工智能"是在(1956)年 . 下列不属于人工智能的研究方法的是(动作模拟法). (语义网络法)是知识的一种结构化图解表示,它由节点和弧线或链线组成,能把实体的 ...

  3. 不要上手就学深度学习!超详细的人工智能专家路线图,GitHub数天获2.1k星

    来源:机器之心 本文约1600字,建议阅读5分钟 这个学习路线图几乎涵盖了人工智能领域的所有内容,点点鼠标,就能链接所需知识. 想从事人工智能领域的研究,盲目地在网上购买了一本又一本的参考资料,学习视 ...

  4. 谷歌15个人工智能开源免费项目!开发者:懂了

    2019-11-21 14:37:20 关于人工智能的开源项目,相信开发者们已经目睹过不少了,Github上也有大把的资源.不过笔者今天说的并非来自Github,而是来自科技"大厂" ...

  5. 人工智能之机器学习算法体系汇总

    https://www.toutiao.com/i6638371599303049731/ 2018-12-24 09:52:12 此处梳理出面向人工智能的机器学习方法体系,主要体现机器学习方法和逻辑 ...

  6. 人工智能之机器学习常见算法

    https://blog.csdn.net/BaiHuaXiu123/article/details/51475384 摘要 之前一直对机器学习很感兴趣,一直没时间去研究,今天刚好是周末,有时间去各大 ...

  7. 干货丨机器学习必备:前20名Python人工智能和机器学习开源项目

    如今机器学习和人工智能已经变得家喻户晓,有很多爱好者进入了该领域.但是,什么才是能够进入该领域的正确路径呢?如何保持自己跟上该领域的发展步伐呢? 为了解决以上两个问题,可以通过利用高级专业人员每天使用 ...

  8. 扩散模型就是自动编码器!DeepMind研究学者提出新观点并论证

    来源:明敏 发自 凹非寺 量子位 | 公众号 QbitAI 由于在图像生成效果上可以与GAN媲美,扩散模型最近成为了AI界关注的焦点. 谷歌.OpenAI过去一年都提出了自家的扩散模型,效果也都非常惊 ...

  9. 他们提出了一个大胆的猜想:GWT(深度学习)→通用人工智能

    来源:AI科技评论 编译 :陈彩娴 近日,有一篇发表在arXiv的论文"Deep Learning and the Global Workspace Theory"提出了一个大胆的 ...

最新文章

  1. Linux文本处理(二)
  2. Visio替代图表工具 - 为什么Visual Paradigm Online?
  3. 华为如何打造智能终端的有趣灵魂?(下)
  4. 统计学习笔记(4)——朴素贝叶斯法
  5. docker设置http_proxy https_proxy解决gcr.io/kaniko-project/executor:v1.7.0之类的镜像拉取问题
  6. 在一台机器上运行多个ActiveMQ实例
  7. request.getAttribute()的数据类型转换问题
  8. 支持向量机: Maximum Margin Classifier
  9. 供应商去市网维护银行账号信息_供应商信息中心是BBP系统中一项很重要的内容...
  10. java拦截器_springMVC入门(八)------拦截器
  11. “欣喜”和“郁闷”交织的2006
  12. 学术会议html模板,学术会议poster模板
  13. psftp文件的上传下载
  14. 日记、2021/9/30
  15. 数据中心与灾备中心建设总结
  16. 软件工程专业的论文答辩_软件工程专业本科毕业答辩?
  17. RadASM的主题更换!
  18. Java初级程序员面试中应该如何准备?一般公司对Java开发的要求有哪些?
  19. Java实现九宫格游戏
  20. 乾坤大挪移——使用PQ分区魔术师扩大C盘空间

热门文章

  1. 基于平均不同分辨率的共振峰跟踪算法matlab仿真
  2. 黑莓9900(9930)网页在线播放音乐
  3. 新技术到底靠不靠谱?在中国用一下就知道了
  4. 单片机基础知识之定时计数器和寄存器
  5. 关于Tex的一般用法汇总(各种操作链接自己使用 一直更新)
  6. 电话号码区号插件vue-country-diacode-selector
  7. 安装NGrabLite 录像DM500
  8. 安装最好用的计算机软件,8个职场人必装的电脑软件,用过以后就离不开了,太好用!...
  9. 全部汽车零部件更换周期 汽车零部件固定更换周期
  10. Revit项目和族文件升级后出现无响应死机情况