pytorch 实现多层感知机,主要使用torch.nn.Linear(in_features,out_features),因为torch.nn.Linear是全连接的层,就代表MLP的全连接层

本文实例MNIST数据,输入层28×28=784个节点,2个隐含层,隐含层各100个,输出层10个节点

开发平台,windows 7平台,python 3.8.5,anaconda3 ,torch版本1.8.1+cpu

#minist 用MLP实现,MLP也是使用pytorch实现的
import torchvision
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.optim as optim
import timetransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#下载数据集
data_train = datasets.MNIST(root = "..//data//",transform=transform,train = True,download = True)data_test = datasets.MNIST(root="..//data//",transform = transform,train = False,download = True)
#装载数据
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size = 64,shuffle = True)data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size = 64,shuffle = True)num_i=28*28 #输入层节点数
num_h=100   #隐含层节点数
num_o=10    #输出层节点数
batch_size=64class Model(torch.nn.Module):def __init__(self,num_i,num_h,num_o):super(Model,self).__init__()self.linear1=torch.nn.Linear(num_i,num_h)self.relu=torch.nn.ReLU()self.linear2=torch.nn.Linear(num_h,num_h) #2个隐层self.relu2=torch.nn.ReLU()self.linear3=torch.nn.Linear(num_h,num_o)def forward(self, x):x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.relu2(x)x = self.linear3(x)return xmodel=Model(num_i,num_h,num_o)
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
epochs = 5
for epoch in range(epochs) :sum_loss=0train_correct=0for data in data_loader_train:inputs,labels=data #inputs 维度:[64,1,28,28]#     print(inputs.shape)inputs=torch.flatten(inputs,start_dim=1) #展平数据,转化为[64,784]#     print(inputs.shape)outputs=model(inputs)optimizer.zero_grad()loss=cost(outputs,labels)loss.backward()optimizer.step()_,id=torch.max(outputs.data,1)sum_loss+=loss.datatrain_correct+=torch.sum(id==labels.data)print('[%d,%d] loss:%.03f' % (epoch + 1, epochs, sum_loss / len(data_loader_train)))print('        correct:%.03f%%' % (100 * train_correct / len(data_train)))print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
model.eval()
test_correct = 0
for data in data_loader_test :inputs, lables = datainputs, lables = Variable(inputs).cpu(), Variable(lables).cpu()inputs=torch.flatten(inputs,start_dim=1) #展并数据outputs = model(inputs)_, id = torch.max(outputs.data, 1)test_correct += torch.sum(id == lables.data)
print("correct:%.3f%%" % (100 * test_correct / len(data_test )))

输出结果如下:

[1,5] loss:0.213correct:93.482%
2021-12-08 18:24:49
[2,5] loss:0.134correct:95.898%
2021-12-08 18:25:01
[3,5] loss:0.102correct:96.818%
2021-12-08 18:25:13
[4,5] loss:0.084correct:97.323%
2021-12-08 18:25:24
[5,5] loss:0.071correct:97.750%
2021-12-08 18:25:40
correct:96.820%

其中使用2个隐含层,torch.flatten(inputs,start_dim=1) 是将维度[64,1,28×28]展平为维度[64,784]的数据,便于训练。当然这个MNIST的分类也可以用scklearn的MLPRegressor实现。

参考资料:

1 Pytorch 学习(五):Pytorch 实现多层感知机(MLP)_RememberUrHeart的博客-CSDN博客_多层感知机pytorch

pytorch 实现MLP(多层感知机)相关推荐

  1. TensorFlow基础之模型建立与训练:线性回归、MLP多层感知机、卷积神经网络

    TensorFlow基础之模型建立与训练 模型建立与训练:简单的线性回归 MLP多层感知机 数据获取.预处理 模型搭建 训练与评估 卷积神经网络 高效建模 Keras Sequential高效建模 F ...

  2. 使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负

    使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负 1. 数据集 百度网盘链接,提取码:q79p 数据集文件格式为CSV.数据集包含了大约5万场英雄联盟钻石排位赛前15分钟的数据集合,总 ...

  3. MLP多层感知机 学习笔记

    cvpr2022的 mobileformer中用到了mlp多层感知机,就来学习一下 其实就是3个全连接层,前面两个加了bn,最后一层没有加bn. import timeimport torch fro ...

  4. 机器学习 | MATLAB实现MLP多层感知机newff参数设定(上)

    机器学习 | MATLAB实现MLP多层感知机newff参数设定(上) 目录 机器学习 | MATLAB实现MLP多层感知机newff参数设定(上) 基本介绍 程序设计 参考资料 基本介绍 newff ...

  5. 机器学习 | MATLAB实现MLP多层感知机newff参数设定(下)

    机器学习 | MATLAB实现MLP多层感知机newff参数设定(下) 目录 机器学习 | MATLAB实现MLP多层感知机newff参数设定(下) 基本介绍 程序设计 参考资料 基本介绍 newff ...

  6. 机器学习 | MATLAB实现MLP多层感知机模型设计

    机器学习 | MATLAB实现MLP多层感知机模型设计 目录 机器学习 | MATLAB实现MLP多层感知机模型设计 基本介绍 模型描述 模型设计 程序设计 学习总结 参考资料 基本介绍 多层感知器( ...

  7. 动手学深度学习(PyTorch实现)(五)--多层感知机

    多层感知机 1. 基本知识 2. 激活函数 2.1 ReLU函数 2.2 Sigmoid函数 2.3 tanh函数 2.4 关于激活函数的选择 3. PyTorch实现 3.1 导入相应的包 3.2 ...

  8. DeepLearning tutorial(3)MLP多层感知机原理简介+代码详解

    FROM:http://blog.csdn.net/u012162613/article/details/43221829 @author:wepon @blog:http://blog.csdn.n ...

  9. MLP多层感知机(人工神经网络)原理及代码实现

    一.多层感知机(MLP)原理简介 多层感知机(MLP,Multilayer Perceptron)也叫人工神经网络(ANN,Artificial Neural Network),除了输入输出层,它中间 ...

  10. 什么是深度学习?kears简介,深度学习常用的三大模型,MLP(多层感知机),CNN(卷积神经网络),RNN(循环神经网络)

    什么是深度学习? 简单理解深度学习就是人类容易做的事情,机器不容易完成的事情.(实例:人脸识别,这个例子很好的证明了这句话.假如你识别一个人 ,今天这个人长这个样子,明天脸上有一块伤口,我们人是不是还 ...

最新文章

  1. Q币才是腾讯真正的世界级产品
  2. OKR的本质是什么?目标如何制定?
  3. 2022年度BCI奖 |THE ANNUAL BCI AWARD
  4. 压缩感知(II) A Compressed Sense of Compressive Sensing (II)
  5. 清华大学计算机系毕业季博论 | 预荐未来的自己
  6. oracle 删除表中重复记录,并保留一条
  7. html的组织顺序是什么,css如何组织?
  8. avue下拉框中属性可以显示,但不能选中
  9. golang 锁的使用
  10. JS日期比较 2013-01-31大于2013-02-01
  11. Kubernetes中配置Pod的liveness和readiness探针
  12. 徐耀赐:道路安全——交通安全会议整理稿(1)
  13. 优化器-SQL语句分析与优化
  14. 我最欣赏的一句话:天道酬勤
  15. 漫画:什么是ConcurrentHashMap?
  16. 苹果ipa 安卓apk 和APPX 安全扫码和分析平台
  17. 转载 解密蓝牙mesh系列 | 第五篇 【好友(Friend)和低功耗节点(LPN)】【友谊(Friendship)参数】【友谊建立】【友谊(Friendship)消息传送】【安全性】【友谊终止】
  18. ms17010利用失败解决一则
  19. 后端质料springboot
  20. Java实现QQ窗口自动输入

热门文章

  1. windows查看端口号占用
  2. NSIS UI 美化类插件分享
  3. 前后台交互:跨域以及PHP与Ajax的配合使用
  4. python找数字程序_程序以查找Python中从1到N的所有缺失数字
  5. 利用python提取abaqus节点坐标的脚本_用于在Abaqus中提取结点力的Python程序
  6. 且用计算机语言怎么表示,用计算机语言表示算法.doc
  7. linux网络流量监测工具,linux下网络流量监控工具
  8. html等待图片全部加载,imgLoad等待图片资源加载完成后执行函数(图片预加载)...
  9. oracle 服务名丢失,win2003 oracle服务丢失后恢复的一个例子
  10. mysql第五章项目二_高性能MySQL笔记 第5章 创建高性能的索引