目录

  • Pokemon Dataset
  • 数据集加载
    • 自定义数据集
    • 数据预处理
    • 图像数据存储结构
    • 代码
  • 构建模型
  • 训练模型
  • 迁移学习

收集、读取、预处理数据,模型搭建、训练。

Pokemon Dataset


数据集加载

自定义数据集

__len__()函数返回数据集的数量,限制数据集迭代次数;
__getitem__索引样本;

import torch
from  torch.utils.data import Dataset class NumberDataset(Dataset):def __init__(self, training = True) -> None:super().__init__()self.training = trainingif training:self.samples = list(range(1, 1001))else:self.samples = list(range(1001, 1501))def __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]if __name__ == '__main__':data = NumberDataset(True)print(len(data))print(data[10])输出:
1000
11

数据预处理

  1. resize
  2. 数据增强
    增加数据集规模,辅助性提升一部分性能;
  3. 归一化
    将数据分布缩放为一个指定均值和方差的正态分布;
  4. 转换为Tensor
    将其它数据类型转换为pytorch的Tensor

图像数据存储结构

推荐采用一个label文件夹存储该label的图像;pytorch易于管理,它提供了一个API可以直接读取出这种存储结构的数据,而不用我们人为去写一个读取这些数据的代码;

代码

import torch
import os, csv
import random, glob
from  torch.utils.data import Dataset
import visdom,time
from torchvision import transforms
from PIL import Imageclass Pokemon(Dataset):def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = dict()  # 将string转换为labelfor name in sorted(os.listdir(root)):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys())# image, label  将图像数据和label一一对应self.images, self.labels = self.load_csv('img_label.csv')if mode == 'train': # 60%self.images = self.images[:int(0.6*len(self.images))]  #取数据集的前60%作为训练集self.labels = self.labels[:int(0.6*len(self.labels))]elif mode == 'val': # 20% : 60%->80%self.images = self.images[int(0.6*len(self.images)): int(0.8*len(self.images))]  #取数据集的前60%-80%作为验证集self.labels = self.labels[int(0.6*len(self.labels)): int(0.8*len(self.labels))]else:  #20% : 80%->100%self.images = self.images[int(0.8*len(self.images)):]  #取数据集的最后20%作为测试集self.labels = self.labels[int(0.8*len(self.labels)):]def load_csv(self, filename):if not os.path.exists(os.path.join(self.root, filename)):images = []#将每一类图像的path提取出来存入imagefor name in self.name2label.keys():images += glob.glob(os.path.join(self.root, name, '*.png'))  #Return a list of paths matching a pathname pattern.images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))print(len(images))  #保存image path和label的对应关系,这里保存到csv文件中,节约内存with open(os.path.join(self.root, filename), mode = 'w', newline='') as f:writer = csv.writer(f)for img in images:name = img.split(os.sep)[-2]label = self.name2label[name]writer.writerow([img, label])print('write to csv file:', filename)images = []labels = []#将image path和label的对应关系再重新读取出来with open(os.path.join(self.root, filename), mode='r') as f:reader = csv.reader(f)for row in reader:img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images,labelsdef denormalize(self, x):  #c,h,wmean = [0.485, 0.456, 0.406] # cstd = [0.229, 0.224, 0.225]  # cx = x*(torch.tensor(std).unsqueeze(1).unsqueeze(1)) + \torch.tensor(mean).unsqueeze(1).unsqueeze(1)return xdef __len__(self):return len(self.images)def __getitem__(self, index):# indx: [0~len(slef.images)]img_path = self.images[index]label = self.labels[index]tf = transforms.Compose([lambda x: Image.open(img_path).convert('RGB'), # string path => imagetransforms.Resize((int(self.resize*1.5), int(self.resize*1.5))),transforms.RandomRotation(15), #如果旋转角度太大,可能会导致网络不收敛transforms.CenterCrop(self.resize), #裁剪成一个指定大小的形状transforms.ToTensor(),transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])])img = tf(img_path)label = torch.tensor(label)return img, labeldef main():vis = visdom.Visdom()data = Pokemon('G:\BaiduNetdiskDownload\pokemon\pokeman', 224, 'train')x,y = next(iter(data))print(x.shape)print(y.shape)vis.image(data.denormalize(x), win = 'sample_x', opts=dict(title='sample_x'))
if __name__ == '__main__':main()输出:
Setting up a new session...
torch.Size([3, 224, 224])
torch.Size([])

visdom:http://localhost:8097/

上述代码,可以通过torchvision.datasets.ImageFolder实现数据的读取和封装;
这种方式不是适合所有情况,只适合数据非常规整的存储了,并且如果对数据有一些额外的操作,还是要自己定义数据类。

import torchvisiontf = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor(),])db = torchvision.datasets.ImageFolder(root='G:\BaiduNetdiskDownload\pokemon\pokeman', transform=tf)print(len(db))x,y = next(iter(db))print(x.shape)print(y)print(db.class_to_idx)vis.image(x, win = 'sample_x', opts=dict(title='sample_x'))
输出:
1167
torch.Size([3, 64, 64])
0
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}

visdom:http://localhost:8097/

构建模型

ResNet18

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ResBlk(nn.Module):def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride= stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.extra = nn.Sequential()if in_channels != out_channels:self.extra = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),nn.BatchNorm2d(out_channels))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out = F.relu(out + self.extra(x))return outclass ResNet(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=0),nn.BatchNorm2d(16))# fllow 4 blocksself.block1 = ResBlk(16, 32 ,2)self.block2 = ResBlk(32, 64, 2)self.block3 = ResBlk(64, 128, 2)self.block4 = ResBlk(128, 256, 2)self.pool = nn.AdaptiveAvgPool2d((1,1))self.outlayer = nn.Linear(256, 5)def forward(self, x):x = self.conv1(x)x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.pool(x).flatten(1)logits = self.outlayer(x)return logits
if __name__ == '__main__':x = torch.rand(3,3, 224, 224)model = ResNet()out = model(x)p = sum(map(lambda p:p.numel(), model.parameters())) # torch.numel()函数,查看一个张量有多少元素print(out.shape)print('parameters size:', p)输出:
torch.Size([3, 5])
parameters size: 1224645

训练模型

import torch
import torch.nn as nn
import torchvision
import visdom
from torch.utils.data import DataLoader
from pokemon import Pokemon
from resnet import ResNetbatch_size = 32
lr = 1e-3
epoches = 10
device = torch.device('cpu')
torch.manual_seed(1234)
root = 'G:\BaiduNetdiskDownload\pokemon\pokeman'def evaluate(model, loader):correct = 0.for x,y in loader:x, y  =  x.to(device), y.to(device)with torch.no_grad():logits = model(x)preds = logits.argmax(dim = 1)correct += preds.eq(y).sum().float().item() print('total correct:', correct)acc = correct / len(loader.dataset)return accdef train():vis = visdom.Visdom()train_db = Pokemon(root, 224, 'train')val_db = Pokemon(root, 224, 'val')test_db = Pokemon(root, 224, 'test')print('train:', len(train_db))print('val:', len(val_db))print('test:', len(test_db))train_loader = DataLoader(train_db, batch_size=batch_size, shuffle = True, num_workers=4)val_loader = DataLoader(val_db, batch_size=batch_size, num_workers=2)test_loader = DataLoader(test_db, batch_size=batch_size, num_workers=2)model = ResNet().to(device)criteon = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr = lr)best_acc = 0.best_epoch = 0global_step = 0for epoch in range(epoches):model.train()for step, (x,y) in enumerate(train_loader):x,y = x.to(device), y.to(device)# print(y)logits = model(x)loss = criteon(logits, y)optimizer.zero_grad()loss.backward()optimizer.step()global_step += 1if step %2 == 0:print('epoch:[{}/{}]\tloss:{}'.format(step, epoch, loss.item()))vis.line([loss.item()], [global_step], win='loss', update='append')# evaluatemodel.eval()val_acc = evaluate(model, val_loader)print('epoch:[{}]\t accuracy:{}'.format(epoch, val_acc))vis.line([val_acc], [global_step], win='val_acc', update='append')if val_acc > best_acc:torch.save(model.state_dict(), 'best_model.pth')best_acc = val_accbest_epoch = epochprint('save model....')# testbest_model = ResNet()best_model.load_state_dict(torch.load('best_model.pth'))best_acc = evaluate(best_model, test_loader)print('best acc:', best_acc, 'best epoch:', best_epoch)if __name__ == '__main__':train()输出:
...
epoch:[10/1]    loss:0.32095620036125183
epoch:[12/1]    loss:0.8680893182754517
epoch:[14/1]    loss:0.5944045782089233
epoch:[16/1]    loss:0.8467034101486206
epoch:[18/1]    loss:0.442536860704422
epoch:[20/1]    loss:0.5616939663887024
total correct: 192.0
epoch:[1]        accuracy:0.8240343347639485
save model....
epoch:[0/2]     loss:0.5097338557243347
...

训练过程的visdom:
loss曲线

acc曲线;

迁移学习

图像数据集与源域数据集存在比较多的重合的话(或者分布相似)比如ImageNet,那么可以使用源域数据集训练好的模型来辅助现在的特定任务,即将在A任务上训练好一个分类器,然后transfer到B任务上去;在B任务上叫微调,finetuning;

from torchvision.models import resnet18
class Flatten(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x):return x.flatten(1)trained_model = resnet18(pretrained=True).to(device)
model = nn.Sequential(*list(trained_model.children())[:-1],  #取resnet18前17层, 该层输出为[b,512,1,1]Flatten(),nn.Linear(512,5)
)输出:
...
epoch:[0/0]     loss:1.823586344718933
epoch:[2/0]     loss:0.30410656332969666
epoch:[4/0]     loss:0.7876781821250916
epoch:[6/0]     loss:0.8662126660346985
epoch:[8/0]     loss:0.5194013714790344
epoch:[10/0]    loss:0.390007346868515
...

部分visdom可视化:
loss曲线

acc曲线

pytorch基础(九)- 自定义数据集训练模型 和 迁移学习相关推荐

  1. 【Pytorch神经网络实战案例】24 基于迁移学习识别多种鸟类(CUB-200数据集)

    1 迁移学习 在实际开发中,常会使用迁移学习将预训练模型中的特征提取能力转移到自己的模型中. 1.1 迁移学习定义 迁移学习指将在一个任务上训练完成的模型进行简单的修改,再用另一个任务的数据继续训练, ...

  2. pytorch卷积神经网络_资源|卷积神经网络迁移学习pytorch实战推荐

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 一.资源简介 这次给大家推荐一篇关于卷积神经网络迁移学习的实战资料,卷积神经网络迁移学 ...

  3. PyTorch 1.0 中文官方教程:迁移学习教程

    译者:片刻 作者: Sasank Chilamkurthy 在本教程中,您将学习如何使用迁移学习来训练您的网络.您可以在 cs231n 笔记 上关于迁移学习的信息 引用这些笔记: 在实践中,很少有人从 ...

  4. 【yolo】yolov3的pytorch版本保存自定义数据集训练好的权重,并载入自己的模型

    多次试验终于测出来了!!很高兴,结果截图: 数据集是来自网上的,代码原型是github一个大概五千多star的pytorch-yolov3,但原代码并没有载入自己的模型进行训测试阶段,然后parser ...

  5. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  6. 【Pytorch实战6】一个完整的分类案例:迁移学习分类蚂蚁和蜜蜂(Res18,VGG16)

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch官方文档 本文是采用pytorch进行迁移学习的实战演练,实战目的是为了进一步学习和熟悉pyt ...

  7. 使用PyTorch进行迁移学习

    概述 迁移学习可以改变你建立机器学习和深度学习模型的方式 了解如何使用PyTorch进行迁移学习,以及如何将其与使用预训练的模型联系起来 我们将使用真实世界的数据集,并比较使用卷积神经网络(CNNs) ...

  8. 第1周学习笔记:深度学习和pytorch基础

    目录 一 视频学习 1.绪论 2.深度学习概述 二 代码学习 1.Pytorch基础练习 2.螺旋数据分类 一 视频学习 1.绪论 人工智能(Artificial Intelligence):使一部机 ...

  9. 迁移学习篇之如何迁移经典CNN网络-附迁移学习Alexnet,VGG,Googlenet,Resnet详细代码注释和方法-pytorch

    鸽了好久的迁移学习篇学习终于打算更新,这次我们来学习一个机器学习中经典常用的简单快速提高网络指标的trick,迁移学习,迁移学习本身是机器学习中的一个trick,但是近些年在深度学习中应用广泛.之前我 ...

最新文章

  1. FPGA之道(66)代码中的约束信息(三)存储器以及寄存器的相关约束
  2. WPF 控件库——仿制Windows10的进度条
  3. STL源码剖析 第八章 配接器
  4. 如何通过三视图判断立方体个数_装机小白看过来:如何通过显卡参数来判断高端低端?...
  5. Cannot get a connection, pool exhausted, cause: ValidateObject failed
  6. 用EasyRecovery怎么恢复电脑中已删除的视频
  7. python 字符串处理_python 数据清洗之字符串处理
  8. 让SQL2000的查询分析器能够直接编辑SQL2005的视图或存储过程
  9. 战旗html5播放器为什么卡顿,视频站启用html5播放器
  10. JavaScript形而上的For循环中的Break
  11. cacti监控H3C交换机
  12. 11月合资SUV销量:日系车统治榜单 大众产品攻势“拳意渐乱”
  13. 牛牛的猜球游戏(前缀和+逆交换)
  14. 如何把wps随机数据固定_WPS Excel:巧用随机函数rand和randbetween生成各种数据
  15. vscode win10笔记本 蓝屏_联想拯救者Win10蓝屏0xc000000d的解决办法
  16. 企业三层架构、冗余、STP生成树协议总结
  17. IDEA使用archetype创建Maven项目(只有两个archetype)
  18. Windows 10 开机不一会出现MEMORY_MANAGEMENT蓝屏
  19. python截长图_利用 Python + Selenium 实现对页面的指定元素截图(可截长图元素)
  20. 前端js获取系统更新刷新页面

热门文章

  1. css如何使文字抖动,CSS美化:实现抖音彩色文字抖动效果
  2. 【论文笔记】CondConv: Conditionally Parameterized Convolutions for Efficient Inference
  3. 扫雷c语言完整源代码,C语言扫雷源码
  4. Cain 不能显示外接网卡问题解决
  5. 深入理解微信二维码扫码登录的原理
  6. 学习日记5-C语言函数的应用
  7. pandas 报警告:A value is trying to be set on a copy of a slice from a DataFrame
  8. linux硬件命令大全,Linux硬件信息命令大全
  9. 关于ORA-12505, TNS:listener does not currently know of SID given in connect descript的一个解决思路
  10. 【合作伙伴大练兵-安全】NGFW盒式防火墙问题排查和维护