主要用到了Pytorch中的Dataset和DataLoadder这两个方法,其中Dataset是抽象类,不能实例化对象,只能继承用于构造数据集,DataLoader是帮助加载数据的,可以做shuffle、batch_size,能拿Mini-batch进行训练。
代码如下:

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoaderclass DiabetesDataset(Dataset):def __init__(self,filepath):xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)self.len = xy.shape[0]self.x_data=torch.from_numpy(xy[:,:-1])self.y_data = torch.from_numpy(xy[:,[-1]])def __getitem__(self,index):return self.x_data[index],self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('E:\\tmp\\.keras\\datasets\\diabetes.csv\\diabetes.csv')
train_loader = DataLoader(dataset=dataset,batch_size = 32,shuffle = True,num_workers=0
)#定义模型
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.relu = torch.nn.ReLU()self.sigmoid = torch.nn.Sigmoid()def forward(self,x):x = self.relu(self.linear1(x))x = self.relu(self.linear2(x))x = self.sigmoid(self.linear3(x))return x
model = Model()#构建优化器和损失函数
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
#训练
if __name__ == '__main__':for epoch in range(1000):for i,data in enumerate(train_loader,0):inputs,labels = datay_pred = model(inputs)loss=criterion(y_pred,labels)print(epoch,i,loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

努力加油a啊

Pytorch(七) --加载数据集相关推荐

  1. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  2. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  3. Pytorch深度学习(五):加载数据集以及mini-batch的使用

    Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...

  4. Pytorch加载数据集的方式总结

    Pytorch加载数据集的方式总结 一.自己重写定义(Dataset.DataLoader) 二.用Pytorch自带的类(ImageFolder.datasets.DataLoader) 2.1 加 ...

  5. pytorch创建自己的Dataset加载数据集

    文章目录 创建一个类并继承torch.utils.data.dataset.Datase类 创建__getitem__方法 加载数据集 创建一个类并继承torch.utils.data.dataset ...

  6. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...

  7. pytorch模型加载测试_使用Pytorch实现物体检测(Faster R-CNN)

    在本示例中,介绍一种two-stage算法(Faster R-CNN),将目标区域检测和类别识别分为两个任务进行物体检测.本示例采用PyTorch引擎进行模型构建. 如果您已熟练使用Notebook和 ...

  8. PyTorch数据加载器

    We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...

  9. python_torch_加载数据集_构建模型_构建训练循环_保存和调用训练好的模型

    以下代码均来自bilibili:[适用于初学者的Pytorch编程教学] 以下为完整代码,复制即可运行. import torch import time import json import tor ...

最新文章

  1. 读博无门就业碰壁,孤独当了7个月“民科”后,我的论文中了顶会
  2. [Microsoft][ODBC SQL Server Driver][SQl Server]参数数据类型 text 对于 replace 函数的参数 1 无效。...
  3. WingIDE 5的安装与破解方法
  4. 14.线程安全?线程不安全?可重入函数?不可重入函数?
  5. 服务器上如何安装两个php网站,服务器安装两个php版本吗
  6. magento更新产品状态报错
  7. java中的传参是什么意思_如果作为参数传递,“字符串…参数”是什么意思?...
  8. 评分卡模型开发(五)--定性指标筛选
  9. 基础知识:计算机网络--《趣谈网络协议》读书笔记
  10. java.sql.SQLException: The server time zone value ' й ׼ʱ ' is unrecognized or represents more tha
  11. Android 单个指定蓝牙设备通讯流程
  12. 项目启动失败解决方法
  13. yum下载速度慢解决,提速飞起来
  14. 手指头肌腱损伤鸿蒙训练,手指肌腱损伤恢复方法有哪些
  15. FFT 快速傅里叶变换 初探
  16. 怎么提升淘宝网店的转化率
  17. 【网络互联技术】(三) 网络互联基础。
  18. Oracle密码过期策略
  19. Proteus 8.9 SP2 专业版一体化快速安装教程(配包)
  20. DG发表声明称,账号被盗了。对此,黑客表示:这个锅我们不背!

热门文章

  1. php 安装测试程序,PHPUnit安装及使用示例
  2. mac使用被动ftp模式(pasv)_ftp主动模式和被动模式
  3. java測試動態方法_java反射学习
  4. python传输大文件_python之socket运用之传输大文件
  5. javascript:history.go()和history.back()的区别
  6. .NET CORE 2.1 导出excel文件的两种方法
  7. hadoop配置2.6.1 centos7
  8. spring源码分析,聊聊PropertyPlaceholderConfigurer
  9. Mac 安装redis
  10. Lin总线应用层代码