Pytorch(七) --加载数据集
主要用到了Pytorch中的Dataset和DataLoadder这两个方法,其中Dataset是抽象类,不能实例化对象,只能继承用于构造数据集,DataLoader是帮助加载数据的,可以做shuffle、batch_size,能拿Mini-batch进行训练。
代码如下:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoaderclass DiabetesDataset(Dataset):def __init__(self,filepath):xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)self.len = xy.shape[0]self.x_data=torch.from_numpy(xy[:,:-1])self.y_data = torch.from_numpy(xy[:,[-1]])def __getitem__(self,index):return self.x_data[index],self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('E:\\tmp\\.keras\\datasets\\diabetes.csv\\diabetes.csv')
train_loader = DataLoader(dataset=dataset,batch_size = 32,shuffle = True,num_workers=0
)#定义模型
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.relu = torch.nn.ReLU()self.sigmoid = torch.nn.Sigmoid()def forward(self,x):x = self.relu(self.linear1(x))x = self.relu(self.linear2(x))x = self.sigmoid(self.linear3(x))return x
model = Model()#构建优化器和损失函数
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
#训练
if __name__ == '__main__':for epoch in range(1000):for i,data in enumerate(train_loader,0):inputs,labels = datay_pred = model(inputs)loss=criterion(y_pred,labels)print(epoch,i,loss.item())optimizer.zero_grad()loss.backward()optimizer.step()
努力加油a啊
Pytorch(七) --加载数据集相关推荐
- 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)
目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...
- pytorch 入门学习加载数据集-8
pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...
- Pytorch深度学习(五):加载数据集以及mini-batch的使用
Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...
- Pytorch加载数据集的方式总结
Pytorch加载数据集的方式总结 一.自己重写定义(Dataset.DataLoader) 二.用Pytorch自带的类(ImageFolder.datasets.DataLoader) 2.1 加 ...
- pytorch创建自己的Dataset加载数据集
文章目录 创建一个类并继承torch.utils.data.dataset.Datase类 创建__getitem__方法 加载数据集 创建一个类并继承torch.utils.data.dataset ...
- PyTorch数据加载处理
PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...
- pytorch模型加载测试_使用Pytorch实现物体检测(Faster R-CNN)
在本示例中,介绍一种two-stage算法(Faster R-CNN),将目标区域检测和类别识别分为两个任务进行物体检测.本示例采用PyTorch引擎进行模型构建. 如果您已熟练使用Notebook和 ...
- PyTorch数据加载器
We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...
- python_torch_加载数据集_构建模型_构建训练循环_保存和调用训练好的模型
以下代码均来自bilibili:[适用于初学者的Pytorch编程教学] 以下为完整代码,复制即可运行. import torch import time import json import tor ...
最新文章
- 读博无门就业碰壁,孤独当了7个月“民科”后,我的论文中了顶会
- [Microsoft][ODBC SQL Server Driver][SQl Server]参数数据类型 text 对于 replace 函数的参数 1 无效。...
- WingIDE 5的安装与破解方法
- 14.线程安全?线程不安全?可重入函数?不可重入函数?
- 服务器上如何安装两个php网站,服务器安装两个php版本吗
- magento更新产品状态报错
- java中的传参是什么意思_如果作为参数传递,“字符串…参数”是什么意思?...
- 评分卡模型开发(五)--定性指标筛选
- 基础知识:计算机网络--《趣谈网络协议》读书笔记
- java.sql.SQLException: The server time zone value ' й ʱ ' is unrecognized or represents more tha
- Android 单个指定蓝牙设备通讯流程
- 项目启动失败解决方法
- yum下载速度慢解决,提速飞起来
- 手指头肌腱损伤鸿蒙训练,手指肌腱损伤恢复方法有哪些
- FFT 快速傅里叶变换 初探
- 怎么提升淘宝网店的转化率
- 【网络互联技术】(三) 网络互联基础。
- Oracle密码过期策略
- Proteus 8.9 SP2 专业版一体化快速安装教程(配包)
- DG发表声明称,账号被盗了。对此,黑客表示:这个锅我们不背!
热门文章
- php 安装测试程序,PHPUnit安装及使用示例
- mac使用被动ftp模式(pasv)_ftp主动模式和被动模式
- java測試動態方法_java反射学习
- python传输大文件_python之socket运用之传输大文件
- javascript:history.go()和history.back()的区别
- .NET CORE 2.1 导出excel文件的两种方法
- hadoop配置2.6.1 centos7
- spring源码分析,聊聊PropertyPlaceholderConfigurer
- Mac 安装redis
- Lin总线应用层代码