Pytorch学习笔记7——自定义数据集
Pytorch学习笔记7——自定义数据集
1.读取数据
首先继承自torch.utils.data.Dataset
重写len与getitem
train就用train数据集,test就用test数据集。
自定义数据集的读取
import torch
import os,glob
import random,csvimport torchvision.datasets
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
from torchvision.transforms import InterpolationModeclass Pokemon(Dataset):def __init__(self,root,resize,mode):super(Pokemon, self).__init__()self.root=rootself.resize=resizeself.name2label={}#字典表达映射关系label‘‘sq..’’:0for name in sorted(os.listdir(os.path.join(root))):#遍历根目录下所有文件假if not os.path.isdir(os.path.join(root,name)):#判断是否是文件夹continueself.name2label[name]=len(self.name2label.keys())print(self.name2label)self.images,self.labels=self.load_csv('images.csv')#得到的是images的路径,和对应的数字标签if mode=='train':self.images=self.images[:int(0.6*len(self.images))]self.labels=self.labels[:int(0.6*len(self.labels))]elif mode=='val':#20self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]else:self.images=self.images[int(0.8*len(self.images)):]self.labels=self.labels[int(0.8*len(self.labels)):]#创建数据对:path+labeldef load_csv(self,filename):if not os.path.exists(os.path.join(self.root,filename)):#如果已经有了,不需要再创建images=[]for name in self.name2label.keys():#key:valueimages+=glob.glob(os.path.join(self.root,name,'*.png'))#glob方法获取目录下所有满足的文件images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))#1165,pokeman/bulbasaur/00001.png#对应关系保存到csvrandom.shuffle(images)with open(os.path.join(self.root,filename),mode='w',newline='') as f:writer=csv.writer(f)for img in images:#pokeman/bulbasaur/00001.pngname=img.split(os.sep)[-2]label=self.name2label[name]#字典根据key找value存入labelwriter.writerow([img,label])#pokeman/bulbasaur/00001.png,0print('writen into csvfile:',filename)#read from csvimages,labels=[],[]with open(os.path.join(self.root,filename)) as f:reader=csv.reader(f)for row in reader:img,label=rowlabel=int(label)images.append(img)labels.append(label)assert len(images)==len(labels)return images,labelsdef __len__(self):return len(self.images)def denormalize(self,x_hat):#逆归一化已回复图片视觉效果mean = [0.845, 0.456, 0.406]std = [0.229, 0.224, 0.225]#x_hat=(x-mean)/std#x=x_hat*std+mean#x:[c,h,w]#mean:[3]=>[3,1,1]mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)print(mean.shape,std.shape)x=x_hat*std+meanreturn xdef __getitem__(self, idx):#self.images,self.labels#idx-[0-len(images)]img,label=self.images[idx],self.labels[idx]#从csv获得的图片路径与labeltf=transforms.Compose([lambda x:Image.open(x).convert('RGB'),#string path=>image datatransforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.845,0.456,0.406],std=[0.229,0.224,0.225])])img=tf(img)label=torch.tensor(label)return img,label
if __name__=='__main__':import visdomimport timeviz=visdom.Visdom()# tf = transforms.Compose([# transforms.Resize((64,64)),# transforms.ToTensor(),# ])# db=torchvision.datasets.ImageFolder(root='/home/lizheng/Study/yolov5-5.0/pytorch1/pokemon/pokeman',transform=tf)# loader=DataLoader(db,batch_size=32,shuffle=True)#一行代码完成数据集加载工作# print(db.class_to_idx)# for x,y in loader:# viz.images(x,nrow=8,win='batch',opts=dict(title='batch'))# viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))## time.sleep(10)db=Pokemon('pokemon/pokeman',64,'train')x,y=next(iter(db))#利用迭代器输入路径获得具体图像,得到第一个样本,调用时自动使用getitem函数,此时x是图像print('sample:',x.shape,y.shape,y)viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)#不想一个一个取,想一个batch一个batch取for x,y in loader:viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))time.sleep(10)
自定义神经网络模型的搭建
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):def __init__(self,ch_in,ch_out,stride=1):''':param ch_in::param ch_out:'''super(ResBlk,self).__init__()#super方法避免父类的init函数被替换self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)self.bn1=nn.BatchNorm2d(ch_out)self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)self.bn2=nn.BatchNorm2d(ch_out)self.extra=nn.Sequential()if ch_out!=ch_in:self.extra=nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),#Sequential里面加入的都是类,因此可以自己写,然后加入nn.BatchNorm2d(ch_out)#这些类在调用时会自动调用forward函数,记得要写return)def forward(self,x):''':param x:[b,ch,h,w]:return:'''out=F.relu(self.bn1(self.conv1(x)))out=self.bn2(self.conv2(out))#short cut#extra module:[b,ch_in,h,w] with [b,ch_out,h,w]#element-wise add:\out=self.extra(x)+outreturn outclass ResNet18(nn.Module):def __init__(self,num_class):super(ResNet18,self).__init__()self.conv1=nn.Sequential(nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(16))#followed 4 blocks#[b,16,h,w]=>[b,32,h,w]self.blk1=ResBlk(16,32,stride=3)#增多通道,减少长宽,避免数据量过大#[b,32,h,w]=>[b,64,h,w]self.blk2 = ResBlk(32, 64,stride=3)#[b,64,h,w]=>[b,128,h,w]self.blk3 = ResBlk(64,128,stride=2)# [b,128,h,w]=>[b,256,h,w]self.blk4 = ResBlk(128,256,stride=2)#[b,256,7,7]self.outlayer=nn.Linear(256*3*3,num_class)#输入512通道,输出10通道def forward(self,x):''':param x::return:'''x=F.relu(self.conv1(x))#[b,64,h,w]=>[b,1024,h,w]x=self.blk1(x)x=self.blk2(x)x=self.blk3(x)x=self.blk4(x)# print('after conv:',x.shape)#[b,512,2,2]# # [b,512,h,w]=>[b,512,2,2]# x=F.adaptive_avg_pool2d(x,[1,1])# # print('after pool:',x.shape)x=x.view(x.size(0),-1)x=self.outlayer(x)return xif __name__=='__main__':blk=ResBlk(64,128)tmp=torch.randn(2,64,224,224)out=blk(tmp)print('block',out.shape)model=ResNet18(5)#5分类tmp=torch.randn(2,3,224,224)out=model(tmp)print('resnet:',out.shape)p=sum(map(lambda p:p.numel(),model.parameters()))print('parameters size:',p)
自定义数据集的训练与测试:
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoaderfrom pytorch1.pt3 import Pokemon
from resnet import ResNet18batchsz=32
lr=1e-3
epochs=10device=torch.device('cuda')
torch.manual_seed(1234)train_db=Pokemon('pokemon/pokeman',224,mode='train')#initial函数初始化训练集,
val_db=Pokemon('pokemon/pokeman',224,mode='val')
test_db=Pokemon('pokemon/pokeman',224,mode='test')train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)#loader里获得的都是真正是图片
val_loader=DataLoader(val_db,batch_size=batchsz,num_workers=4)
test_loader=DataLoader(test_db,batch_size=batchsz,num_workers=4)viz = visdom.Visdom()def evaluate(model,loader):correct=0total=len(loader.dataset)for x,y in loader:x,y=x.to(device),y.to(device)with torch.no_grad():logits=model(x)pred=logits.argmax(dim=1)correct+=torch.eq(pred,y).sum().float().item()return correct/totalif __name__=='__main__':model=ResNet18(5).to(device)optimizer=optim.Adam(model.parameters(),lr=lr)criteon=nn.CrossEntropyLoss()best_acc,best_epoch=0,0global_step=0viz.line([0],[-1],win='loss',opts=dict(title='loss'))viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))#清空操作for epoch in range(epochs):for step,(x,y) in enumerate(train_loader):#x:[b,3,224,224],y:[b]x,y=x.to(device),y.to(device)logits=model(x)loss=criteon(logits,y)optimizer.zero_grad()loss.backward()optimizer.step()viz.line([loss.item()], [global_step], win='loss', update='append')global_step+=1if epoch%2==0:val_acc=evaluate(model,val_loader)if val_acc>best_acc:best_epoch=epochbest_acc=val_acctorch.save(model.state_dict(),'best.mdl')viz.line([val_acc], [global_step], win='val_acc',update='append')print('best acc:',best_acc,'best_epoch:',best_epoch)model.load_state_dict(torch.load('best.mdl'))print('loaded from ckpt!')test_acc=evaluate(model,test_loader)print('test acc:',test_acc)
实验效果:
Pytorch学习笔记7——自定义数据集相关推荐
- 深度学习入门之PyTorch学习笔记:卷积神经网络
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...
- Pytorch学习笔记总结
往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...
- PyTorch学习笔记(六):PyTorch进阶训练技巧
PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...
- PyTorch学习笔记(四):PyTorch基础实战
PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...
- PyTorch学习笔记(三):PyTorch主要组成模块
往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...
- PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard
文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...
- PyTorch学习笔记(二)——回归
PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...
- pytorch 学习笔记目录
1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...
- 深度学习入门之PyTorch学习笔记:多层全连接网络
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...
最新文章
- 目标检测--Training Region-based Object Detectors with Online Hard Example Mining
- 分分钟玩转多进程编程
- BufferedInputStream的read方法原理
- Powershell指令集_1
- 产品观,来自微信张小龙的
- 写出float x 与“零值”比较的if语句
- apache win下安装
- 力扣(LeetCode)56
- android中menu菜单扩增_创意菜单效果
- java将ppt转换成图片,图片以幻灯片的备注命名
- word2016毕业论文不同章节设置不同页眉方法
- 2023 新版 微信公众号无限回调系统源码
- 英语计算机自我介绍范文面试,计算机面试英文自我介绍范例
- 解决Ubuntu18.04 / 16.04和Win10双系统系统时间时间不同步
- Wordvice推出人工智能工具 免费论文润色功能受欢迎
- No adapter attached; skipping layout 原因、解决办法
- SM2 (含SM3、SM4)国密算法工具QT版,彻底搞懂sm2算法的使用
- LoadWebOffice实现在线编辑Word
- Autofac 批量注入
- 瑞萨电子Rcar-H3的qnx系统开发