Pytorch深度学习(五):加载数据集以及mini-batch的使用
Pytorch深度学习(五):加载数据集以及mini-batch的使用
- 参考B站课程:《PyTorch深度学习实践》完结合集
- 传送门:《PyTorch深度学习实践》完结合集
一、预备知识
- Dataset是一个抽象函数,不能直接实例化,所以我们要创建一个自己类,继承Dataset
继承Dataset后我们必须实现三个函数:
init()是初始化函数,之后我们可以提供数据集路径进行数据的加载
getitem()帮助我们通过索引找到某个样本
len()帮助我们返回数据集大小
class DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0]self.xdata = torch.from_numpy(xy[:, :-1])self.ydata = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.xdata[index], self.ydata[index]def __len__(self):return self.len
- 用DataLoader为数据进行分组,batch_size是一个组中有多少个样本,shuffle表示要不要对样本进行随机排列。一般来说,训练集我们随机排列,测试集不需要。num_workers表示我们可以用多少进程并行的运算,由于我的版本原因(cuda不好使),只能选择num_workers=0,一般可以写num_workers=2,进行并行运算算提高速度。
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)
示意图中选择了batch_size=2,于是取每两组样本为mini-batch。而在程序中,我们选取了batch_size=32,于是取每32个样本为一个mini-batch,最后mini-batch根据具体的样本总数决定其包含的样本数量。
二、程序实现
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as pltclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0]self.xdata = torch.from_numpy(xy[:, :-1])self.ydata = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.xdata[index], self.ydata[index]def __len__(self):return self.lendataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)xtest = dataset.xdata
ytest = dataset.ydataclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)costlist = []
acclist = []
# 若使用的是非linux系统,则以下循环部分需要封装
for epoch in range(10000):l = 0for i, data in enumerate(train_loader, 0):# 1. Prepare datainputs, labels = data# 2. Forwardypred = model(inputs)loss = criterion(ypred, labels)l += loss.item()#print(epoch, i, loss.item())# 3. Backwardoptimizer.zero_grad()loss.backward()# 4. Updata optimizer.step()costlist.append(l / len(inputs))# 每迭代1000次测试一次精确度if epoch % 1000 == 999:ypredtest = model(xtest)ypredlabel = torch.where(ypredtest>0.5, torch.tensor([1]), torch.tensor([0]))acc = torch.eq(ypredlabel, ytest).sum().item() / ytest.size(0)acclist.append(acc)print('the accuracy of testdataset:', acc)plt.figure(figsize=(10,4))
plt.subplot(1, 2, 1)
plt.plot(range(10000), costlist)
plt.title('error of mini-batch')
plt.xlabel('epoch')
plt.ylabel('loss')plt.subplot(1, 2, 2)
plt.plot(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])*1000 , acclist)
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('the accuracy of test dataset')plt.show()
- 输出结果
the accuracy of testdataset: 0.8234519104084321
the accuracy of testdataset: 0.8313570487483531
the accuracy of testdataset: 0.855072463768116
the accuracy of testdataset: 0.8629776021080369
the accuracy of testdataset: 0.8669301712779973
the accuracy of testdataset: 0.8682476943346509
the accuracy of testdataset: 0.8722002635046113
the accuracy of testdataset: 0.8748353096179183
the accuracy of testdataset: 0.8761528326745718
the accuracy of testdataset: 0.8774703557312253
- 输出图片
由于我的计算机的性能的限制,且mini-batch的使用对于计算力的耗费更大,所以我们只计算了10000步,更长的步数需要花的时间更多,预计迭代步数突破10w后,会得到一个准确率更高的好结果。
Pytorch深度学习(五):加载数据集以及mini-batch的使用相关推荐
- NVIDIA GEFORCE 2080 / 2080 SUPER / 2080 Ti + CUDA Toolkit 8.0 深度学习模型加载速度慢
NVIDIA GEFORCE 2080 / 2080 SUPER / 2080 Ti + CUDA Toolkit 8.0 深度学习模型加载速度慢 (卡顿) GEFORCE RTX 2080 / GE ...
- 【深度学习-数据加载优化-训练速度提升一倍】
1,介绍 数据加载 深度学习的训练,简单的说就是将数据切分成batch,丢入模型中,并计算loss训练.其中比较重要的一环是数据打batch部分(数据加载部分). 训练时间优化: 深度学习训练往往需要 ...
- 如何将深度学习模型加载到android环境中
承接上一篇的内容,考虑如何将深度学习的模型加载到android app中 文章目录 前言 一.使用工具 二.使用步骤 1.模型格式的转换 2.配置文件修改 3. 应用程序 前言 将图片学习的模型加载到 ...
- 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)
目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...
- pytorch 入门学习加载数据集-8
pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...
- pytorch中的数据加载(dataset基类,以及pytorch自带数据集)
目录 pytorch中的数据加载 模型中使用数据加载器的目的 数据集类 Dataset基类介绍 数据加载案例 数据加载器类 pytorch自带的数据集 torchvision.datasets MIN ...
- 【学习系列7】Pytorch中的数据加载
目录 1. 模型中使用数据加载器的目的 2. 数据集类 3. 迭代数据集 1. 模型中使用数据加载器的目的 在前面的线性回归横型中,我们使用的数据很少,所以直接把全部数据放到锁型中去使用. 但是在深度 ...
- Pytorch加载数据集的方式总结
Pytorch加载数据集的方式总结 一.自己重写定义(Dataset.DataLoader) 二.用Pytorch自带的类(ImageFolder.datasets.DataLoader) 2.1 加 ...
- 【PyTorch深度学习项目实战100例目录】项目详解 + 数据集 + 完整源码
前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...
最新文章
- 树形结构在关系数据库中的设计
- react引入多个图片_重新引入React:v16之后的每个React更新都已揭开神秘面纱。
- 走火入魔 | 暑期电子设计课程学生们的作品
- Makefile_04:Makefile变量初了解
- MySql数据库查询表信息/列信息(列ID/列名/数据类型/长度/精度/是否可以为null/默认值/是否自增/是否是主键/列描述)...
- java int64如何定义_java – 具有两个int属性的自定义类的hashCode是什么?
- 仿iReader-按menu键弹出PopupWindow布局界面
- new to python什么意思_Python中__new__的作用
- 【讨论】对技术的掌握到底应该又多深?
- 2020八年级计算机会考计划,初二下学期学习计划2020
- 模板题——前缀和与差分
- 干货!假新闻检测:观察新闻本身,更要观察它所在的新闻环境
- 重构分析21: 被拒绝的遗赠(Refused Bequest)
- 市场调研-全球与中国汽车零部件涂层市场现状及未来发展趋势
- SDUT-3337 计算长方体、四棱锥的表面积和体积(JAVA*)
- 20个2013年最值得关注的网页设计趋势
- Aria2远程下载方案部署(CentOS7+Aria+AriaNG+Nginx)
- win7系统如何添加打印机服务器,怎样如何添加打印机驱动步骤
- jeesite集群和负载均衡配置
- java双音频文件分频_分频电路作用,怎么来理解二分频电路?