目录

ImageFolder 加载数据集

使用pytorch提供的Dataset类创建自己的数据集。

Dataset加载数据集

接下来我们就可以构建我们的网络架构:

训练我们的网络:

保存网络模型(这里不止是保存参数,还保存了网络结构)


pytorch加载图片数据集有两种方法。

1.ImageFolder 适合于分类数据集,并且每一个类别的图片在同一个文件夹, ImageFolder加载的数据集, 训练数据为文件件下的图片, 训练标签是对应的文件夹, 每个文件夹为一个类别

在Flower_Orig_dataset文件夹下有flower_orig 和 sunflower这两个文件夹, 这两个文件夹下放着同一个类别的图片。 使用 ImageFolder 加载的图片, 就会返回图片信息和对应的label信息, 但是label信息是根据文件夹给出的, 如flower_orig就是标签0, sunflower就是标签1。

ImageFolder 加载数据集

1. 导入包和设置transform

from torchvision.datasets import ImageFolderimport torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoadertransforms = transforms.Compose([transforms.Resize(256),    # 将图片短边缩放至256,长宽比保持不变:transforms.CenterCrop(224),   #将图片从中心切剪成3*224*224大小的图片transforms.ToTensor()          #把图片进行归一化,并把数据转换成Tensor类型
])

2.加载数据集: 将分类图片的父目录作为路径传递给ImageFolder(), 并传入transform。这样就有了要加载的数据集, 之后就可以使用DataLoader加载数据, 并构建网络训练。

path = r'D:\dataset_deep_learning\Flower_Orig_dataset'data_train = datasets.ImageFolder(path, transform=transforms)data_loader = DataLoader(data_train, batch_size=64, shuffle=True)for i, data in enumerate(data_loader):images, labels = data# 打印数据集中的图片img = torchvision.utils.make_grid(images).numpy()plt.imshow(np.transpose(img, (1, 2, 0)))plt.show()break

使用pytorch提供的Dataset类创建自己的数据集。

具体步骤:

1.  首先要有一个txt文件, 这个文件格式是: 图片路径   标签  图片文件夹.  这样的格式, 所以使用os库, 遍历自己的图片名, 并把标签和图片路径写入txt文件。

2. 有了这个txt文件, 我们就可以在类里面构造我们的数据集.

2.1    把图片路径和图片标签分割开, 有三个列表, 一个列表是图片路径名, 一个列表是标签号,一个列表是这类图片的文件夹 。 有一点就是第 i 个图片列表和 第 i 个标签是对应的

3. 重写__len__方法  和  __getitem__方法

3.1 getitem方法中, 获得对应的图片路径,并用PIL库读取文件把图片transfrom后, 在getitem函数中返回读取的图片和标签即可

4.就可以构建数据集实例和加载数据集.

文件结构如图:

定义一个用来生成[ 图片路径 标签  该类图片文件夹名] 这样的txt文件函数(因为用了a追加的方式,所以,flower_orig和sunflower两个文件夹下的都被写进data.txt文件了)

#打开存放图片的文件夹,然后遍历文件名,把文件名字, label 还有 文件夹名写入data.txt文件中。import osdef make_txt(root, file_name, label):path = os.path.join(root, file_name)  data = os.listdir(path)f = open(root + '\\' + 'data.txt', 'a')for line in data:f.write(line + ' ' + str(label) + ' ' + file_name + '\n')f.close()path = r'D:\dataset_deep_learning\Flower_Orig_dataset'# 调用函数生成两个文件夹下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)

现在看看查看data.txt文件的格式如图:(由图中三部分组成)

现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息

Dataset加载数据集

#我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径
import osimport numpy as np
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transformstransforms = transforms.Compose([transforms.Resize(256),    # 将图片短边缩放至256,长宽比保持不变:transforms.CenterCrop(224),   #将图片从中心切剪成3*224*224大小的图片transforms.ToTensor()          #把图片进行归一化,并把数据转换成Tensor类型
])class MyDataset(Dataset):def __init__(self, img_path, transform=None):super(MyDataset, self).__init__()self.root = img_pathself.txt_root = self.root + '\\' + 'data.txt'f = open(self.txt_root, 'r')data = f.readlines()imgs = []labels = []for line in data:line = line.rstrip()word = line.split()#print(word[0], word[1], word[2])   #word[0]是图片名字.jpg  word[1]是label  word[2]是文件夹名,如sunflowerimgs.append(os.path.join(self.root,word[2], word[0]))labels.append(word[1])self.img = imgsself.label = labelsself.transform = transformdef __len__(self):return len(self.label)def __getitem__(self, item):img = self.img[item]label = self.label[item]img = Image.open(img).convert('RGB')# 此时img是PIL.Image类型   label是str类型if self.transform is not None:img = self.transform(img)label = np.array(label).astype(np.int64)label = torch.from_numpy(label)return img, label

加载我们的数据集并查看我们加载到图片:

path = r'D:\数据集\Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)for i, data in enumerate(data_loader):images, labels = data# 打印数据集中的图片img = torchvision.utils.make_grid(images).numpy()plt.imshow(np.transpose(img, (1, 2, 0)))plt.show()break

接下来我们就可以构建我们的网络架构:

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3,16,3)self.maxpool = nn.MaxPool2d(2,2)self.conv2 = nn.Conv2d(16,5,3)self.relu = nn.ReLU()self.fc1 = nn.Linear(55*55*5, 1200)self.fc2 = nn.Linear(1200,64)self.fc3 = nn.Linear(64,2)def forward(self,x):x = self.maxpool(self.relu(self.conv1(x)))    #113x = self.maxpool(self.relu(self.conv2(x)))    #55x = x.view(-1, self.num_flat_features(x))x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_features

训练我们的网络:

model = Net()criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)epochs = 10
for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(data_loader):images, label = dataout = model(images)loss = criterion(out, label)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()if(i+1)%10 == 0:print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))running_loss = 0.0print('finished train')

保存网络模型(这里不止是保存参数,还保存了网络结构)

#保存模型
torch.save(net, 'model_name.pth')   #保存的是模型, 不止是w和b权重值# 读取模型
model = torch.load('model_name.pth')

pytorch加载自己的图片数据集的两种方法相关推荐

  1. 【PyQt】pyqt加载调用ui界面文件的两种方法

    使用PyQt开发界面软件,自然会用到Qt Designer进行界面设计,拖拖按钮.设置菜单什么的,然后保存为 .ui 文件.但是在 Python 代码里面如何使用这些 .ui 文件呢? 有两种方法: ...

  2. qt 加载 图片旋转_QT 实现图片旋转的两种方法

    第一种方案 使用 QPixmap 的 transformed 函数来实现旋转,这个函数默认是以图片中心为旋转点,不能设置旋转的中心点,使用如下: QMatrix matrix; matrix.rota ...

  3. ie加载项存在残留是什么_Win7系统遇到IE加载项故障的原因及两种解决办法

    在我们使用的系统中,都是有自带的IE浏览器,当然在我们使用的Win7系统中也不例外,可是在使用浏览器的过程中,也会出现各种各样的情况,在出现问题的时候就要看你怎样进行解决了.而最近就有用户反映,在IE ...

  4. D3D中2D图片的绘制两种方法

    2014/09/19 (转载自:http://blog.csdn.net/rabbit729/article/details/6388703) 想要在D3D中加载2D图片可以使用如下两种方法(我只想到 ...

  5. android 图片叠加xml,Android实现图片叠加效果的两种方法

    本文实例讲述了Android实现图片叠加效果的两种方法.,具体如下: 效果图: 第一种: 第二种: 第一种是通过canvas画出来的效果: public void first(View v) { // ...

  6. android 画布叠加,Android实现图片叠加效果的两种方法

    本文实例讲述了Android实现图片叠加效果的两种方法.分享给大家供大家参考,具体如下: 效果图: 第一种: 第二种: 第一种是通过canvas画出来的效果: public void first(Vi ...

  7. java 图片压缩100k_Java 实现图片压缩的两种方法

    问题背景. 典型的情景:Nemo社区中,用户上传的图片免不了要在某处给用户做展示. 如用户上传的头像,那么其他用户在浏览该用户信息的时候,就会需要回显头像信息了. 用户上传的原图可能由于清晰度较高而体 ...

  8. 目标检测(3)—— 如何使用PyTorch加载COCO类型的数据集

    一.如何使用PyTorch加载COCO数据集 打开pytorch的官网 可以看到COCO数据集不提供下载 回顾json文件里面都有什么:"annotations"里面有" ...

  9. android获取位图字节数,Android中获取图片尺寸大小两种方法

    两种方法  建议用第二种 private void getPictureSize(String path) { /*第一种直接把bitmap加载到内存中,通过对bitmap的测量, 得出宽高,由于这个 ...

最新文章

  1. kong api gateway 初体验
  2. 熟悉交换机与路由器组网(图解)
  3. Oracle编程入门经典 第1章 了解Oracle
  4. 简述python程序执行原理_Python程序的执行原理(1)
  5. 中海达数据怎么转rinex_cors账号网最新实战教程,中海达 F61 Plus RTK连接千寻cors账号的方法...
  6. join left semi_Hive的left join、left outer join和left semi join三者的区别
  7. 【机器人操作系统】ROS话题编程
  8. gradle 不支持多级子模块_Apache NetBeans 11.0 正式发布 支持Java 12
  9. python中re怎么念_Python,Re模块的学习
  10. ENVI辐射校正(辐射定标+大气校正)
  11. 按键精灵python插件_按键精灵必须掌握的命令之插件命令
  12. java实现将汉字转为拼音并包含音调
  13. python母亲节代码_python 计算 母亲节
  14. 房地产微信营销方案微信“危”与“机”
  15. 【java】-XX:-OmitStackTraceInFastThrow只有空指针,没有堆栈信息
  16. stm32实现毫秒ms微秒us级延时
  17. 网络层——IP数据报详解
  18. 直流无刷电机【一】从零开始上手
  19. 笔记本onenote绘画快捷键_onenote快捷键
  20. Syclover战队专访 | 年度终局之战,键指圣诞狂欢

热门文章

  1. 局域网lan_什么是局域网(LAN)?
  2. 1.[QT | QCharts | 动态显示]折线图标题字体大小无法更改
  3. 4.加载FeatureLayer
  4. SPSS多元线性回归
  5. Learning Shape Representations for Clothing Variations in Person Re-Identification
  6. OpenCV计算机视觉(三) —— 图像的几何变换
  7. 携程java面经 一二HR面面经
  8. 如何给笔记本安装固态硬盘
  9. We Dont Kown ....
  10. 快速剪辑视频,每个视频按秒数快速分割,并保留原声