【小白学习PyTorch教程】七、基于乳腺癌数据集构建Logistic 二分类模型
「@Author:Runsen」
在逻辑回归中预测的目标变量不是连续的,而是离散的。可以应用逻辑回归的一个示例是电子邮件分类:标识为垃圾邮件或非垃圾邮件。图片分类、文字分类都属于这一类。
在这篇博客中,将学习如何在 PyTorch 中实现逻辑回归。
1. 数据集加载
在这里,我将使用来自 sklearn 库的乳腺癌数据集。这是一个简单的二元类分类数据集。从 sklearn.datasets 模块加载。接下来,可以使用内置函数从数据集中提取 X 和 Y,代码如下所示。
from sklearn import datasets
breast_cancer=datasets.load_breast_cancer()
x,y=breast_cancer.data,breast_cancer.target
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test= train_test_split(x,y,test_size=0.2)
在上面的代码中,测试大小表示要用作测试数据集的数据的比例。因此,80% 用于训练,20% 用于测试。
2. 预处理
由于这是一个分类问题,一个好的预处理步骤是应用标准的缩放器变换。
scaler=sklearn.preprocessing.StandardScaler()
x_train=scaler.fit_transform(x_train)
x_test=scaler.fit_transform(x_test)
现在,在使用Logistic 模型之前,还有最后一个关键的数据处理步骤。在Pytorch 需要使用张量。因此,我们使用“torch.from_numpy()”
方法将所有四个数据转换为张量。
在此之前将数据类型转换为 float32
很重要。可以使用“astype()”函数来做到这一点。
import numpy as np
import torch
x_train=torch.from_numpy(x_train.astype(np.float32))
x_test=torch.from_numpy(x_test.astype(np.float32))
y_train=torch.from_numpy(y_train.astype(np.float32))
y_test=torch.from_numpy(y_test.astype(np.float32))
我们知道 y 必须采用列张量而不是行张量的形式。因此,使用代码中所示的view
操作执行此更改。对 y_test 也做同样的操作。
y_train=y_train.view(y_train.shape[0],1)
y_test=y_test.view(y_test.shape[0],1)
预处理步骤完成,您可以继续进行模型构建。
3. 模型搭建
现在,我们已准备好输入数据。让我们看看如何在 PyTorch 中编写用于逻辑回归的自定义模型。第一步是用模型名称定义一个类。这个类应该派生torch.nn.Module
。
在类内部,我们有__init__
函数和 forward
函数。
class Logistic_Reg_model(torch.nn.Module):def __init__(self,no_input_features):super(Logistic_Reg_model,self).__init__()self.layer1=torch.nn.Linear(no_input_features,20)self.layer2=torch.nn.Linear(20,1)def forward(self,x):y_predicted=self.layer1(x)y_predicted=torch.sigmoid(self.layer2(y_predicted))return y_predicted
在__init__
方法中,必须在模型中定义所需的层。在这里,使用线性层,可以从 torch.nn 模块声明。需要为图层指定任何名称,例如本例中的“layer1”。所以,我已经声明了 2 个线性层。
语法为:torch.nn.Linear(in_features, out_features, bias=True)
接下来,也要有“forward()
”函数,负责执行前向传递/传播。输入通过之前定义的 2 个层。此外,第二层的输出通过一个称为 sigmoid
的激活函数。
激活函数用于捕捉线性数据中的复杂关系。在这种情况下,我们使用 sigmoid 激活函数。
在这种情况下,我们选择 sigmoid 函数的原因是它会将值限制为(0 到 1)。下面是 sigmoid 函数的图形及其公式
4. 训练和优化
定义类后,初始化模型。
model=Logistic_Reg_model(n_features)
现在,需要定义损失函数和优化算法。在 Pytorch 中,可以通过简单的步骤选择并导入所需的损失函数和优化算法。在这里,选择 BCE 作为我们的损失标准。
BCE代表二元交叉熵损失。它通常用于二元分类示例。值得注意的一点是,当使用 BCE 损失函数时,节点的输出应该在(0-1)之间。我们需要为此使用适当的激活函数。
对于优化器,选择 SGD 或随机梯度下降。SGD 算法,通常用作优化器。还有其他优化器,如 Adam、lars 等。
优化算法有一个称为学习率的参数。这基本上决定了算法接近局部最小值的速率,此时损失最小。这个值很关键。
因为如果学习率值太高,算法可能会突然出现并错过局部最小值。如果它太小,则会花费大量时间并且可能无法收敛。因此,学习率“lr”是一个超参数,应该微调到最佳值。
criterion=torch.nn.BCELoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
接下来,决定 epoch 的数量,然后编写训练循环。
number_of_epochs=100
for epoch in range(number_of_epochs):y_prediction=model(x_train)loss=criterion(y_prediction,y_train)loss.backward()optimizer.step()optimizer.zero_grad()if (epoch+1)%10 == 0:print('epoch:', epoch+1,',loss=',loss.item())
如果发生了第一次前向传播。接下来,计算损失。当loss.backward()
被调用时,它计算损失相对于(层的)权重的梯度。然后通过调用optimizer.step()更新权重。之后,必须为下一次迭代清空权重。因此调用 zero_grad
()方法。
计算准确度
with torch.no_grad():y_pred=model(x_test)y_pred_class=y_pred.round()accuracy=(y_pred_class.eq(y_test).sum())/float(y_test.shape[0])print(accuracy.item())# 0.92105
使用torch.no_grad(),目的是基跳过权重的梯度计算。所以,我在这个循环中写的任何内容都不会导致权重发生变化,因此不会干扰反向传播过程。
往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》课件合集
本站qq群851320808,加入微信群请扫码:
【小白学习PyTorch教程】七、基于乳腺癌数据集构建Logistic 二分类模型相关推荐
- 【小白学习PyTorch教程】十六、在多标签分类任务上 微调BERT模型
@Author:Runsen BERT模型在NLP各项任务中大杀四方,那么我们如何使用这一利器来为我们日常的NLP任务来服务呢?首先介绍使用BERT做文本多标签分类任务. 文本多标签分类是常见的NLP ...
- R语言构建文本分类模型:文本数据预处理、构建词袋模型(bag of words)、构建xgboost文本分类模型、基于自定义函数构建xgboost文本分类模型
R语言构建文本分类模型:文本数据预处理.构建词袋模型(bag of words).构建xgboost文本分类模型.基于自定义函数构建xgboost文本分类模型 目录
- 【小白学习PyTorch教程】十、基于大型电影评论数据集训练第一个LSTM模型
「@Author:Runsen」 本博客对原始IMDB数据集进行预处理,建立一个简单的深层神经网络模型,对给定数据进行情感分析. 数据集下载 here. 原始数据集,没有进行处理here. impor ...
- 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...
「@Author:Runsen」 上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. imp ...
- 【小白学习PyTorch教程】十七、 PyTorch 中 数据集torchvision和torchtext
@Author:Runsen 对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext. 之前使用 torchDataLoader类直接加载图像并将其转换为张量. ...
- 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度
@Author:Runsen 上次基于CIFAR-10 数据集,使用PyTorch 构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. im ...
- 【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类
「@Author:Runsen」 上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类. ResNet是 Residual Networks 的缩写,是一种经典的神经网络,用作许多计算 ...
- 【小白学习PyTorch教程】五、在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据
「@Author:Runsen」 有时候,在处理大数据集时,一次将整个数据加载到内存中变得非常难. 因此,唯一的方法是将数据分批加载到内存中进行处理,这需要编写额外的代码来执行此操作.对此,PyTor ...
- 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始构建图像分类模型...
「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...
最新文章
- Java中的访问权限
- 叮~ 量子位欢迎你加入AI群聊
- Gold Balanced Lineup - poj 3274 (hash)
- Linux磁盘格式化和挂载
- C指针原理(8)-C内嵌汇编
- 《Java多线程编程核心技术》读后感(十五)
- 如何处理Express异常?
- 劳力埃大学计算机科学,劳里埃大学计算机科学本科.pdf
- 5个最佳网络安全监控工具、 你知道哪些
- 有窗体的闭合导线计算程序(C#)
- 矩阵乘法 c/c++代码
- 计算机装系统找不到硬盘分区,解决安装系统找不到硬盘的问题(图文)
- 什么叫定向广告?定向传播有哪些好处
- Ontology对接资源整理
- Excel数据分析入门-数据透视表
- NY145 聪明的小柯
- 如何设计透明的png图标
- Kafka常用命令收录
- CSS3画布Canvas知识点
- 实验08 软件设计模式及应用
热门文章
- 一个自定义 HBase Filter -“通过RowKeys来高性能获取数据”
- 深入理解计算机系统(2.6)------整数的运算
- Linux安装/卸载软件教程
- python调用matlab环境配置、非常详细!!!_Python调用Matlab2014b引擎
- GraphPad Prism 的统计显著性报告中*或**或**的含义是什么?
- 25接口之间的单继承
- pycharm中无法安装scipy、imread、GDAL等库
- 转载:opencv错误rect错误
- JavaSE(二十)——面向对象的概念及三个基本特征
- Cadence Orcad元器件位号重排与原理图页序号重排