文章目录

  • 1.导入必要模块
  • 2.超参数设置
  • 3.数据准备
  • 4.打印部分加载的数据
  • 5.模型建立
  • 6.训练

1.导入必要模块

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

2.超参数设置

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   #设备input_size = 784      #输入大小
hidden_size = 500     #隐层大小
num_classes = 10      #输出大小
num_epochs = 2        #迭代次数
batch_size = 100      #批量大小
learning_rate = 0.001   #学习率

3.数据准备

train_dataset = torchvision.datasets.MNIST(root='./data',       #下载训练数据train = True,transform=transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',        #下载测试数据train = False,transform=transforms.ToTensor())train_loader = torch.utils.data.DataLoader(dataset=train_dataset,     #制作DataLoaderbatch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

4.打印部分加载的数据

examples = iter(test_loader)
example_data, example_targets = examples.next()for i in range(6):plt.subplot(2, 3, i+1)plt.imshow(example_data[i][0], cmap='gray')
plt.show()

5.模型建立

class NeuralNet(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet, self).__init__()self.input_size = input_sizeself.l1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.l2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.l1(x)out = self.relu(out)out = self.l2(out)return out

6.训练

model = NeuralNet(input_size, hidden_size, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
n_total_steps = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, 28*28).to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1)%100==0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss:{loss.item():.4f}')with torch.no_grad():n_correct = 0n_samples = 0for images, labels in test_loader:images = images.reshape(-1, 28*28).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)n_samples += labels.size(0)n_correct += (predicted == labels).sum().item()
acc = 100.0*n_correct / n_samples
print(f'Accuracy of the network on the 1000 test images:{acc}%')

Pytorch专题实战——前馈神经网络(Feed-Forward Neural Network)相关推荐

  1. keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)

    keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...

  2. keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习

    keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习 前馈神经网络(feedforward neural network)是一种最简单的神经网络,各 ...

  3. keras构建前馈神经网络(feedforward neural network)进行分类模型构建并加入L2正则化

    keras构建前馈神经网络(feedforward neural network)进行分类模型构建并加入L2正则化 正则化(Regularization)是机器学习中一种常用的技术,其主要目的是控制模 ...

  4. keras构建前馈神经网络(feedforward neural network)进行回归模型构建和学习

    keras构建前馈神经网络(feedforward neural network)进行回归模型构建和学习 我们不必在"回归"一词上费太多脑筋.英国著名统计学家弗朗西斯·高尔顿(Fr ...

  5. 深度学习实验1:pytorch实践与前馈神经网络

    深度学习实验1:pytorch实践与前馈神经网络 1.pytorch基本操作 1.使用

  6. 深度学习笔记(四)——循环神经网络(Recurrent Neural Network, RNN)

    目录 一.RNN简介 (一).简介 (二).RNN处理任务示例--以NER为例 二.模型提出 (一).基本RNN结构 (二).RNN展开结构 三.RNN的结构变化 (一).N to N结构RNN模型 ...

  7. 卷积神经网络(Convolutional Neural Network,CNN)

    卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现. 它包括卷积层(con ...

  8. 类脑运算--脉冲神经网络(Spiking Neural Network)发展现状

    类脑运算–脉冲神经网络(Spiking Neural Network)发展现状 前一段时间忙于博士论文的攥写和答辩, 抱歉拖更 继上一章: 类脑运算–脉冲神经网络(Spiking Neural Net ...

  9. R语言高级算法之人工神经网络(Artificial Neural Network)

    1.人工神经网络原理分析: 神经网络是一种运算模型,由大量的节点(或称神经元)和之间的相互连接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function).每两个节点间 ...

最新文章

  1. 【java】增强for循环的简单使用(遍历数组)
  2. 因果解释能够对规则进行解释吗?
  3. 1. python 字符串简介与常用函数
  4. 2016重庆计算机一级考试题型,重庆计算机一级考试真题2016年最新(笔试+上机).doc...
  5. Netty入门篇-从双向通信开始
  6. Spring实战(十三)Spring事务
  7. java smp_什么是SMP系统
  8. java中多叉树(tree)的生成与显示
  9. linux下spark的python编辑_Linux下搭建Spark 的 Python 编程环境的方法
  10. 使用Json出现java.lang.NoClassDefFoundError解决方法
  11. 【转】离婚男人给女孩的恋爱忠告
  12. 由三目运算符想出的PHP改进建议
  13. python动态数据类型_[python学习手册-笔记]004.动态类型
  14. 【74系列芯片的Verilog重现(一)】------74HC00
  15. mac安装绿联USB转以太网驱动
  16. java数据结构——哈希表
  17. 接口各项性能测试指标
  18. angular RxJs
  19. 《linux多线程服务端编程》---- C++基础前奏
  20. 【FFmpeg+Qt开发】转码流程 H.264 转(mov、mp4、avi、flv)等视频格式 示例详解

热门文章

  1. qt中dll缺失以及无法启动程序的正确解决方法
  2. linux计算圆周率程序,科学网—[转载]关于Linux中使用bc命令计算圆周率(π):可以计算上千位或上万位,顺便评测CPU的计算能力 - 张成岗的博文...
  3. php查找二维数组值,根据二维数组某个字段的值查找数组
  4. 学会asp后再学php,九天学会ASP 之 第二天
  5. 同步现象 心理学_男生是不是更容易从失恋中走出来?心理学:失恋后悲伤,男女不同...
  6. 深入大数据安全分析(1):为什么需要大数据安全分析?
  7. java安卓如何实现定义接口
  8. linux发布微软消息队列,消息队列RabbitMQ入门与5种模式详解
  9. Apache Qpid:一个AMQP的开源实现
  10. linux网络安装mysql_linux系统安装mysql