Pytorch专题实战——前馈神经网络(Feed-Forward Neural Network)
文章目录
- 1.导入必要模块
- 2.超参数设置
- 3.数据准备
- 4.打印部分加载的数据
- 5.模型建立
- 6.训练
1.导入必要模块
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
2.超参数设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #设备input_size = 784 #输入大小
hidden_size = 500 #隐层大小
num_classes = 10 #输出大小
num_epochs = 2 #迭代次数
batch_size = 100 #批量大小
learning_rate = 0.001 #学习率
3.数据准备
train_dataset = torchvision.datasets.MNIST(root='./data', #下载训练数据train = True,transform=transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', #下载测试数据train = False,transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(dataset=train_dataset, #制作DataLoaderbatch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
4.打印部分加载的数据
examples = iter(test_loader)
example_data, example_targets = examples.next()for i in range(6):plt.subplot(2, 3, i+1)plt.imshow(example_data[i][0], cmap='gray')
plt.show()
5.模型建立
class NeuralNet(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet, self).__init__()self.input_size = input_sizeself.l1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.l2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.l1(x)out = self.relu(out)out = self.l2(out)return out
6.训练
model = NeuralNet(input_size, hidden_size, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
n_total_steps = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, 28*28).to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1)%100==0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss:{loss.item():.4f}')with torch.no_grad():n_correct = 0n_samples = 0for images, labels in test_loader:images = images.reshape(-1, 28*28).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)n_samples += labels.size(0)n_correct += (predicted == labels).sum().item()
acc = 100.0*n_correct / n_samples
print(f'Accuracy of the network on the 1000 test images:{acc}%')
Pytorch专题实战——前馈神经网络(Feed-Forward Neural Network)相关推荐
- keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)
keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...
- keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习
keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习 前馈神经网络(feedforward neural network)是一种最简单的神经网络,各 ...
- keras构建前馈神经网络(feedforward neural network)进行分类模型构建并加入L2正则化
keras构建前馈神经网络(feedforward neural network)进行分类模型构建并加入L2正则化 正则化(Regularization)是机器学习中一种常用的技术,其主要目的是控制模 ...
- keras构建前馈神经网络(feedforward neural network)进行回归模型构建和学习
keras构建前馈神经网络(feedforward neural network)进行回归模型构建和学习 我们不必在"回归"一词上费太多脑筋.英国著名统计学家弗朗西斯·高尔顿(Fr ...
- 深度学习实验1:pytorch实践与前馈神经网络
深度学习实验1:pytorch实践与前馈神经网络 1.pytorch基本操作 1.使用
- 深度学习笔记(四)——循环神经网络(Recurrent Neural Network, RNN)
目录 一.RNN简介 (一).简介 (二).RNN处理任务示例--以NER为例 二.模型提出 (一).基本RNN结构 (二).RNN展开结构 三.RNN的结构变化 (一).N to N结构RNN模型 ...
- 卷积神经网络(Convolutional Neural Network,CNN)
卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现. 它包括卷积层(con ...
- 类脑运算--脉冲神经网络(Spiking Neural Network)发展现状
类脑运算–脉冲神经网络(Spiking Neural Network)发展现状 前一段时间忙于博士论文的攥写和答辩, 抱歉拖更 继上一章: 类脑运算–脉冲神经网络(Spiking Neural Net ...
- R语言高级算法之人工神经网络(Artificial Neural Network)
1.人工神经网络原理分析: 神经网络是一种运算模型,由大量的节点(或称神经元)和之间的相互连接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function).每两个节点间 ...
最新文章
- 【java】增强for循环的简单使用(遍历数组)
- 因果解释能够对规则进行解释吗?
- 1. python 字符串简介与常用函数
- 2016重庆计算机一级考试题型,重庆计算机一级考试真题2016年最新(笔试+上机).doc...
- Netty入门篇-从双向通信开始
- Spring实战(十三)Spring事务
- java smp_什么是SMP系统
- java中多叉树(tree)的生成与显示
- linux下spark的python编辑_Linux下搭建Spark 的 Python 编程环境的方法
- 使用Json出现java.lang.NoClassDefFoundError解决方法
- 【转】离婚男人给女孩的恋爱忠告
- 由三目运算符想出的PHP改进建议
- python动态数据类型_[python学习手册-笔记]004.动态类型
- 【74系列芯片的Verilog重现(一)】------74HC00
- mac安装绿联USB转以太网驱动
- java数据结构——哈希表
- 接口各项性能测试指标
- angular RxJs
- 《linux多线程服务端编程》---- C++基础前奏
- 【FFmpeg+Qt开发】转码流程 H.264 转(mov、mp4、avi、flv)等视频格式 示例详解
热门文章
- qt中dll缺失以及无法启动程序的正确解决方法
- linux计算圆周率程序,科学网—[转载]关于Linux中使用bc命令计算圆周率(π):可以计算上千位或上万位,顺便评测CPU的计算能力 - 张成岗的博文...
- php查找二维数组值,根据二维数组某个字段的值查找数组
- 学会asp后再学php,九天学会ASP 之 第二天
- 同步现象 心理学_男生是不是更容易从失恋中走出来?心理学:失恋后悲伤,男女不同...
- 深入大数据安全分析(1):为什么需要大数据安全分析?
- java安卓如何实现定义接口
- linux发布微软消息队列,消息队列RabbitMQ入门与5种模式详解
- Apache Qpid:一个AMQP的开源实现
- linux网络安装mysql_linux系统安装mysql