分类-Classification

  • 示例代码

两大基础任务之一的分类任务,根据 datatargets进行训练,之后对目标进行预判。
依旧是一样,根据视频来编写的代码,会使你看视频更加容易,都有注释,哈哈哈哈哈

示例代码

import torch as t
import torch.nn.functional as tnf
import matplotlib.pyplot as plt# 生成一个全是1的二维tensor,每个维度参数为(x,y)
n_data = t.ones(100, 2)
# 参数含义为(均值,标准差),根据这个正态分布随机生成和上面相同类型的tensor
x0 = t.normal(2*n_data, 1)
# 生成一个全是0的1维tensor,维度参数为100
y0 = t.zeros(100)
x1 = t.normal(-2*n_data, 1)
y1 = t.ones(100)
# 根据dim参数将两个tensor拼接起来,这里的话就是将两个x——tensor在一维进行拼接(200,2)
x = t.cat((x0, x1), 0).type(t.FloatTensor)
# y是一维tensor,所以直接拼接就好了
y = t.cat((y0, y1), ).type(t.LongTensor)# x_tensor是一个二维张量,这里转化为numpy二维数组,取一维所有数据,二维中的第一个量为x轴,第二个量为y轴
# plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
# plt.show()# unsqueeze大概可以理解为将一维tensor转化为二维tensor
# 最主要的是如果听信弹幕屁话,这里不用这个东西,直接把一维tensor代入就会报错mismatch
# x = t.unsqueeze(t.linspace(-1, 1, 100), dim=1)
# y = x.pow(2) + 0.2 * t.rand(x.size())# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()class Net(t.nn.Module):def __init__(self, n_feature, n_hidden, n_output):# 继承父类初始方法super().__init__()# 创建一层hidden神经网络,向后接收n_features个输入,由n_hidden个神经元组成self.hidden = t.nn.Linear(n_feature, n_hidden)# 创建一层predict神经网络,向后接收n_hidden个输入,由n_output个神经元组成self.predict = t.nn.Linear(n_hidden, n_output)# 需要注意的是,这里只是定义,并不代表搭建,并没有构建好连接。def forward(self, x):# x为输入的data# 神经元之间的传递是线性的,非线性是由激励函数实现的# 让数据先进入hidden层进行处理,然后使用激励函数relu进行非线性处理x = tnf.relu(self.hidden(x))# 再让其进入predict层进行处理x = self.predict(x)return x# 运用多分类,这次输入是两个特征,输出为二分类,[0,1]和[1,0]
net = Net(2, 10, 2)plt.ion()
plt.show()# 使用optimizer来进行优化,一般选择SGD(随机梯度下降)
# 参数理解分别为(神经网络里面所有参数,学习效率)
# 学习效率一般小于1,差不多可以理解为梯度下降的步长,大了就会错失一些细节点
optimizer = t.optim.SGD(net.parameters(), lr=0.02)
# 计算误差,以下为均方差的方法
# loss_func = t.nn.MSELoss()
loss_func = t.nn.CrossEntropyLoss()for i in range(500):# 预测----计算丢失函数----梯度清零----向后传播----优化器优化# 至于为什么要梯度清零,我现在也没搞懂,这啥玩意。out = net(x)# prediction = tnf.softmax(prediction)# 这里y应该是一个一维张量,但是prediction应该是二维的,为什么能够计算?loss = loss_func(out, y)optimizer.zero_grad()loss.backward()optimizer.step()if i % 2 == 0:plt.cla()# max函数会返回tensor里面关于最大值的信息,返回两个,第一个为最大值的值,第二个为最大值的索引# max函数里面参数为(tensor, 0/1) 0代表每列最大值, 1代表每行最大值# softmax会产生warning,更新后需要设置softmax的对象维度。prediction = t.max(tnf.softmax(out, dim=1), 1)[1]pre_y = prediction.data.numpy()target_y = y.data.numpy()plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pre_y, s=100, lw=0, cmap='RdYlGn')accuracy = sum(pre_y==target_y)/200plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})plt.pause(0.1)# plt.scatter(x.data.numpy(), y.data.numpy())# plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)# plt.text(0.5, 0, 'Loss=%.4f' % loss.data, fontdict={'size': 20, 'color': 'red'})# plt.pause(0.1)plt.ioff()
plt.show()

从零学习PyTorch(5)----整个天空都是灰蒙蒙的相关推荐

  1. 从零学习pytorch 第2课 Dataset类

    课程目录(在更新,喜欢加个关注点个赞呗): 从零学习pytorch 第1课 搭建一个超简单的网络 从零学习pytorch 第1.5课 训练集.验证集和测试集的作用 从零学习pytorch 第2课 Da ...

  2. 从零学习pytorch 第1课 搭建一个超简单的网络

    课程目录(在更新,喜欢加个关注点个赞呗): 从零学习pytorch 第1课 搭建一个超简单的网络 从零学习pytorch 第1.5课 训练集.验证集和测试集的作用 从零学习pytorch 第2课 Da ...

  3. 纽约大学深度学习PyTorch课程笔记(自用)Week6

    纽约大学深度学习PyTorch课程笔记Week6 Week 6 6.1 卷积网络的应用 6.1.1 邮政编码识别器 使用CNN进行识别 6.1.2 人脸检测 一个多尺度人脸检测系统 6.1.3 语义分 ...

  4. pytorch 训练过程acc_深度学习Pytorch实现分类模型

    今天将介绍深度学习中的分类模型,以下主要介绍Softmax的基本概念.神经网络模型.交叉熵损失函数.准确率以及Pytorch实现图像分类.01Softmax基本概念 在分类问题中,通常标签都为类别,可 ...

  5. 伯禹公益AI《动手学深度学习PyTorch版》Task 07 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 07 学习笔记 Task 07:优化算法进阶:word2vec:词嵌入进阶 微信昵称:WarmIce 优化算法进阶 emmmm,讲实 ...

  6. 伯禹公益AI《动手学深度学习PyTorch版》Task 03 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 03 学习笔记 Task 03:过拟合.欠拟合及其解决方案:梯度消失.梯度爆炸:循环神经网络进阶 微信昵称:WarmIce 过拟合. ...

  7. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  8. 纽约大学深度学习PyTorch课程笔记(自用)Week3

    纽约大学深度学习PyTorch课程笔记Week3 Week 3 3.1 神经网络参数变换可视化及卷积的基本概念 3.1.1 神经网络的可视化 3.1.2 参数变换 一个简单的参数变换:权重共享 超网络 ...

  9. 【分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人(8/10)】

    分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人 (8/10) 在进行本章的数学推导前,有必要先粗浅的介绍一下,笔者在广泛查找 ...

最新文章

  1. web工程中spring+ibatis的单元测试--转载
  2. vscode解决java无法输入(scanner)问题
  3. Quartz.net 开源job调度框架(一)
  4. Mybatis映射文件!CDATA[[]] 转义问题
  5. 【机器视觉】 until算子
  6. android 中 Proguard 和JNI 相关
  7. python函数-基础知识
  8. 微服务之springCloud-docker-feign配置(五)
  9. Unity3D 4.x怎样实现动画的Ping Pong效果
  10. NIO Channel Scatter/Gather 管道Pipe类
  11. oracle索引实验报告,Oracle之索引(Index)实例讲解
  12. linux禁止root用户远程登录,linux禁止root用户远程登录
  13. Python使用pytesseract进行验证码图像识别
  14. 手把手教你一个321MB的视频,如何压缩到300MB以内?
  15. c语言运行太短怎么毡筒,C语言程序设计 最简单的C程序设计.ppt
  16. 游戏制作人谈10大开发经验
  17. ElementUI Collapse 折叠面板
  18. RadStudio 10.3.3 Rio (Delphi C++ Builder)及TMS TAdvStringGrid控件安装方法
  19. 嫡权法赋权法_Python实现客观赋权法
  20. gin-vue-admin 使用docker容器中的数据库

热门文章

  1. echart折线图连线不显示问题总结
  2. 【渝粤题库】广东开放大学 管理学原理 形成性考核
  3. Java基础篇--继承(inherit),多态(Polymorphism)
  4. 微信零钱提现还要手续费?不存在的
  5. 第十一届蓝桥杯大赛软件类省赛第二场 C/C++ 大学 B 组
  6. 医疗器械经营与服务类毕业论文文献有哪些?
  7. 关于pip 的依赖项解析器当前未考虑安装的所有包。此行为是以下依赖项冲突的根源。
  8. 川崎机器人 AS语言基础运动指令表
  9. 2022重庆幼教产业展览会|高科技玩具益智解压玩具博览会
  10. 计算机视觉 马尔_基于视觉AI的智能车牌识别相机,识别更精准功能更强大