Pytorch学习笔记7——自定义数据集


1.读取数据

首先继承自torch.utils.data.Dataset
重写len与getitem

train就用train数据集,test就用test数据集。




自定义数据集的读取

import torch
import os,glob
import random,csvimport torchvision.datasets
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
from torchvision.transforms import InterpolationModeclass Pokemon(Dataset):def __init__(self,root,resize,mode):super(Pokemon, self).__init__()self.root=rootself.resize=resizeself.name2label={}#字典表达映射关系label‘‘sq..’’:0for name in sorted(os.listdir(os.path.join(root))):#遍历根目录下所有文件假if not os.path.isdir(os.path.join(root,name)):#判断是否是文件夹continueself.name2label[name]=len(self.name2label.keys())print(self.name2label)self.images,self.labels=self.load_csv('images.csv')#得到的是images的路径,和对应的数字标签if mode=='train':self.images=self.images[:int(0.6*len(self.images))]self.labels=self.labels[:int(0.6*len(self.labels))]elif mode=='val':#20self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]else:self.images=self.images[int(0.8*len(self.images)):]self.labels=self.labels[int(0.8*len(self.labels)):]#创建数据对:path+labeldef load_csv(self,filename):if not os.path.exists(os.path.join(self.root,filename)):#如果已经有了,不需要再创建images=[]for name in self.name2label.keys():#key:valueimages+=glob.glob(os.path.join(self.root,name,'*.png'))#glob方法获取目录下所有满足的文件images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))#1165,pokeman/bulbasaur/00001.png#对应关系保存到csvrandom.shuffle(images)with open(os.path.join(self.root,filename),mode='w',newline='') as f:writer=csv.writer(f)for img in images:#pokeman/bulbasaur/00001.pngname=img.split(os.sep)[-2]label=self.name2label[name]#字典根据key找value存入labelwriter.writerow([img,label])#pokeman/bulbasaur/00001.png,0print('writen into csvfile:',filename)#read from csvimages,labels=[],[]with open(os.path.join(self.root,filename)) as f:reader=csv.reader(f)for row in reader:img,label=rowlabel=int(label)images.append(img)labels.append(label)assert len(images)==len(labels)return images,labelsdef __len__(self):return len(self.images)def denormalize(self,x_hat):#逆归一化已回复图片视觉效果mean = [0.845, 0.456, 0.406]std = [0.229, 0.224, 0.225]#x_hat=(x-mean)/std#x=x_hat*std+mean#x:[c,h,w]#mean:[3]=>[3,1,1]mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)print(mean.shape,std.shape)x=x_hat*std+meanreturn xdef __getitem__(self, idx):#self.images,self.labels#idx-[0-len(images)]img,label=self.images[idx],self.labels[idx]#从csv获得的图片路径与labeltf=transforms.Compose([lambda x:Image.open(x).convert('RGB'),#string path=>image datatransforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.845,0.456,0.406],std=[0.229,0.224,0.225])])img=tf(img)label=torch.tensor(label)return img,label
if __name__=='__main__':import visdomimport timeviz=visdom.Visdom()# tf = transforms.Compose([#     transforms.Resize((64,64)),#     transforms.ToTensor(),# ])# db=torchvision.datasets.ImageFolder(root='/home/lizheng/Study/yolov5-5.0/pytorch1/pokemon/pokeman',transform=tf)# loader=DataLoader(db,batch_size=32,shuffle=True)#一行代码完成数据集加载工作# print(db.class_to_idx)# for x,y in loader:#     viz.images(x,nrow=8,win='batch',opts=dict(title='batch'))#     viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))##     time.sleep(10)db=Pokemon('pokemon/pokeman',64,'train')x,y=next(iter(db))#利用迭代器输入路径获得具体图像,得到第一个样本,调用时自动使用getitem函数,此时x是图像print('sample:',x.shape,y.shape,y)viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)#不想一个一个取,想一个batch一个batch取for x,y in loader:viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))time.sleep(10)

自定义神经网络模型的搭建

import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):def __init__(self,ch_in,ch_out,stride=1):''':param ch_in::param ch_out:'''super(ResBlk,self).__init__()#super方法避免父类的init函数被替换self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)self.bn1=nn.BatchNorm2d(ch_out)self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)self.bn2=nn.BatchNorm2d(ch_out)self.extra=nn.Sequential()if ch_out!=ch_in:self.extra=nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),#Sequential里面加入的都是类,因此可以自己写,然后加入nn.BatchNorm2d(ch_out)#这些类在调用时会自动调用forward函数,记得要写return)def forward(self,x):''':param x:[b,ch,h,w]:return:'''out=F.relu(self.bn1(self.conv1(x)))out=self.bn2(self.conv2(out))#short cut#extra module:[b,ch_in,h,w] with [b,ch_out,h,w]#element-wise add:\out=self.extra(x)+outreturn outclass ResNet18(nn.Module):def __init__(self,num_class):super(ResNet18,self).__init__()self.conv1=nn.Sequential(nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(16))#followed 4 blocks#[b,16,h,w]=>[b,32,h,w]self.blk1=ResBlk(16,32,stride=3)#增多通道,减少长宽,避免数据量过大#[b,32,h,w]=>[b,64,h,w]self.blk2 = ResBlk(32, 64,stride=3)#[b,64,h,w]=>[b,128,h,w]self.blk3 = ResBlk(64,128,stride=2)# [b,128,h,w]=>[b,256,h,w]self.blk4 = ResBlk(128,256,stride=2)#[b,256,7,7]self.outlayer=nn.Linear(256*3*3,num_class)#输入512通道,输出10通道def forward(self,x):''':param x::return:'''x=F.relu(self.conv1(x))#[b,64,h,w]=>[b,1024,h,w]x=self.blk1(x)x=self.blk2(x)x=self.blk3(x)x=self.blk4(x)# print('after conv:',x.shape)#[b,512,2,2]#    # [b,512,h,w]=>[b,512,2,2]#    x=F.adaptive_avg_pool2d(x,[1,1])# #   print('after pool:',x.shape)x=x.view(x.size(0),-1)x=self.outlayer(x)return xif __name__=='__main__':blk=ResBlk(64,128)tmp=torch.randn(2,64,224,224)out=blk(tmp)print('block',out.shape)model=ResNet18(5)#5分类tmp=torch.randn(2,3,224,224)out=model(tmp)print('resnet:',out.shape)p=sum(map(lambda p:p.numel(),model.parameters()))print('parameters size:',p)

自定义数据集的训练与测试:

import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoaderfrom pytorch1.pt3 import Pokemon
from resnet import ResNet18batchsz=32
lr=1e-3
epochs=10device=torch.device('cuda')
torch.manual_seed(1234)train_db=Pokemon('pokemon/pokeman',224,mode='train')#initial函数初始化训练集,
val_db=Pokemon('pokemon/pokeman',224,mode='val')
test_db=Pokemon('pokemon/pokeman',224,mode='test')train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)#loader里获得的都是真正是图片
val_loader=DataLoader(val_db,batch_size=batchsz,num_workers=4)
test_loader=DataLoader(test_db,batch_size=batchsz,num_workers=4)viz = visdom.Visdom()def evaluate(model,loader):correct=0total=len(loader.dataset)for x,y in loader:x,y=x.to(device),y.to(device)with torch.no_grad():logits=model(x)pred=logits.argmax(dim=1)correct+=torch.eq(pred,y).sum().float().item()return correct/totalif __name__=='__main__':model=ResNet18(5).to(device)optimizer=optim.Adam(model.parameters(),lr=lr)criteon=nn.CrossEntropyLoss()best_acc,best_epoch=0,0global_step=0viz.line([0],[-1],win='loss',opts=dict(title='loss'))viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))#清空操作for epoch in range(epochs):for step,(x,y) in enumerate(train_loader):#x:[b,3,224,224],y:[b]x,y=x.to(device),y.to(device)logits=model(x)loss=criteon(logits,y)optimizer.zero_grad()loss.backward()optimizer.step()viz.line([loss.item()], [global_step], win='loss', update='append')global_step+=1if epoch%2==0:val_acc=evaluate(model,val_loader)if val_acc>best_acc:best_epoch=epochbest_acc=val_acctorch.save(model.state_dict(),'best.mdl')viz.line([val_acc], [global_step], win='val_acc',update='append')print('best acc:',best_acc,'best_epoch:',best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt!')test_acc=evaluate(model,test_loader)print('test acc:',test_acc)

实验效果:

Pytorch学习笔记7——自定义数据集相关推荐

  1. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  2. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

  3. PyTorch学习笔记(六):PyTorch进阶训练技巧

    PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...

  4. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  5. PyTorch学习笔记(三):PyTorch主要组成模块

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  6. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  7. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

  8. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

  9. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

最新文章

  1. 目标检测--Training Region-based Object Detectors with Online Hard Example Mining
  2. 分分钟玩转多进程编程
  3. BufferedInputStream的read方法原理
  4. Powershell指令集_1
  5. 产品观,来自微信张小龙的
  6. 写出float x 与“零值”比较的if语句
  7. apache win下安装
  8. 力扣(LeetCode)56
  9. android中menu菜单扩增_创意菜单效果
  10. java将ppt转换成图片,图片以幻灯片的备注命名
  11. word2016毕业论文不同章节设置不同页眉方法
  12. 2023 新版 微信公众号无限回调系统源码
  13. 英语计算机自我介绍范文面试,计算机面试英文自我介绍范例
  14. 解决Ubuntu18.04 / 16.04和Win10双系统系统时间时间不同步
  15. Wordvice推出人工智能工具 免费论文润色功能受欢迎
  16. No adapter attached; skipping layout 原因、解决办法
  17. SM2 (含SM3、SM4)国密算法工具QT版,彻底搞懂sm2算法的使用
  18. LoadWebOffice实现在线编辑Word
  19. Autofac 批量注入
  20. 瑞萨电子Rcar-H3的qnx系统开发

热门文章

  1. 在C#中使用WIA获取扫描仪数据
  2. 文本内容相似度计算方法:simhash
  3. Swift表格Lxr
  4. 三款超火的国外壁纸应用,让你每天都用新手机
  5. flutter 如何判断页面渲染完毕
  6. 获取超级用户访问权限-redhat 7.2
  7. photoshop---压缩图片大小/给人物换衣服
  8. ps 图片从竖屏拉伸成横屏
  9. 论文阅读:Regularizing Deep Networks with Semantic Data Augmentation
  10. 淘宝最基础的优化:标题优化