下面我就以一些动漫头像为例,来说明怎样利用torch来进行训练和测试数据的预处理。下面是图片的格式:

上述图片一共有51223张,每个图片的大小为3*96*96。 下载地址为:百度云链接

网络的基本结构是通过 卷积层*2,全连接层*n,解码层(全连接层*m)输入和输出的数据是一样的,最多是压缩到三个神经元。压缩到三个神经元的目的有两个,一个是可以对图片进行可视化,三个神经元代表三个坐标轴XYZ,另一个目的就是通过对三个神经元的随机赋值,再通过解码层生成一个张图片,相当于使用自编码器作为一个生成模型(效果可能很差)。

下面是构造自编码网络和训练这个网络的代码:

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
#定义自编码器的网络结构
class AutoEncoder(nn.Module):def __init__(self):super(AutoEncoder, self).__init__()###############################################################self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=1,padding=1,),#->(16,96,96)nn.ReLU(),#->(16,96,96)nn.MaxPool2d(kernel_size=2),#->(16,48,48))                     #->(16,48,48)###############################################################self.conv2=nn.Sequential(nn.Conv2d(#->(16,48,48)in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1),#->(32,48,48)nn.ReLU(),nn.MaxPool2d(kernel_size=2),#->(32,24,24)            )###############################################################self.linear=nn.Sequential(nn.Linear( 32*24*24, 256 ), nn.Tanh(),  # 激活函数nn.Linear( 256, 64 ), nn.Tanh(),nn.Linear( 64, 12),nn.Tanh(),nn.Linear( 12 ,3),nn.Tanh())#)self.decoder=nn.Sequential(nn.Linear(3,12),nn.Tanh(),nn.Linear( 12, 64 ),nn.Tanh(),nn.Linear( 64, 128 ),nn.Tanh(),nn.Linear( 128, 96*96*3),nn.Sigmoid())def forward(self, x):x=self.conv1(x)x=self.conv2(x)x=x.view(x.size(0),-1)encoded=self.linear(x)decoded=self.decoder(encoded)return encoded,decoded
#训练并反向传播
def trainOneBatch(batch:torch.FloatTensor,raw:torch.FloatTensor):encoded,decoded=auto(batch)loss=loss_function(decoded,raw)optimizer.zero_grad()loss.backward()optimizer.step()
#前向传播获得误差
def testOneBatch(batch:torch.FloatTensor,raw:torch.FloatTensor):encoded,decoded=auto(batch)loss=loss_function(decoded,raw)return loss
#超参数
LR=0.001
BATCH_SIZE=100
EPOCHES=30
#获取gpu是不是可用
cuda_available=torch.cuda.is_available()
#实例化网络
auto=AutoEncoder() if cuda_available :auto.cuda()#定义优化器和损失函数
optimizer=torch.optim.Adam(auto.parameters(),lr=LR)
loss_function=nn.MSELoss()#数据准备
DIRECTORY= "E:\\DataSets\\facess\\faces"#这里是自己的图片的位置
files=os.listdir(DIRECTORY)
imgs=[]#构造一个存放图片的列表数据结构
for file in files:file_path=DIRECTORY+"\\"+fileimg=cv2.imread(file_path)imgs.append(img)print("train")#遍历迭代期
for i in range(EPOCHES):print(i)#打乱数据np.random.shuffle(imgs)count=0#count是为了凑齐成为一个batch_size的大小batch=[]for j in range(len(imgs)):img=imgs[j]count+=1batch.append(img)if count==BATCH_SIZE or j==len(imgs)-1:#这里就算最后#列表转成张量,再转换维度batch_train=torch.Tensor(batch).permute(0,3,2,1)/255#batch,3,96,96raw=batch_train.contiguous().view(batch_train.size(0),-1)#batch,3*96*96if cuda_available:raw=raw.cuda()#数据变换到gpu上batch_train=batch_train.cuda()trainOneBatch(batch_train,raw)#训练一个批次batch.clear()count=0batch.clear()#测试for j in range(100):batch.append(imgs[j])batch_train=torch.Tensor(batch).permute(0,3,2,1)/255raw=batch_train.contiguous().view(batch_train.size(0),-1)if cuda_available:raw=raw.cuda()batch_train=batch_train.cuda()#调用函数获得损失loss=testOneBatch(batch_train,raw)batch.clear()print(loss)#把训练的中间结果输出到本地文件torch.save(auto,"auto.pkl")

下面是读取训练完成之后的网络,然后进行生成图像的代码:


import torch
import torch.nn as nn
import numpy as npimport cv2
class AutoEncoder(nn.Module):def __init__(self):super(AutoEncoder, self).__init__()#self.encoder=nn.Sequential(     #->(3,96,96)###############################################################self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=1,padding=1,),#->(16,96,96)nn.ReLU(),#->(16,96,96)nn.MaxPool2d(kernel_size=2),#->(16,48,48))                     #->(16,48,48)###############################################################self.conv2=nn.Sequential(nn.Conv2d(#->(16,48,48)in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1),#->(32,48,48)nn.ReLU(),nn.MaxPool2d(kernel_size=2),#->(32,24,24)            )###############################################################self.linear=nn.Sequential(nn.Linear( 32*24*24, 256 ), nn.Tanh(),  # 激活函数nn.Linear( 256, 64 ), nn.Tanh(),nn.Linear( 64, 12),nn.Tanh(),nn.Linear( 12 ,3),nn.Tanh())#)self.decoder=nn.Sequential(nn.Linear(3,12),nn.Tanh(),nn.Linear( 12, 64 ),nn.Tanh(),nn.Linear( 64, 128 ),nn.Tanh(),nn.Linear( 128, 96*96*3),nn.Sigmoid())def forward(self, x):x=self.conv1(x)x=self.conv2(x)x=x.view(x.size(0),-1)encoded=self.linear(x)decoded=self.decoder(encoded)return encoded,decodedauto:AutoEncoder=torch.load("auto.pkl")
print(auto)
auto=auto.cuda()
for i in range(6):for j in range(6):for k in range(6):m=i/2.5-1n=i/2.5-1p=k/2.5-1print(m,n,p)x=torch.FloatTensor([m,n,p])decoded=auto.decoder(x.cuda())img=decoded.contiguous().view(3,96,96).permute(2,1,0).detach().cpu().numpy()*255cv2.imwrite("imgs/" + str(i)+str(j)+str(k) + ".jpg",cv2.cvtColor(img,cv2.COLOR_BGR2GRAY))cv2.waitKey(0)

下面是给解码器三个数字,然后生成的图像

我感觉生成的图像样式基本都一样,只是颜色不一样而已。我就怀疑这些生成的图片只是数据集中的图片的平均,然后就了一个程序,然后输出所有图片的平均值,得到的结果和使用自编码网络生成的图片基本一致,可能原因是网络太深,中间层的神经元数量太少。所以如果要用自编码作为生成器的话,可能还需要很多其他的策略。但是对于mnist数据来说的话,如果解码器生成的图片是所有数据集的平均的话,那么很显然得到的就不是一个数字了,实际程序发现,对于随机的输入一个数字,生成的确实也是数字的样子,并不是所有图片的平均,但是可能是某个类别的平均。

自编码器相当于通过编码器数据集投影到低纬度的表示空间,一般来说会损失一些信息,然后通过解码器把低维的数据,映射到高维的数据。那么利用训练好的网络中的编码器就可以实现降维并可视化的功能,利用解码器就可以从低维空间中选取一些点然后生成最后的数据。数据量再多也是不可能占满低维空间的所有位置的,因为空间中的点的个数是无限的,所以这就给扩充数据集提供了可能性。

以上均为自己理解,难免会有偏差,欢迎评论、指正!

PyTorch读取自己的本地图片数据集训练自编码器相关推荐

  1. pytorch加载自己的图片数据集的两种方法

    目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 接下来我们就可以构建我们的网络架构: 训练我们的网络: 保存网络模型(这 ...

  2. pytorch用FCN语义分割手提包数据集(训练+预测单张输入图片代码)

    一,手提包数据集 数据集下载:用pytorch写FCN进行手提包的语义分割. training data(https://github.com/yunlongdong/FCN-pytorch-easi ...

  3. 【Pytorch实战4】基于CIFAR10数据集训练一个分类器

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch中文文档 先是数据的导入与预览. import torch import torchvisio ...

  4. TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练、评估(偶尔100%准确度,交叉熵验证)

    TF之LSTM:利用LSTM算法对mnist手写数字图片数据集(TF函数自带)训练.评估(偶尔100%准确度,交叉熵验证) 目录 输出结果 设计思路 代码设计 输出结果 第 0 accuracy 0. ...

  5. 【PyTorch】构造VGG19网络进行本地图片分类(超详细过程)——项目介绍

    本篇博客主要解决以下3个问题: 如何自定义网络(以VGG19为例). 如何自建数据集并加载至模型中. 如何使用自定义数据训练自定义模型. 第一篇:[PyTorch]构造VGG19网络进行本地图片分类( ...

  6. pytorch保存准确率_初学Pytorch:MNIST数据集训练详解

    前言 本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试.针对过程中的每个步骤都尽可能的给出了详尽的解释. ...

  7. cifar10数据集测试有多少张图_pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)...

    首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层: 一,写VGG代码时,首先定义一个 vgg_block(n ...

  8. PyTorch 学习笔记(一):让PyTorch读取你的数据集

    本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial 文章目录 Dataset类 ...

  9. Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)

    Pytorch采用AlexNet实现猫狗数据集分类(训练与预测) 介绍 AlexNet网络模型 猫狗数据集 AlexNet网络训练 训练全代码 预测 预测图片 介绍 AlexNet模型是CNN网络中经 ...

最新文章

  1. mysql 写入400_MySQL5.7运行CPU达百分之400处理方案
  2. SAP PM 初级系列18 - 为维修工单分配Permit
  3. 一个球从100m高度自由落下,第10次反弹多高
  4. Vue 组件库 HeyUI@1.17.0 发布,新增 Skeleton 组件
  5. Objective C 基础教程
  6. 浅谈c/c++typedef和#define区别[转]
  7. rtmp测试地址_超详细搭建多码率测试环境(成为流媒体高手必经之路)
  8. mysql windows身份验证_SQL Server 2005 怎么就不能用Windows身份验证方式登录呢?
  9. 笨办法学python3_软件测试需要学什么(个人软件测试学习路线)
  10. 模板 - 计算几何(合集)
  11. photoshop 新建文档尺寸预设如何导出保存
  12. 唐宇迪学习笔记2:Python数据分析处理库——pandas
  13. matlab瑞利衰落信道仿真
  14. 电工学的MATLAB实践,基于Matlab/Simulink的电工学电路仿真
  15. 华为认证的考试费用和重认证
  16. android 集成腾讯定位,Android集成腾讯云通信IM
  17. 【面试篇】诚迈科技(外包)
  18. scratch做简单跑酷游戏_Scratch(七)篇外.用小动画和触碰能做大型游戏?
  19. 去携程实习了!半年时间,从机械转行 Java,二哥的读者真牛逼!
  20. 解决redis连接错误:MISCONF Redis is configured to save RDB snapshots, but it is currently not able to...

热门文章

  1. Bingo NFT 如何帮助交易者和投资者分析市场
  2. Pandas基础入门知识点总结
  3. [Python]This probably means that Tcl wasn‘t installed properly.(Windows 10)(pyinstaller库)
  4. Autocad2015点开闪退问题,线段等分
  5. 启动Intel TV-x设置
  6. Elasticsearch集群扩容踩坑记录
  7. 多元回归分析的心得(笔记)
  8. 力荐 50 个最实用的免费机器学习数据集
  9. 被一些数字整除的数字的特征
  10. 基于JavaScript网上商城开发设计 毕业设计-附源码261620