Pytorch深度学习(五):加载数据集以及mini-batch的使用

  • 参考B站课程:《PyTorch深度学习实践》完结合集
  • 传送门:《PyTorch深度学习实践》完结合集

一、预备知识

  • Dataset是一个抽象函数,不能直接实例化,所以我们要创建一个自己类,继承Dataset
    继承Dataset后我们必须实现三个函数:
    init()是初始化函数,之后我们可以提供数据集路径进行数据的加载
    getitem()帮助我们通过索引找到某个样本
    len()帮助我们返回数据集大小
class DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0]self.xdata = torch.from_numpy(xy[:, :-1])self.ydata = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.xdata[index], self.ydata[index]def __len__(self):return self.len
  • DataLoader为数据进行分组,batch_size是一个组中有多少个样本,shuffle表示要不要对样本进行随机排列。一般来说,训练集我们随机排列,测试集不需要。num_workers表示我们可以用多少进程并行的运算,由于我的版本原因(cuda不好使),只能选择num_workers=0,一般可以写num_workers=2,进行并行运算算提高速度。
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)


示意图中选择了batch_size=2,于是取每两组样本为mini-batch。而在程序中,我们选取了batch_size=32,于是取每32个样本为一个mini-batch,最后mini-batch根据具体的样本总数决定其包含的样本数量。

二、程序实现

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as pltclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0]self.xdata = torch.from_numpy(xy[:, :-1])self.ydata = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.xdata[index], self.ydata[index]def __len__(self):return self.lendataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)xtest = dataset.xdata
ytest = dataset.ydataclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)costlist = []
acclist = []
# 若使用的是非linux系统,则以下循环部分需要封装
for epoch in range(10000):l = 0for i, data in enumerate(train_loader, 0):# 1. Prepare datainputs, labels = data# 2. Forwardypred = model(inputs)loss = criterion(ypred, labels)l += loss.item()#print(epoch, i, loss.item())# 3. Backwardoptimizer.zero_grad()loss.backward()# 4. Updata optimizer.step()costlist.append(l / len(inputs))# 每迭代1000次测试一次精确度if epoch % 1000 == 999:ypredtest = model(xtest)ypredlabel = torch.where(ypredtest>0.5, torch.tensor([1]), torch.tensor([0]))acc = torch.eq(ypredlabel, ytest).sum().item() / ytest.size(0)acclist.append(acc)print('the accuracy of testdataset:', acc)plt.figure(figsize=(10,4))
plt.subplot(1, 2, 1)
plt.plot(range(10000), costlist)
plt.title('error of mini-batch')
plt.xlabel('epoch')
plt.ylabel('loss')plt.subplot(1, 2, 2)
plt.plot(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])*1000 , acclist)
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('the accuracy of test dataset')plt.show()
  • 输出结果
the accuracy of testdataset: 0.8234519104084321
the accuracy of testdataset: 0.8313570487483531
the accuracy of testdataset: 0.855072463768116
the accuracy of testdataset: 0.8629776021080369
the accuracy of testdataset: 0.8669301712779973
the accuracy of testdataset: 0.8682476943346509
the accuracy of testdataset: 0.8722002635046113
the accuracy of testdataset: 0.8748353096179183
the accuracy of testdataset: 0.8761528326745718
the accuracy of testdataset: 0.8774703557312253
  • 输出图片


由于我的计算机的性能的限制,且mini-batch的使用对于计算力的耗费更大,所以我们只计算了10000步,更长的步数需要花的时间更多,预计迭代步数突破10w后,会得到一个准确率更高的好结果。

Pytorch深度学习(五):加载数据集以及mini-batch的使用相关推荐

  1. NVIDIA GEFORCE 2080 / 2080 SUPER / 2080 Ti + CUDA Toolkit 8.0 深度学习模型加载速度慢

    NVIDIA GEFORCE 2080 / 2080 SUPER / 2080 Ti + CUDA Toolkit 8.0 深度学习模型加载速度慢 (卡顿) GEFORCE RTX 2080 / GE ...

  2. 【深度学习-数据加载优化-训练速度提升一倍】

    1,介绍 数据加载 深度学习的训练,简单的说就是将数据切分成batch,丢入模型中,并计算loss训练.其中比较重要的一环是数据打batch部分(数据加载部分). 训练时间优化: 深度学习训练往往需要 ...

  3. 如何将深度学习模型加载到android环境中

    承接上一篇的内容,考虑如何将深度学习的模型加载到android app中 文章目录 前言 一.使用工具 二.使用步骤 1.模型格式的转换 2.配置文件修改 3. 应用程序 前言 将图片学习的模型加载到 ...

  4. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  5. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  6. pytorch中的数据加载(dataset基类,以及pytorch自带数据集)

    目录 pytorch中的数据加载 模型中使用数据加载器的目的 数据集类 Dataset基类介绍 数据加载案例 数据加载器类 pytorch自带的数据集 torchvision.datasets MIN ...

  7. 【学习系列7】Pytorch中的数据加载

    目录 1. 模型中使用数据加载器的目的 2. 数据集类 3. 迭代数据集 1. 模型中使用数据加载器的目的 在前面的线性回归横型中,我们使用的数据很少,所以直接把全部数据放到锁型中去使用. 但是在深度 ...

  8. Pytorch加载数据集的方式总结

    Pytorch加载数据集的方式总结 一.自己重写定义(Dataset.DataLoader) 二.用Pytorch自带的类(ImageFolder.datasets.DataLoader) 2.1 加 ...

  9. 【PyTorch深度学习项目实战100例目录】项目详解 + 数据集 + 完整源码

    前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...

最新文章

  1. 树形结构在关系数据库中的设计
  2. react引入多个图片_重新引入React:v16之后的每个React更新都已揭开神秘面纱。
  3. 走火入魔 | 暑期电子设计课程学生们的作品
  4. Makefile_04:Makefile变量初了解
  5. MySql数据库查询表信息/列信息(列ID/列名/数据类型/长度/精度/是否可以为null/默认值/是否自增/是否是主键/列描述)...
  6. java int64如何定义_java – 具有两个int属性的自定义类的hashCode是什么?
  7. 仿iReader-按menu键弹出PopupWindow布局界面
  8. new to python什么意思_Python中__new__的作用
  9. 【讨论】对技术的掌握到底应该又多深?
  10. 2020八年级计算机会考计划,初二下学期学习计划2020
  11. 模板题——前缀和与差分
  12. 干货!假新闻检测:观察新闻本身,更要观察它所在的新闻环境
  13. 重构分析21: 被拒绝的遗赠(Refused Bequest)
  14. 市场调研-全球与中国汽车零部件涂层市场现状及未来发展趋势
  15. SDUT-3337 计算长方体、四棱锥的表面积和体积(JAVA*)
  16. 20个2013年最值得关注的网页设计趋势
  17. Aria2远程下载方案部署(CentOS7+Aria+AriaNG+Nginx)
  18. win7系统如何添加打印机服务器,怎样如何添加打印机驱动步骤
  19. jeesite集群和负载均衡配置
  20. java双音频文件分频_分频电路作用,怎么来理解二分频电路?

热门文章

  1. 数字游戏ABCD*E=DCBA-第11届蓝桥杯Scratch选拔赛真题精选
  2. Python3,11行代码解密摩斯电码,真実はいつもひとつ。
  3. 美国诚实签经验贴汇总
  4. 流程图用什么软件做?这篇文章告诉你(内附详细教程)
  5. 心脏滴血(CVE-2014-0160)
  6. Vue的props的三种写法
  7. 阿里云服务器CentOS搭建
  8. IMDB数据集allow_pickle=False问题
  9. YOLOV3 网络结构学习笔记
  10. ARTIX-7 XC7A35T实验项目之 串口发送