pytorch使用torch.nn.Sequential构建网络
以一个线性回归的例子为例:
全部代码
import torch
import numpy as npdef get_x_y():x = np.random.randint(0, 50, 300)y_values = 2 * x + 21x = np.array(x, dtype=np.float32)y = np.array(y_values, dtype=np.float32)x = x.reshape(-1, 1)y = y.reshape(-1, 1)return x, yif __name__ == '__main__':train_x, train_y = get_x_y()input_size = train_x.shape[1] # 输入的维度,只是1维output_size = 1 # 输出的个数batch_size = 16 # 每个 batch 的数量my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, output_size),)cost = torch.nn.MSELoss(reduction='mean') # 使用MSE作为损失函数optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001) # 优化器# 训练网络losses = []for i in range(1000):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(train_x), batch_size):end = start + batch_size if start + batch_size < len(train_x) else len(train_x)xx = torch.tensor(train_x[start:end], dtype=torch.float, requires_grad=True)yy = torch.tensor(train_y[start:end], dtype=torch.float, requires_grad=True)prediction = my_nn(xx) # 自己的网络模型(x数据),会被认为做前向传播loss = cost(prediction, yy) # 计算损失optimizer.zero_grad() # 清零优化器!!!!一定要记得loss.backward(retain_graph=True) # 反向传播,retain_graph表示是否重复执行操作,在循环中需要设置为Trueoptimizer.step() # 更新参数batch_loss.append(loss.data.numpy())# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))
pytorch使用torch.nn.Sequential构建网络相关推荐
- pytorch torch.nn.Sequential(* args)(嘎哈用的?构建神经网络用的?)
class torch.nn.Sequential(* args) 一个时序容器.Modules 会以他们传入的顺序被添加到容器中.当然,也可以传入一个OrderedDict. 为了更容易的理解如何使 ...
- PyTorch 笔记(16)— torch.nn.Sequential、torch.nn.Linear、torch.nn.RelU
PyTorch 中的 torch.nn 包提供了很多与实现神经网络中的具体功能相关的类,这些类涵盖了深度神经网络模型在搭建和参数优化过程中的常用内容,比如神经网络中的卷积层.池化层.全连接层这类层次构 ...
- 【深度学习】torch.nn.Sequential方法介绍
torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中. 另外,也可以传入一个有序模块. 作用:Sequential除了本身可以用来定义模型之 ...
- 【torch.nn.Sequential】序列容器的介绍和使用
文章目录 torch.nn.Sequential 简单介绍 构建实例 参数列表 字典 基本操作 参考 torch.nn.Sequential 简单介绍 nn.Sequential是一个有序的容器,该类 ...
- Pytorch中torch.nn.Softmax的dim参数含义
自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义
- Pytorch 之torch.nn初探
第1关:torch.nn.Module 本关要求利用nn.Linear()声明一个线性模型 l,并构建一个变量 net 由三个l序列构成. import torch import torch.nn a ...
- 【PyTorch】torch.nn.Transformer解读与应用
nn.TransformerEncoderLayer 这个类是transformer encoder的组成部分,代表encoder的一个层,而encoder就是将transformerEncoderL ...
- 【Pytorch】torch.nn.Conv1d()理解与使用
官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html?highlight=nn%20conv1d#torch.nn.C ...
- Pytorch之torch.nn.functional.pad函数详解
torch.nn.functional.pad是PyTorch内置的矩阵填充函数 (1).torch.nn.functional.pad函数详细描述如下: torch.nn.functional.pa ...
最新文章
- 使用VS2008开发及部署Excel AddIn 心得
- Android Bluetooth 文件接收路径修改方法
- 《leetcode》first-missing-positive
- 银行喜欢全额还款的客户,还是喜欢最低还款客户?--编辑
- 事件捕获(capture)和冒泡事件(Bubble)
- python线程同步锁_[python] 线程间同步之Lock RLock
- (转)C#中Split用法
- java arraylist6_java 集合 ArrayList
- 动态修改log4net设置
- 装ubuntu_系统安装_win10下安装Ubuntu后,启动时无win10选项的解决办法。
- Unity 自学与进阶必会目录
- mysql批量插入跟更新_mysql批量插入以及批量更新
- 安装PdaNet以连接Android设备
- 华硕主板固态硬盘不识别_华硕主板固态硬盘识别不出来怎么办
- GLASS 产品使用(一)
- 《游戏系统设计四》游戏资源系统太复杂? 啥?你不会?一步一步带你分析并实现,源码直接拿走
- 【资源下载】分享个嵌入式开发的入门教程(包含视频)
- android 太阳系布局,Solar Walk太阳系模型软件-三维太阳系模型下载2.4.49安卓版-西西软件下载...
- matlab使用invfreqs出错,MATLAB实验2016剖析.doc
- 解决IDEA中使用git插件提交代码乱码问题