torch.nn是专门为神经网络设计的模块化接口。nn构建于 Autograd之上,可用来定义和运行神经网络。
nn.functional,这个包中包含了神经网络中使用的一些常用函数,这些函数的特点是,不具有可学习的参数(如ReLU,pool,DropOut等),这些函数可以放在构造函数中,也可以不放,但是这里建议不放。

定义一个网络

PyTorch中已经为我们准备好了现成的网络模型,只要继承nn.Module,并实现它的forward方法,PyTorch会根据autograd,自动实现backward函数。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MLP(nn.Module):def __init__(self, user_num, user_dim, layers=[32, 16, 8]):super(MLP, self).__init__()  # 子类函数必须在构造函数中执行父类构造函数self.user_Embedding = nn.Embedding(user_num, user_dim)self.mlp = nn.Sequential() for id in range(1, len(layers)):  # 这样可以实现MLP层数和每层神经单元数的自动调整self.mlp.add_module("Linear_layer_%d" % id, nn.Linear(layers[id - 1], layers[id]))self.mlp.add_module("Relu_layer_%d" % id, nn.ReLU(inplace=True))self.predict = nn.Sequential(nn.Linear(layers[-1], 1),nn.Sigmoid(),)def forward(self, x):user = self.user_Embedding(x)user = self.mlp(user)score = self.predict(user)return scoremodel = MLP(1000, 64)
print(model)
MLP((user_Embedding): Embedding(1000, 64)(mlp): Sequential((Linear_layer_1): Linear(in_features=32, out_features=16, bias=True)(Relu_layer_1): ReLU(inplace=True)(Linear_layer_2): Linear(in_features=16, out_features=8, bias=True)(Relu_layer_2): ReLU(inplace=True))(predict): Sequential((0): Linear(in_features=8, out_features=1, bias=True)(1): Sigmoid())
)
for parameters in model.parameters():print(parameters)
Parameter containing:
tensor([[ 0.4192, -1.0525,  1.4208,  0.5376,  2.1371,  0.7074,  0.1017,  0.9701, 1.2824, -0.0436],[-0.6374,  0.0153, -0.1862, -0.6061,  0.5522, -1.1526,  0.3913,  0.3103,-0.1055,  0.6098],[-0.0367, -0.9573, -0.5106, -1.2440,  1.2201, -0.5424,  0.2045,  0.2208,-0.7557, -0.7811],[ 0.5457,  0.3586,  0.9871, -0.2117,  1.0885,  1.7162, -0.2125,  0.2652,-0.3262,  0.3047],[ 0.1039,  0.8132,  0.6638,  0.2618,  0.8552,  0.8300,  0.2349,  1.8830,-0.5149, -1.0468]], requires_grad=True)
Parameter containing:
tensor([[-0.2395,  0.1461, -0.0161,  0.0267, -0.0353,  0.2085,  0.0046, -0.1572],[ 0.2267,  0.0129, -0.3296, -0.2270,  0.2268,  0.1771, -0.0992,  0.2148],[ 0.1906,  0.1896, -0.2703, -0.3506,  0.0248,  0.1949, -0.3117,  0.0721],[-0.3197,  0.2782, -0.1553,  0.2509,  0.0279,  0.2040, -0.1478,  0.2943]],requires_grad=True)
Parameter containing:
tensor([ 0.0808, -0.3252, -0.0015, -0.0666], requires_grad=True)
Parameter containing:
tensor([[-0.3243,  0.4393, -0.2430,  0.4330]], requires_grad=True)
Parameter containing:
tensor([-0.0739], requires_grad=True)
for name,parameters in model.named_parameters():print(name,':',parameters.size())
user_Embedding.weight : torch.Size([5, 10])
mlp.Linear_layer_1.weight : torch.Size([4, 8])
mlp.Linear_layer_1.bias : torch.Size([4])
predict.0.weight : torch.Size([1, 4])
predict.0.bias : torch.Size([1])

Pytorch 实现 MLP相关推荐

  1. 使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负

    使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负 1. 数据集 百度网盘链接,提取码:q79p 数据集文件格式为CSV.数据集包含了大约5万场英雄联盟钻石排位赛前15分钟的数据集合,总 ...

  2. pytorch 实现MLP(多层感知机)

    pytorch 实现多层感知机,主要使用torch.nn.Linear(in_features,out_features),因为torch.nn.Linear是全连接的层,就代表MLP的全连接层 本文 ...

  3. PyTorch之MLP

    一 .关于Graphviz 的问题 首先手动下载对应的包并安装,添加环境变量,如果仍然不行,考虑如下方法 graphviz.backend.execute.ExecutableNotFound: fa ...

  4. PyTorch入门(二)搭建MLP模型实现分类任务

      本文是PyTorch入门的第二篇文章,后续将会持续更新,作为PyTorch系列文章.   本文将会介绍如何使用PyTorch来搭建简单的MLP(Multi-layer Perceptron,多层感 ...

  5. Pytorch深度学习入门与实战一--全连接神经网络

    全连接神经网络在分类和回归问题中都非常有效,本节对全连接神经网及神经元结构进行介绍,并使用Pytorch针对分类和回归问题分别介绍全连接神经网络的搭建,训练及可视化相关方面的程序实现. 1.全连接神经 ...

  6. 第五章 全连接神经网络

    第五章 全连接神经网络 1.1 全连接神经网络 人工神经网络(Artificial Neural Network)可以对一组输入信号和一组输出信号之间的关系进行建模,是机器学习和认知科学中的一种模仿生 ...

  7. 深度学习(从零开始)

    1.线性回归 y=wx+b 2.学习网站:python 入门基础 python入门的120个基础练习(一) - 知乎 sklearn官方   https://sklearn.apachecn.org/ ...

  8. pytorch学习笔记(十):MLP

    文章目录 1. 隐藏层 2. 激活函数 2.1 ReLU函数 2.2 sigmoid函数 2.3 tanh函数 3 多层感知机 4. 代码实现MLP 4.1 获取和读取数据 4.2 定义模型参数 4. ...

  9. Pytorch:全连接神经网络-MLP回归

    Pytorch: 全连接神经网络-解决 Boston 房价回归问题 Copyright: Jingmin Wei, Pattern Recognition and Intelligent System ...

最新文章

  1. 上海交大发布 MedMNIST 医学图像分析数据集 新基准
  2. html渐变不兼容,CSS3实现文字渐变效果,兼容性最强系列!
  3. 初中数学分几个模块_【初中数学】8大模块61个必考易错知识点!
  4. c++排序函数对二维数组排序_JS骚操作之数组快速排序
  5. 数据库计划中的14个才略
  6. WildFly 8.2.0.Final版本–更改的快速概述
  7. 计算机网络课程计划,计算机网络教学计划2017
  8. Android实现3D旋转效果
  9. java符号引用 直接引用_Java -- JVM的符号引用和直接引用
  10. yii2 html 跳转,阐述在Yii2上实现跳转提示页
  11. html通用的排班方法,呼叫中心排班的两种主要方法
  12. 利用锁机制解决商品表和库存表并发问题
  13. 一文详解YOLOX算法实现血细胞检测
  14. Kafka 设计与原理详解(一)
  15. 《图解算法》第10章之 k最近邻算法
  16. 如何快速的开发直播App
  17. 遭遇Trojan.PSW.ZhengTu,Trojan.PSW.OnlineGames,Trojan.PSW.ZhuXian.b等
  18. MATLAB求解一阶RC电路和二阶RLC电路
  19. 售后服务场景智能调度解决方案
  20. layui使用模板渲染数据

热门文章

  1. Java内存泄露原因详解
  2. 华为突然宣布,对物联网下手了!
  3. 一句话输出没有结束符的字符串
  4. vmware虚拟机中ubuntu上网问题
  5. 51单片机——DS18B20
  6. Orange-Classification,Regression
  7. 中缀表达式转换为前缀及后缀表达式并求值【摘】
  8. mysql索引使增删变慢_mysql优化(四)–索引
  9. ipconfig不是内部或外部_晶振有什么作用,如何选择合适的晶振,为什么有时候用内部晶振?...
  10. Joi验证模块的使用