名字叫做回归实际是个多分类算法

在一个全连接层之后连接softmax得到属于每个类的概率,softmax就是自己除以所有的和,使得所有项加起来等于1

使用是FashionMNIST数据集

import torch
from torch import nn
import torchvision
from torch.utils import data
from torchvision import transforms
from numpy import mean''' 1.下载数据集 '''
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=False)
# for i in mnist_train和mnist_test中  i是一个图片和标签的元组batch_size = 256
# num_workers 用几个线程来读数据
train_iter = data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_iter = data.DataLoader(mnist_test, batch_size=batch_size, shuffle=True)''' 2.定义模型 '''
# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)############### *****特别注意******** ###############
#PyTorch的Cross Entropy Loss与其它框架的不同,因为PyTorch中该损失函数其实自带了“nn.LogSoftmax”与“nn.NLLLoss”两个方法。
# 因此,在PyTorch的Cross Entropy Loss之前请勿再使用Softmax方法!
# loss = nn.CrossEntropyLoss(reduction='none')
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
num_epochs = 10''' 3.训练模型 '''
for i in range(num_epochs):L1 = []for X, y in train_iter:l = loss(net(X), y)L1.append(l.item())trainer.zero_grad()l.backward()trainer.step()L2 = []for X, y in test_iter:l = loss(net(X), y)L2.append(l.item())print(f'epoch: {i+1}, 训练集loss: {mean(L1)}  ,测试集loss: {mean(L2)}')

开始时在 nn.Linear 之后加了一个 nn.softmax ,得到的损失奇奇怪怪,查了一下才知道:PyTorch的Cross Entropy Loss与其它框架的不同,因为PyTorch中该损失函数其实自带了“nn.LogSoftmax”与“nn.NLLLoss”两个方法。 因此,在PyTorch的Cross Entropy Loss之前请勿再使用Softmax方法!

训练结果:

2 softmax回归实现相关推荐

  1. Softmax 回归 vs. k 个二元分类器

    如果你在开发一个音乐分类的应用,需要对k种类型的音乐进行识别,那么是选择使用 softmax 分类器呢,还是使用 logistic 回归算法建立 k 个独立的二元分类器呢? 这一选择取决于你的类别之间 ...

  2. 【深度学习】基于Pytorch的softmax回归问题辨析和应用(一)

    [深度学习]基于Pytorch的softmax回归问题辨析和应用(一) 文章目录 1 概述 2 网络结构 3 softmax运算 4 仿射变换 5 对数似然 6 图像分类数据集 7 数据预处理 8 总 ...

  3. 【深度学习】基于Pytorch的softmax回归问题辨析和应用(二)

    [深度学习]基于Pytorch的softmax回归问题辨析和应用(二) 文章目录1 softmax回归的实现1.1 初始化模型参数1.2 Softmax的实现1.3 优化器1.4 训练 2 多分类问题 ...

  4. Softmax回归——logistic回归模型在多分类问题上的推广

    Softmax回归 Contents [hide] 1 简介 2 代价函数 3 Softmax回归模型参数化的特点 4 权重衰减 5 Softmax回归与Logistic 回归的关系 6 Softma ...

  5. DeepLearning tutorial(1)Softmax回归原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43157801 DeepLearning tutorial(1)Softmax回归原理简介 ...

  6. Logistic and Softmax Regression (逻辑回归和Softmax回归)

    1. 简介 逻辑回归和Softmax回归是两个基础的分类模型,虽然听名字以为是回归模型,但实际我觉得他们也有一定的关系.逻辑回归,Softmax回归以及线性回归都是基于线性模型,它们固定的非线性的基函 ...

  7. 简单探索MNIST(Softmax回归和两层CNN)-Tensorflow学习

    简述 这次是在看<21个项目玩转深度学习>那本书的第一章节后做的笔记. 这段时间,打算把TensorFlow再补补,提升一下技术水平~ 希望我能坚持下来,抽空把这本书刷下来吧~ 导入数据 ...

  8. 2.3.3 Softmax回归介绍

    Softmax回归 到现在我们的分类问题都只能识别0和1的问题,比如说是猫或者不是猫,那么更复杂的问题怎么办呢?Softmax回归就能让你在多分类中识别是哪一个分类,而不只是二分类中识别. 如图所示, ...

  9. 3.8 Softmax 回归-深度学习第二课《改善深层神经网络》-Stanford吴恩达教授

    ←上一篇 ↓↑ 下一篇→ 3.7 测试时的 Batch Norm 回到目录 3.9 训练一个 Softmax 分类器 Softmax 回归 (Softmax Regression) 到目前为止,我们讲 ...

  10. 添加softmax层_PyTorch入门之100行代码实现softmax回归分类

    本文首发于公众号[拇指笔记] 1. 使用pytorch实现softmax回归模型 使用pytorch可以更加便利的实现softmax回归模型. 1.1 获取和读取数据 读取小批量数据的方法: 首先是获 ...

最新文章

  1. 安卓手机兼容_重磅:鸿蒙OS2.0手机开发者Beta版发布,能兼容安卓
  2. elasticsearch 结构化搜索_在案例中实战基于range filter来进行范围过滤
  3. java线程池有哪几种,真香系列
  4. jquery实现截取pc图片_如何优雅的对网页截取长图
  5. bfs:01迷宫(洛谷P1141)
  6. 控件注册 - 利用资源文件将dll、ocx打包进exe文件(转)
  7. 如何用多线程方式,提高rabbitmq消息处理效率?
  8. 从零打造一个程序员的mac
  9. STM32 AES 加解密流程梳理
  10. 通过Ajax进行POST提交JSON类型的数据到SpringMVC Controller的方法
  11. 一次性奖励300万?成都市武侯区促进文化产业发展系列政策影视产业专项政策出来了
  12. 学习计算机基础必读的4本经典入门书籍,自学编程必备书单!
  13. You have not concluded your merge (MERGE_HEAD exists).
  14. 面试时被问“你的缺点是什么?”,这么答就对了
  15. CCS_3.3.83.20的安装步骤
  16. 【web自动化测试】
  17. java编写活期储蓄帐目管理_活期储蓄账目管理系统
  18. 图像处理方向常用网站
  19. 远程命令/代码执行漏洞(RCE)总结
  20. 剑指offer--20.顺时针打印矩阵

热门文章

  1. 数学建模国赛:python机器学习基础之数据归一化、去除空值
  2. 23、数据结构中的绝代双骄(2)链表
  3. 那么如何让你的 JS 写得更漂亮?
  4. python 数据分析(六)astype('category')按类别分组 + 分组聚合操作 + 透视表 + 交叉表 + excel表的数据处理
  5. mvc 当中 [ValidateAntiForgeryToken] 的作用及用法
  6. UNIAPP实战项目笔记45 订单页面布局完成和数据渲染
  7. 【优化算法】多目标蝗虫优化算法(MOGOA)
  8. 物联网通信技术|课堂笔记week2-1|Linux网络管理基本命令
  9. open函数的参数说明
  10. 文本分类的特征提取算法