【PyTorch深度学习实践】08_Softmax分类器(多分类)
文章目录
- 1.Softmax层
- 1.1softmax的函数表示
- 1.2 损失函数
- 2. 代码实现
1.Softmax层
当需要多分类的时候,会输出一个分布,这些分布需要满足P(y = i) >=0 和 所有的P值加起来=1
,使用softmax可以实现。
要注意的是,softmax本质上和sigmoid一样也是一个激活函数。
sigmoid用于二分类,softmax用于多分类。
1.1softmax的函数表示
示例
1.2 损失函数
关于代码中的ToTensor
2. 代码实现
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optimbatch_size = 64
transform = transforms.Compose([transforms.ToTensor(), # 转为张量# 归一化,切换到01分布进行训练(神经网络更适用),两个值分别是均值和方差,用于进行分布转换transforms.Normalize((0.1307, ), (0.3081, )) # 注意这里说的不是取值范围,而是以0为均值,1为标准差的分布
])train_dataset = datasets.MNIST(root='/Users/yahoo/Downloads',train=True,download=False,transform=transform
)train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size
)test_dataset = datasets.MNIST(root='/Users/yahoo/Downloads',train=True,download=False,transform=transform
)test_loader = DataLoader(test_dataset,shuffle=True,batch_size=batch_size
)class Net(torch.nn.Module):def __init__(self):super().__init__()self.l1 = torch.nn.Linear(784, 512) # 每层都是全连接self.l2 = torch.nn.Linear(512, 256)self.l3 = torch.nn.Linear(256, 128)self.l4 = torch.nn.Linear(128, 64)self.l5 = torch.nn.Linear(64, 10)def forward(self,x):x = x.view(-1,784)x = F.relu(self.l1(x))x = F.relu(self.l2(x))x = F.relu(self.l3(x))x = F.relu(self.l4(x))return self.l5(x) # 最后一层不激活model = Net()criterion = torch.nn.CrossEntropyLoss() # 打包好的交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) # 带冲量效果更好,可以冲破局部最小值,尽可能找到全局最优解def train(epoch):running_loss = 0.0for batch_idx,data in enumerate(train_loader, 0):inputs, target = dataoptimizer.zero_grad() # 每轮先置0# forward + backward + updateoutputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx+1, running_loss / 300))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad(): # 表明不用计算梯度for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, dim=1) # 横向为第一个维度,意指从每行中找出最大值及其下标total += labels.size(0)correct += (predicted == labels).sum().item() # 比较下标与预测值的结果是否接近,求和即是看我们猜对了多少个print('Accuracy on test set: %d %%' % (100 * correct / total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()
输出结果
[1, 300] loss: 2.213
[1, 600] loss: 0.892
[1, 900] loss: 0.446
Accuracy on test set: 89 %
[2, 300] loss: 0.313
[2, 600] loss: 0.268
[2, 900] loss: 0.224
Accuracy on test set: 94 %
···························
[9, 300] loss: 0.044
[9, 600] loss: 0.040
[9, 900] loss: 0.043
Accuracy on test set: 99 %
[10, 300] loss: 0.035
[10, 600] loss: 0.033
[10, 900] loss: 0.033
Accuracy on test set: 99 %
【PyTorch深度学习实践】08_Softmax分类器(多分类)相关推荐
- PyTorch深度学习实践概论笔记9-SoftMax分类器
上一讲PyTorch深度学习实践概论笔记8-加载数据集中,主要介绍了Dataset 和 DataLoader是加载数据的两个工具类.这一讲介绍多分类问题如何解决,一般会用到SoftMax分类器. 0 ...
- 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)
从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...
- 《PyTorch深度学习实践》
[<PyTorch深度学习实践>完结合集] https://www.bilibili.com/video/BV1Y7411d7Ys/?share_source=copy_web&v ...
- 【刘二大人】PyTorch深度学习实践
文章目录 一.overview 1 机器学习 二.Linear_Model(线性模型) 1 例子引入 三.Gradient_Descent(梯度下降法) 1 梯度下降 2 梯度下降与随机梯度下降(SG ...
- 【Pytorch深度学习实践】B站up刘二大人之BasicCNN Advanced CNN -代码理解与实现(9/9)
这是刘二大人系列课程笔记的 最后一个笔记了,介绍的是 BasicCNN 和 AdvancedCNN ,我做图像,所以后面的RNN我可能暂时不会花时间去了解了: 写在前面: 本节把基础个高级CNN放在一 ...
- PyTorch深度学习实践
根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...
- 【Pytorch深度学习实践】B站up刘二大人之SoftmaxClassifier-代码理解与实现(8/9)
这是刘二大人系列课程笔记的倒数第二个博客了,介绍的是多分类器的原理和代码实现,下一个笔记就是basicCNN和advancedCNN了: 写在前面: 这节课的内容,主要是两个部分的修改: 一是数据集: ...
- 《PyTorch 深度学习实践》第10讲 卷积神经网络(基础篇)
文章目录 1 卷积层 1.1 torch.nn.Conv2d相关参数 1.2 填充:padding 1.3 步长:stride 2 最大池化层 3 手写数字识别 该专栏内容为对该视频的学习记录:[&l ...
- 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归
刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...
- 【PyTorch深度学习实践】P9 kaggle otto商品分类作业(含注释)
<PyTorch深度学习实践>-刘二大人 Otto Group Product Classification作业 将商品进行十分类,输入为93个特征10个类别的商品数据集,输出为预测数据集 ...
最新文章
- [转]Android横竖屏切换解决方案
- 深浅复制的的理解与区别
- c语言程序怎么颠倒数据,急求如何将下列C语言程序数据存储到文件中?
- CodeIgniter模型
- Spring抽取jdbc配置文件
- 一个好用的Visual Studio Code扩展 - Live Server,适用于前端小工具开发
- layui删除后刷新表格_LayUi前端框架删除数据缓存问题(解决删除后刷新页面内容又会显示问题)...
- Lucene.net 下载地址
- strong vs copy
- T-SQL命令性能比较– NOT IN与SQL NOT EXISTS与SQL LEFT JOIN与SQL EXCEPT
- FFmpeg之wav转mp3(二十四)
- jquery监听html状态,jquery监听页面刷新
- linux系统怎样挂载虚拟盘,linux 应用盘(从盘)挂载方法linux操作系统 -电脑资料...
- PowerDesigner(CDM—PDM—SQL脚本的转换流程) 随笔
- 梦飞苍穹c语言答案,梦飞仙途-楔子一 决战苍穹之巅-汤圆创作
- 更改WPS云文档保存位置
- 【Python】Matplotlib画图(七)——线的颜色、点的形状
- 我的世界刷猪人塔java版_我的世界速攻猪人塔详解 史上最牛的经验塔
- js 时间转东八区_js:固定与东八区服务器时间保持一致并且可选时间格式
- stm32f103c8t6的中文字库
热门文章
- 主要的CMS(内容管理系统)提供商
- MySQL全版本安装步骤
- 计算机找不到def,我打开计算机,发现缺少def驱动器. C盘发生了什么?如何解决def驱动器消失的问题?...
- 虚拟机ubuntu占用CPU过高
- 刚下载的xshell无法使用解决办法--实用--绝对靠谱
- 中小企业服务器虚拟化部署方案:规划备份和灾难恢复
- ffmpeg录制系统声音
- PID控制参数整定(调节方法)原理+图示+MATLAB调试
- CSR8670/8675 发射(TX SOURCE)一拖二 编码 格式APTX APTXLL APTXHD SBC
- 工具分享之截图软件Snipaste