多层感知机的从零开始实现

  • 获取和读取数据
  • 定义模型参数
  • 定义激活函数
  • 定义模型
  • 定义损失函数
  • 训练模型
  • 小结

我们已经从上一节里了解了多层感知机的原理。下面,我们一起来动手实现一个多层感知机。首先导入实现所需的包或模块。

import torch
import numpy as np

获取和读取数据

这里继续使用Fashion-MNIST数据集。代码和之前softmax回归是一样的,我们将使用多层感知机对图像进行分类。

定义模型参数

我们在3(softmax回归的从零开始实现)里已经介绍了,Fashion-MNIST数据集中图像形状为 28×2828 \times 2828×28,类别数为10。本节中我们依然使用长度为 28×28=78428 \times 28 = 78428×28=784 的向量表示每一张图像。因此,输入个数为784,输出个数为10。实验中,我们设超参数隐藏单元个数为256。

num_inputs, num_outputs, num_hiddens = 784, 10, 256W1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hiddens)), dtype=torch.float,requires_grad=True)
b1 = torch.zeros(num_hiddens, dtype=torch.float,requires_grad=True)
W2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens, num_outputs)), dtype=torch.float,requires_grad=True)
b2 = torch.zeros(num_outputs, dtype=torch.float,requires_grad=True)

定义激活函数

这里我们使用基础的max函数来实现ReLU,而非直接调用relu函数。

def relu(X):return torch.max(input=X, other=torch.tensor(0.0))

定义模型

同softmax回归一样,我们通过view函数将每张原始图像改成长度为num_inputs的向量。然后我们实现上一节中多层感知机的计算表达式。

def net(X):X = X.view((-1, num_inputs))H = relu(torch.matmul(X, W1) + b1)return torch.matmul(H, W2) + b2

定义损失函数

为了得到更好的数值稳定性,我们直接使用PyTorch提供的包括softmax运算和交叉熵损失计算的函数。

loss = torch.nn.CrossEntropyLoss()

训练模型

训练多层感知机的步骤和之前训练softmax回归的步骤没什么区别。

小结

  • 可以通过手动定义模型及其参数来实现简单的多层感知机。
  • 当多层感知机的层数较多时,本节的实现方法会显得较烦琐,例如在定义模型参数的时候。

深度学习pytorch--多层感知机(二)相关推荐

  1. 深度学习 — — PyTorch入门(二)

    在深度学习--PyTorch入门(一)中我们介绍了构建网络模型和加载数据的内容,本篇将继续介绍如何完成对模型的训练. 训练:更新网络权重 构建网络结构和加载完数据集之后,便可以开始进行网络权重的训练. ...

  2. 深度学习基础——多层感知机

    多层感知机(Multilayer Perceptron, MLP)是最简单的深度网络.本文回顾多层感知机的相关内容及一些基本概念术语. 多层感知机 为什么需要多层感知机 多层感知机是对线性回归的拓展和 ...

  3. 【动手学深度学习】多层感知机(MLP)

    1 多层感知机的从零开始实现 torch.nn 继续使用Fashion-MNIST图像分类数据集 导入需要的包 import torch from torch import nn from d2l i ...

  4. 动手学深度学习之多层感知机

    多层感知机 多层感知机的基本知识 深度学习主要关注多层模型.本节将以多层感知机(multilayer perceptron,MLP)为例,介绍多层神经网络的概念. 隐藏层 下图展示了一个多层感知机的神 ...

  5. [深度学习] (sklearn)多层感知机对葡萄酒的分类

    时间:2021年12月2日 from sklearn.datasets import load_wine from sklearn.model_selection import train_test_ ...

  6. 动手学深度学习(PyTorch实现)(十二)--批量归一化(BatchNormalization)

    批量归一化-BatchNormalization 1. 前言 2. 批量归一化的优势 3. BN算法介绍 4. PyTorch实现 4.1 导入相应的包 4.2 定义BN函数 4.3 定义BN类 5. ...

  7. 动手深度学习PyTorch(十二)word2vec

    独热编码 独热编码即 One-Hot 编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效.举个例子,假设我们有四个样 ...

  8. 【深度学习】基于Pytorch多层感知机的高级API实现和注意力机制(二)

    [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(二) 文章目录1 代码实现 2 训练误差和泛化误差 3 模型复杂性 4 多项式回归4.1 生成数据集4.2 对模型进行训练和测试4 ...

  9. 【深度学习】基于Pytorch多层感知机的高级API实现和注意力机制(三)

    [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(三) 文章目录 [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(三) 1 权重衰减 1.1 范数 1.2 L ...

最新文章

  1. 280. Wiggle Sort
  2. BZOJ.3277.串(广义后缀自动机)
  3. 4路组相连cache设计_移动图形处理器的纹理Cache设计
  4. my-innodb-heavy-4G.cnf 配置文件参数介绍
  5. java零碎要点001--深入理解JVM_Java的堆内存_栈内存_以及运行时数据区的作用
  6. 12大深度学习开源框架(caffe,tensorflow,pytorch,mxnet等)汇总详解
  7. win10双显卡怎么切换amd和英特尔_win10双显卡怎么切换
  8. 160304-01、mysql数据库插入速度和读取速度的调整记录
  9. PT建站源码(PT服务器原程序)汇总 by 乱世狂人
  10. c语言英文拼写检查器,c – 简单的拼写检查算法
  11. 西电计科操作系统实验
  12. element-UI el-dialog组件按ESC键关闭不了弹窗
  13. Typecho+Handsome主题美化
  14. 【人工智能】实验一:基于MLP的手写体字符识别
  15. 概率论与数理统计(3):二维随机变量及其分布
  16. 洛谷P4170 [CQOI2007]涂色 题解
  17. OMCI协议二层功能的模型选择
  18. VUE 中实现echarts中国地图 人口迁徙
  19. Win7运行命令的打开方法 Win7运行命令大全(45个)
  20. Ubuntu如何修改grub启动项

热门文章

  1. 嵌入式java基准测试_Java正则表达式库基准测试– 2015年
  2. 在Java中将时间单位转换为持续时间
  3. 如何集成和使用EclEmma插件来获得良好的Junit覆盖率
  4. JMetro 5.5版发布
  5. 使用混合多云每个人都应避免的3个陷阱(第1部分)
  6. Spring批处理CSV处理
  7. 渴望订阅– RxJava常见问题解答
  8. 消息队列概述[幻灯片]
  9. 使用Java 8在地图上流式传输
  10. primefaces_懒惰的JSF Primefaces数据表分页–第2部分