PyTorch框架:(4)如何去构建数据
接PyTorch框架:(3)
1、最基本的方法
(1)使用模块
模块1:TensorDataset、模块2:DataLoader
自己去构造数据集,然后一个batch一个batch的取数据,自己去写构造数据太麻烦,可以自动让其把数据源给我们构建好,这两个模块就是来帮我们完成这个事的。
第一步把x_train和y_train传进去,使用TensorDataset自动的帮我们组件dataset即(train_ds);
DataLoader是得搭配一下,先把数据转化为TensorDataset所支持的格式,然后采用DataLoader读进来,DataLoader的意思就是你把数据交给我,然后你告诉我一个batch_size有多少,然后你要取数据的时候我就帮你一个batch一个batch的取数据,这样方便一些。shuffle=True表示要不要重新洗牌;
(2)定义一个get_data方法,需要传进来当前的数据集,后边做了一个return,就是按照一个Batch取数据就完事了;
(3)训练函数
自己定义一个训练方法,def fit方法,实际的去执行训练的操作。传进来的参数:
steps:一共迭代多少次。
model:就是定义的model,就是自己写个类,把model传进来。
loss_func:使用的f.中的损失。
opt:优化器是什么。
train_dl:实际数据传进来。
valid_dl:实际数据传进来。
Batch Normalization和Dropout在训练的时候一般都会加这两项,让模型过拟合的更低;在测试的时候一般就不加这两个东西了。所以为了有这两个区分,如果此时是训练,那么在训练的时候加上model.train();下边不是训练就是走一次前向传播,看一下对于当前模型来说他的一个效果,他的损失等于多少,把损失拿过来,我也不需要进行参数更新,不需要计算梯度,也不需要训练的过程,所以这一块我再额外的指定一下,这块不需要加Batch Normalization和Dropout,他不是一个训练的过程,所以在前边加上model.eval()。
所以见到这两个就是表示:model.train()强调的是你的训练过程,把该加的加进去;model.eval()强调的是测试过程,只需要得到结果,不需要把没用的都加进去。
loss_batch做的事情:如果你传进来一个优化器,优化器求梯度,求完梯度更新,更新完之后置0,然后返回结果。这里不光计算一个loss值还要去计算他实际的梯度值是多少,要进行参数的更新。
上述相当于把每个模块都准备好了,实际训练模型的时候不用把每个函数都也在一个sell当中,下面三行就搞定了:
第一步:拿到数据getdata。
第二步:拿到模型和优化器。(模型就是自己的类Mnist_NN)
第三步:执行fit函数。(fit函数的第三个参数表示损失函数是如何计算的,在损失函数计算当中还加入了梯度的更新,第四个使用什么样的优化器去更新我当前的结果)
2、复杂的方法
暂定
PyTorch框架:(4)如何去构建数据相关推荐
- 银行股价预测——基于pytorch框架RNN神经网络
银行股价预测--基于pytorch框架RNN神经网络 任务目标 数据来源 完整代码 流程分析 1.导包 2.读入数据并做预处理 3.构建单隐藏层Rnn模型 4.设计超参数,训练模型 5.加载模型,绘图 ...
- PyTorch框架学习八——PyTorch数据读取机制(简述)
PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...
- PyTorch框架:(3)使用PyTorch框架构构建神经网络分类任务
目录 0.背景 1.分类任务介绍: 2.网络架构 3.手写网络 3.1.读取数据集 3.2.查看数据集 3.3将x和y转换成tensor的格式 3.4.定义model 0.背景 其实分类和回归本质上没 ...
- PyTorch框架学习九——网络模型的构建
PyTorch框架学习九--网络模型的构建 一.概述 二.nn.Module 三.模型容器Container 1.nn.Sequential 2.nn.ModuleList 3.nn.ModuleDi ...
- 达芬奇 - 构建数据查询API的框架
达芬奇 - 基于"Serverless"的数据查询API框架 此文背景 我们要解决什么样的问题? 系统要求 系统设计 访问控制列表及其使用 过滤器以及其语法 Serverless ...
- pytorch框架(计算机视觉)一.数据增强
pytorch框架中一些代码的解释 预处理函数 class torchvision.transforms.Compose(transforms) data_transforms = {'train': ...
- PyTorch框架:(1)基本处理操作
目录 1.PyTorch框架介绍 2.安装Pytorch 2.1.CPU版本的安装命令: 2.2.GPU版本的安装命令: 2.2.1.安装CUDA 3.基本使用方法 4.Pytorch中的自动求导机制 ...
- PyTorch框架学习二十——模型微调(Finetune)
PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...
- PyTorch框架学习十——基础网络层(卷积、转置卷积、池化、反池化、线性、激活函数)
PyTorch框架学习十--基础网络层(卷积.转置卷积.池化.反池化.线性.激活函数) 一.卷积层 二.转置卷积层 三.池化层 1.最大池化nn.MaxPool2d 2.平均池化nn.AvgPool2 ...
最新文章
- python编程用户登陆c_django实现用户登陆功能详解
- 杨清彦:《像三国》游戏3D动效制作经验分享
- PAT-B 1015. 德才论(同PAT 1062. Talent and Virtue)
- php+java+框架整合_ThinkPhP+Apache+PHPstorm整合框架流程图解
- SQL Server on Ubuntu——Ubuntu上的SQL Server(全截图)
- [LeetCode]235.Lowest Common Ancestor of a Binary Search Tree
- 图说Netty服务端启动过程
- python生成器表达式_python 生成器和生成器表达式
- linux安装telnet工具下载,Linux下安装telnet的方法
- android 选座系统,android 影院选座
- C语言面试必问的经典问题(纯”gan“货)
- Transformer 权重共享
- MSP430F149实现超声波测距并通过串口和PC机通信进行显示
- 云MAS - MT-提交状态码
- mysql group concat 去重,MySQL group_concat() 函数用法
- C语言-小黄鸭☞循环结构while
- 4 RRC Measurement -- 配置
- ybt1283:登山
- 麦克纳姆轮底盘-正反向运动学-里程估计
- jumpstart-6.10.3安装指南
热门文章
- Python的Xpath介绍和语法详解
- HTTP/HTTPS抓包工具-Fiddler
- [C] [字节跳动] [编程题] 手串
- Python 判断当前数值的类型(比如x=10 获取x的类型)
- Python 2x 中list 里面的中文打印效果乱码
- 洛谷P1092 虫食算
- 【Gamma】“北航社团帮”展示博客
- 2022-2028年中国重卡行业投资分析及前景预测报告
- C++ 笔记(25)— 理解 C++ 中的头文件和源文件的作用
- Session原理、安全以及最基本的Express和Redis实现