多分类问题softmax的分类器

为什么要探索多分类

之前我们在处理糖尿病数据集的时候我们只是有两种分类,但是很多情况的数据集不只有两种,例如MNIST数据集就是手写数字的数据集有10种不同的标签。所以我们必须有处理多种分类标签的能力。

探索多分类

是否还可以使用二分类的操作?

当然还是可以使用二分类的方法来解决这个问题,某分类设置位p=1其他全部p=0就可以了,还是使用交叉熵损失函数来处理。

这里我们要注意到,我们的样本必须是只有一个选择的,所以我们的输出数据当中只能有一个输出的数据比较大,要对其他形成抑制,或者描述为所有的输出的和必须是1。但是我们使用上面的方法并不能满足,甚至所有的分类的输出都是0.8或0.9这种。
如果我们这样做呢?将十种输出的最后一种转化为1减去其他输出,这样是否可行呢?实际上也并不行,因为我们如果这样处理将会导致十种输出的结果过程并不相同,导致系统的并行能力下降,使得整体的效率下降。

那该怎么实现?

所以我们的输出应当是一个分布,我们前面的一些层还是可以使用sigmoid来做一个层,最后我们要用一个特殊层来完成一个将原有输出转化位一个分布的操作。
我们看一下我们需要做到什么内容:
1.输出的内容和是一个1 。
2.输出的内容都在0-1之间。
我们可以想到第一个内容比较好实现我们只需要让输出是一个分数就行了。分母设置成一样的内容,分子的和等于这个分母。这是十分容易的。
之后我们再来分析第二个,想要让输出是0-1,在我们已经做到第一个要求的情况下,我们想做到第二个,其实只需要所有的输出全部都是大于0的就可以完成了。这样我们就想到了一个东西指数函数,指数函数的值永远是大于0的。这样问题就解决了。
所以我们只需要使用一个softmax层,这个层的运算如下:


这样就可以完成我们的需求。
我们来看一个实际的例子就更好理解了:

这样我们就理解了softmax层

那么损失函数又该怎么做

我们实际上还是需要使用交叉熵损失函数,我们看一下交叉熵损失函数,其实是什么情况,我们知道交叉熵损失函数是根据其中的概率进行计算的,因为我们是打标签,除了1的就全是0所以我们需要计算的内容其实可以简化:
所以其实loss函数就发生了变化:

整体的情况就变成了如下的情况:

这个其实就是NLL损失函数:

这个是理论上的情况,但是我们在pytorch当中其实并没有这么复杂,我们可以直接使用交叉熵损失函数来将从softmax直接一下子全部包含了。注意我们最后一层是不需要使用激活的,最后一层其实是交叉熵损失函数直接替我们完成了这个操作。

这里一个关键是一定要区分清楚NLL损失和交叉熵损失。

图像张量的问题

如果一个灰度图就是一个单通道的图,我们平时看到的彩色图像其实是三通道(Channel)的RGB,所以我们在表示图像张量的时候一般是whc就是(宽度*高 *通道)
为什么我们要将数据转化位N(0,1)因为这种标准分布的数据对于神经网络来说是最好的,对于神经网络来说训练效果是最好的

什么时候进行transform的问题

我们实际上是将transform定义在了datasets的位置。使用的下面的代码:

trian_set = datasets.MNIST(root='./data',train=True,download=True,transform=transform)

我们注意这样子的一个问题,这个MNIST数据集毕竟是一个图片,其实占地方不小的,所以我们不能直接将其读入内存,所以我们不能读进来再transform,需要每次从文件读进来一个之后再进行transform。

构建神经网络

因为我们使用的是全连接神经网络,所以我们最后输入的一个数据应当是一个矩阵(二维的)但是我们这里的输入实际上是一个四维的输入,数据量*通道数*宽*高,所以我们要转化一下,这里就需要使用view函数了,具体使用下面的函数

x= x.view(-1,784)

之后我们画一个图,写模型就好写了

代码实现

import numpy
import torch
from torchvision import transforms
from torchvision import datasets
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader#这里我们要理解这个from的作用,使用一个from#transeforms 主要是用来做图像处理
# optim主要是包含优化器的
#首先明白下面这个东西是个什么东西?
#这个东西是个转化器,可以对输入的内容做我们规定好的操作,
# 具体操作就是下面[]中定义的。
#这里注意一个细节transforms和transform的区别,带s的表示是一片,不带s的是一个,我们实例化的一个
batch_size=64
transform= transforms.Compose([#这个首先转化为一个张量transforms.ToTensor(),#为了更好的学习效果我们需要一个标准化的过程,前一个参数是均值后一个是方差# 这两个参数都是MNIST数据集使用的,如果是自己的数据集这个要算一下。transforms.Normalize((0.1307,),(0.3081,))
])
#这里其实和之前的一样的,只是加了一个transform=的参数
trian_set = datasets.MNIST(root='./data',train=True,download=True,transform=transform)
trian_loader= DataLoader(trian_set,shuffle=True,batch_size=batch_size)
test_set =datasets.MNIST(root='./data',train=False,download=True,transform=transform)
test_loader= DataLoader(test_set,shuffle=True,batch_size=batch_size)
class myModol(torch.nn.Module):def __init__(self):super(myModol,self).__init__()self.l1=torch.nn.Linear(784,512)self.l2=torch.nn.Linear(512,256)self.l3=torch.nn.Linear(256,128)self.l4=torch.nn.Linear(128,64)self.l5=torch.nn.Linear(64,10)def forward(self,x):#前面的-1代表着看着情况进行变化。#这里我们注意我们的这里输入的784不是随便输入的,# 是有实际意义的,MNIST的数据集的输入是1*28*28的,# 所以我们想要转化为二维的时候,一定要尊重原有的实际意义x=x.view(-1,784)x=F.relu(self.l1(x))x=F.relu(self.l2(x))x=F.relu(self.l3(x))x=F.relu(self.l4(x))return self.l5(x)#因为我们最后使用的是交叉熵损失函数已经将激活层抱进去了不用再单独写了
model=myModol()
#使用一个交叉熵损失函数
criterion=torch.nn.CrossEntropyLoss()
#因为我们计算量已经比较大了所以要使用一个带有冲量的优化器。
opminster= optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
def train(epoch):train_loss=0.0for batch_idx,data in enumerate(trian_loader,0):m_input,target=dataopminster.zero_grad()outputs= model(m_input)loss=criterion(outputs,target)loss.backward()opminster.step()train_loss+=loss.item()if batch_idx%300==299:#每300次才输出一次,另外因为我们都是从0开始计数所以我们需要+1print('[%d,%5d]loss:%.3f'%(epoch+1,batch_idx+1,train_loss/300))train_loss=0
def test():#定义几个计数变量correct=0total=0#因为我们在计算损失的时候是不需要进行梯度计算的,# 所以这里取消梯度计算来增加速度。with torch.no_grad():for data in test_loader:images,targets=dataoutputs=model(images)#这里我们注意我们在输出的时候是一个最大值,一个最大值角标,我们只需要最大值角标,# 这里我们注意dim=1的这个问题,dim=0代表每列找一个,dim=1代表每行找。_,predicted=torch.max(outputs.data,dim=1)#这里我们注意一个问题就是返回的形式是torch.Size,# 这个玩意是一个元组,这里是(行数,列数)所以我们要取出来第0个。total+=targets.size(0)#我们注意这里张量比较的使用问题correct+=(predicted==targets).sum().item()print('acc on test set:%d%%'%(100*correct/total))
if __name__=='__main__':for epoch in range(10):train(epoch)test()
#但是我们这里的全连接层得到的准确度并不高,因为我们是将图上的所有信息都全部利用了
#其实决定图片的数字到底是多少并不是由全部的情况决定的。

pytorch的多分类问题相关推荐

  1. LESSON 10.110.210.3 SSE与二分类交叉熵损失函数二分类交叉熵损失函数的pytorch实现多分类交叉熵损失函数

    在之前的课程中,我们已经完成了从0建立深层神经网络,并完成正向传播的全过程.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具AutoGrad ...

  2. pytorch实现文本分类_使用变形金刚进行文本分类(Pytorch实现)

    pytorch实现文本分类 'Attention Is All You Need' "注意力就是你所需要的" New deep learning models are introd ...

  3. Pytorch搭建常见分类网络模型------VGG、Googlenet、ResNet50 、MobileNetV2(4)

    接上一节内容:Pytorch搭建常见分类网络模型------VGG.Googlenet.ResNet50 .MobileNetV2(3)_一只小小的土拨鼠的博客-CSDN博客 mobilenet系列: ...

  4. pytorch对MNIST分类

    深度学习 基础知识和各种网络结构实战 ... pytorch对MNIST分类 深度学习 前言 一.导入第三方库 二.下载MNIST数据集 三.创建神经网络模型 四.训练数据集 五.测试 完整代码 总结 ...

  5. Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练

    pytorch进行CIFAR-10分类(4)训练 我的系列博文: Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理 Pytorch打怪路(一)pyt ...

  6. pytorch bert文本分类_一起读Bert文本分类代码 (pytorch篇 四)

    Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了.这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码exa ...

  7. pytorch 三维点分类_用于RGBD语义分割的三维图神经网络(2017ICCV,已开源)

    3D Graph Neural Networks for RGBD Semantic Segmentation(2017ICCV, citation:78) 开源地址:https://github.c ...

  8. pytorch focalloss多分类 单分类

    代码还没测: focal_loss 多类别和二分类 Pytorch代码实现 本文链接:https://blog.csdn.net/qq_33278884/article/details/9157217 ...

  9. pytorch 三维点分类_三维点云分类与分割-PointNet

    PointNet是对点云数据直接进行学习的开山之作, 这里结合PointNet-Pytorch代码,对PointNet网络结构与其思想进行阐述和分析. 点云数据的特性: 点云数据不同于图像数据,他有三 ...

最新文章

  1. MPB:北大口腔陈峰、陈智滨等-口腔常见微生物的培养方法
  2. 使用PyTorch训练图像分类器
  3. reactjs redux chrome扩展插件
  4. 排序算法——各算法性能
  5. php select socket
  6. std::string格式化输入输出
  7. 用Python标准库turtle画一头金牛,祝您新年牛气冲天!
  8. 卷积神经网络 – CNN
  9. 【学习笔记】Unreal(虚幻)4引擎入门(三)
  10. 基于YOLOv5的血细胞识别和计数
  11. GNOME-Shell-Extensions开发经验(一)Hello,world!
  12. 在网页右下角添加一个卡通动漫人物
  13. 7-7 选民投票 (20分)(不区分大小写投票)
  14. ANSYS公开课圆满落幕
  15. 人类简史-读书笔记之历史演变图
  16. Python自动抢红包,从此再也不会错过微信红包了!
  17. 【译】Unity3D Shader 新手教程(1/6)
  18. APP 跳转微信小程序和回调
  19. 展望未来「编程之路起始篇」
  20. android studio重装后直接,【原创】重装Windows系统后Android studio无需重装,直接迁移...

热门文章

  1. 洛谷——P2256 一中校运会之百米跑
  2. 简单用数组模拟顺序栈(c++)
  3. Java数据结构——有序链表
  4. jquery技巧(持续更新。。)
  5. Lab_2 OSPF
  6. 主动触发被动模式从而挟持无线客户端 – Passive Karma Attack
  7. 计算机组成原理第3章-存储系统
  8. 删除win10自带的旧版edge浏览器(亲测有效)
  9. 用ajax更新div,如何使用ajax和jquery更新特定的div
  10. php删除记录前的判断弹窗,thinkPHP删除前弹出确认框的简单实现方法