文章目录

  • 1.Softmax层
    • 1.1softmax的函数表示
    • 1.2 损失函数
  • 2. 代码实现

1.Softmax层

当需要多分类的时候,会输出一个分布,这些分布需要满足P(y = i) >=0 和 所有的P值加起来=1,使用softmax可以实现。

要注意的是,softmax本质上和sigmoid一样也是一个激活函数。
sigmoid用于二分类,softmax用于多分类。

1.1softmax的函数表示


示例

1.2 损失函数



关于代码中的ToTensor

2. 代码实现

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optimbatch_size = 64
transform = transforms.Compose([transforms.ToTensor(),  # 转为张量# 归一化,切换到01分布进行训练(神经网络更适用),两个值分别是均值和方差,用于进行分布转换transforms.Normalize((0.1307, ), (0.3081, ))  # 注意这里说的不是取值范围,而是以0为均值,1为标准差的分布
])train_dataset = datasets.MNIST(root='/Users/yahoo/Downloads',train=True,download=False,transform=transform
)train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size
)test_dataset = datasets.MNIST(root='/Users/yahoo/Downloads',train=True,download=False,transform=transform
)test_loader = DataLoader(test_dataset,shuffle=True,batch_size=batch_size
)class Net(torch.nn.Module):def __init__(self):super().__init__()self.l1 = torch.nn.Linear(784, 512)  # 每层都是全连接self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self,x):x = x.view(-1,784)x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x)        # 最后一层不激活model = Net()criterion = torch.nn.CrossEntropyLoss()  # 打包好的交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) # 带冲量效果更好,可以冲破局部最小值,尽可能找到全局最优解def train(epoch):running_loss = 0.0for batch_idx,data in enumerate(train_loader, 0):inputs, target = dataoptimizer.zero_grad() # 每轮先置0# forward + backward + updateoutputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx+1, running_loss / 300))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():  # 表明不用计算梯度for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1)  # 横向为第一个维度,意指从每行中找出最大值及其下标total += labels.size(0)correct += (predicted == labels).sum().item()  # 比较下标与预测值的结果是否接近,求和即是看我们猜对了多少个print('Accuracy on test set: %d %%' % (100 * correct / total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()

输出结果

[1,   300] loss: 2.213
[1,   600] loss: 0.892
[1,   900] loss: 0.446
Accuracy on test set: 89 %
[2,   300] loss: 0.313
[2,   600] loss: 0.268
[2,   900] loss: 0.224
Accuracy on test set: 94 %
···························
[9,   300] loss: 0.044
[9,   600] loss: 0.040
[9,   900] loss: 0.043
Accuracy on test set: 99 %
[10,   300] loss: 0.035
[10,   600] loss: 0.033
[10,   900] loss: 0.033
Accuracy on test set: 99 %

【PyTorch深度学习实践】08_Softmax分类器(多分类)相关推荐

  1. PyTorch深度学习实践概论笔记9-SoftMax分类器

    上一讲PyTorch深度学习实践概论笔记8-加载数据集中,主要介绍了Dataset 和 DataLoader是加载数据的两个工具类.这一讲介绍多分类问题如何解决,一般会用到SoftMax分类器. 0 ...

  2. 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)

    从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...

  3. 《PyTorch深度学习实践》

    [<PyTorch深度学习实践>完结合集] https://www.bilibili.com/video/BV1Y7411d7Ys/?share_source=copy_web&v ...

  4. 【刘二大人】PyTorch深度学习实践

    文章目录 一.overview 1 机器学习 二.Linear_Model(线性模型) 1 例子引入 三.Gradient_Descent(梯度下降法) 1 梯度下降 2 梯度下降与随机梯度下降(SG ...

  5. 【Pytorch深度学习实践】B站up刘二大人之BasicCNN Advanced CNN -代码理解与实现(9/9)

    这是刘二大人系列课程笔记的 最后一个笔记了,介绍的是 BasicCNN 和 AdvancedCNN ,我做图像,所以后面的RNN我可能暂时不会花时间去了解了: 写在前面: 本节把基础个高级CNN放在一 ...

  6. PyTorch深度学习实践

    根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...

  7. 【Pytorch深度学习实践】B站up刘二大人之SoftmaxClassifier-代码理解与实现(8/9)

    这是刘二大人系列课程笔记的倒数第二个博客了,介绍的是多分类器的原理和代码实现,下一个笔记就是basicCNN和advancedCNN了: 写在前面: 这节课的内容,主要是两个部分的修改: 一是数据集: ...

  8. 《PyTorch 深度学习实践》第10讲 卷积神经网络(基础篇)

    文章目录 1 卷积层 1.1 torch.nn.Conv2d相关参数 1.2 填充:padding 1.3 步长:stride 2 最大池化层 3 手写数字识别 该专栏内容为对该视频的学习记录:[&l ...

  9. 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归

    刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...

  10. 【PyTorch深度学习实践】P9 kaggle otto商品分类作业(含注释)

    <PyTorch深度学习实践>-刘二大人 Otto Group Product Classification作业 将商品进行十分类,输入为93个特征10个类别的商品数据集,输出为预测数据集 ...

最新文章

  1. [转]Android横竖屏切换解决方案
  2. 深浅复制的的理解与区别
  3. c语言程序怎么颠倒数据,急求如何将下列C语言程序数据存储到文件中?
  4. CodeIgniter模型
  5. Spring抽取jdbc配置文件
  6. 一个好用的Visual Studio Code扩展 - Live Server,适用于前端小工具开发
  7. layui删除后刷新表格_LayUi前端框架删除数据缓存问题(解决删除后刷新页面内容又会显示问题)...
  8. Lucene.net 下载地址
  9. strong vs copy
  10. T-SQL命令性能比较– NOT IN与SQL NOT EXISTS与SQL LEFT JOIN与SQL EXCEPT
  11. FFmpeg之wav转mp3(二十四)
  12. jquery监听html状态,jquery监听页面刷新
  13. linux系统怎样挂载虚拟盘,linux 应用盘(从盘)挂载方法linux操作系统 -电脑资料...
  14. PowerDesigner(CDM—PDM—SQL脚本的转换流程) 随笔
  15. 梦飞苍穹c语言答案,梦飞仙途-楔子一  决战苍穹之巅-汤圆创作
  16. 更改WPS云文档保存位置
  17. 【Python】Matplotlib画图(七)——线的颜色、点的形状
  18. 我的世界刷猪人塔java版_我的世界速攻猪人塔详解 史上最牛的经验塔
  19. js 时间转东八区_js:固定与东八区服务器时间保持一致并且可选时间格式
  20. stm32f103c8t6的中文字库

热门文章

  1. 主要的CMS(内容管理系统)提供商
  2. MySQL全版本安装步骤
  3. 计算机找不到def,我打开计算机,发现缺少def驱动器. C盘发生了什么?如何解决def驱动器消失的问题?...
  4. 虚拟机ubuntu占用CPU过高
  5. 刚下载的xshell无法使用解决办法--实用--绝对靠谱
  6. 中小企业服务器虚拟化部署方案:规划备份和灾难恢复
  7. ffmpeg录制系统声音
  8. PID控制参数整定(调节方法)原理+图示+MATLAB调试
  9. CSR8670/8675 发射(TX SOURCE)一拖二 编码 格式APTX APTXLL APTXHD SBC
  10. 工具分享之截图软件Snipaste