目录

  • 1、代码实现
  • 2、踩过的坑

1、代码实现

步骤

  1. 获得数据
  2. 建立逻辑回归模型
  3. 定义损失函数
  4. 计算损失函数
  5. 求解梯度
  6. 梯度更新
  7. 预测测试集
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoaderinput_size = 784  # 输入到逻辑回归模型中的输入大小
num_classes = 10  # 分类的类别个数
num_epochs = 10  # 迭代次数
batch_size = 50  # 批量训练个数
learning_rate = 0.01  # 学习率# 下载训练数据和测试数据
train_dataset = dataset.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dataset.MNIST(root='./data',train=False, transform=transforms.ToTensor)# 使用DataLoader形成批处理文件
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 创建逻辑回归类模型  (sigmoid(wx+b))
class LogisticRegression(nn.Module):def __init__(self,input_size,num_classes):super(LogisticRegression,self).__init__()self.linear = nn.Linear(input_size,num_classes)self.sigmoid = nn.Sigmoid()def forward(self, x):out = self.linear(x)out = self.sigmoid(out)return out# 设定模型参数
model = LogisticRegression(input_size, num_classes)
# 定义损失函数,分类任务,使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化算法,随机梯度下降,lr为学习率,获得模型需要更新的参数值
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)# 使用训练数据训练模型
for epoch in range(num_epochs):# 批量数据进行模型训练for i, (images, labels) in enumerate(train_loader):# 需要将数据转换为张量Variableimages = Variable(images.view(-1, 28*28))labels = Variable(labels)# 梯度更新前需要进行梯度清零optimizer.zero_grad()# 获得模型的训练数据结果outputs = model(images)# 计算损失函数用于计算梯度loss = criterion(outputs, labels)# 计算梯度loss.backward()# 进行梯度更新optimizer.step()# 每隔一段时间输出一个训练结果if (i+1) % 100 == 0:print('Epoch:[%d %d], Step:[%d/%d], Loss: %.4f' % (epoch+1,num_epochs,i+1,len(train_dataset)//batch_size,loss.item()))# 训练好的模型预测测试数据集
correct = 0
total = 0
for images, labels in test_loader:images = Variable(images.view(-1, 28*28))  # 形式为(batch_size,28*28)outputs = model(images)_,predicts = torch.max(outputs.data,1)  # _输出的是最大概率的值,predicts输出的是最大概率值所在位置,max()函数中的1表示维度,意思是计算某一行的最大值total += labels.size(0)correct += (predicts==labels).sum()print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))

2、踩过的坑

  1. 在代码中下载训练数据和测试数据的时候,两段代码是有区别的:
train_dataset = dataset.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dataset.MNIST(root='./data',train=False, transform=transforms.ToTensor)

第一段代码中多了一个download=True,这个的作用是,如果为True,则从Internet下载数据集并将其存放在根目录中。如果数据已经下载,则不会再次下载。

在第二段代码中没有加download=True,加了的话在使用测试数据进行预测的时候会报错。

代码中transform=transforms.ToTensor()的作用是将PIL图像转换为Tensor,同时已经进行归一化处理。

  1. 代码中设置损失函数:
criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, labels)

一开始的时候直接使用:

loss = nn.CrossEntropyLoss()
loss = loss(outputs, labels)

这样也会报错,因此需要将loss改为criterion。

用Pytorch实现逻辑回归分类相关推荐

  1. python利用什么写模板_Python利用逻辑回归分类实现模板

    Logistic Regression Classifier逻辑回归主要思想就是用最大似然概率方法构建出方程,为最大化方程,利用牛顿梯度上升求解方程参数. 优点:计算代价不高,易于理解和实现. 缺点: ...

  2. spark java 逻辑回归_逻辑回归分类技术分享,使用Java和Spark区分垃圾邮件

    原标题:逻辑回归分类技术分享,使用Java和Spark区分垃圾邮件 由于最近的工作原因,小鸟很久没给大家分享技术了.今天小鸟就给大家介绍一种比较火的机器学习算法,逻辑回归分类算法. 回归是一种监督式学 ...

  3. R语言使用逻辑回归分类算法

    R语言使用逻辑回归分类算法 逻辑回归属于概率统计的分类算法模型的算法,是根据一个或者多个特征进行类别标号预测.在R语言中可以通过调用logit函数执行逻辑回归分类算法并预测输出概率.通过调用glm函数 ...

  4. python机器学习基础05——sklearn之逻辑回归+分类评价指标

    文章目录 逻辑回归 逻辑回归的损失函数 逻辑回归API 分类模型的评价指标 混淆矩阵 准确率 召回率(较多被使用) 精确率 f1-score:精确率和召回率的调和平均数 AUC 逻辑回归 逻辑回归是经 ...

  5. 数据挖掘—逻辑回归分类—信用卡欺诈分析

    文章目录 1.分析目的: 2.掌握要点: 3.构建逻辑回归分类器 4.模型评估指标 5.精确度和召回率(不平衡数据衡量指标) 6.案例分析: 1.分析目的: 信用卡欺诈的危害性大,如何通过遗忘的交易数 ...

  6. 树模型与线性模型的区别 决策树分类和逻辑回归分类的区别 【总结】

    树模型与线性模型的区别在于: (一)树模型 ①树模型产生可视化的分类规则,可以通过图表表达简单直观,逐个特征进行处理,更加接近人的决策方式   ②产生的模型可以抽取规则易于理解,即解释性比线性模型强. ...

  7. java基础巩固-宇宙第一AiYWM:为了维持生计,编程语言番外篇之机器学习(项目预测模块总结:线性回归算法、逻辑回归分类算法)~整起

    机器学习 一.机器学习常见算法(未完待续...) 1.算法一:线性回归算法:找一条完美的直线,完美拟合所有的点,使得直线与点的误差最小 2.算法二:逻辑回归分类算法 3.算法三:贝叶斯分类算法 4.算 ...

  8. 编程实践-逻辑回归分类算法--马的疝气病症分类

    #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 逻辑回归分类 @dataset:马的疝气病症数据集Horse Col ...

  9. 逻辑回归分类鸢尾花和红酒等级

    逻辑回归分类鸢尾花和红酒等级 源代码以及训练数据和测试数据已上传:https://download.csdn.net/download/j__max/10816259 一.实验准备 1.实验内容和目的 ...

最新文章

  1. [置顶] 浅析objc的消息机制
  2. Educational Codeforces Round 9 A. Grandma Laura and Apples 水题
  3. php运行汇编,php脚本的执行过程(编译与执行相分离)
  4. 怎样编写测试类测试分支_编写干净的测试–天堂中的麻烦
  5. 笨办法学 Python · 续 练习 5:`cat`
  6. 家里安装20M宽带,为什么看视频都不卡,但一玩游戏就卡的要死?
  7. python绘图背景透明_如何在 Matplotlib 中更改绘图背景
  8. Hibernate深入浅出(六)事务2——锁locking
  9. mysql2008 精简版_精简版 SqlServer2008 的安装和使用
  10. 用C#通过sql语句操作Sqlserver数据库教程
  11. 如何自制自平衡云台基于mpu6050,arduino输出三维倾斜角度的方法(含源码,库)
  12. tensorflow学习笔记(八):LSTM手写体(MNIST)识别
  13. F. Asya And Kittens
  14. win xp出现“安装程序包的语言不受支持”的解决
  15. 微软 2022 新 bug:大量程序员连夜加班!
  16. Linux下软连接的创建和删除
  17. 解决: Specifically, your app violates Section 3.2(f) of the PLA, which states:
  18. 不打不相识,苹果偷学微信代码
  19. 欧朋浏览器的移动互联网变局
  20. 旧android 4 平板,1099的小米平板4可能真的刷新了我心目中的安卓平板

热门文章

  1. 硬件磁盘阵列还是软件磁盘阵列
  2. 软件正版,我们是缺钱还是缺意识
  3. Linux命令之zip命令
  4. 面试精讲之面试考点及大厂真题 - 分布式专栏 10 Redis雪崩,穿透,击穿三连问
  5. 容器编排技术 -- Kubernetes 为 Namespace 设置最小和最大内存限制
  6. Oracle 优化和性能调整
  7. life game c语言,c++生命游戏源码
  8. 个人猜测一下《黑神话:悟空》的部分剧情
  9. C#LeetCode刷题-双指针
  10. FreeCodeCamp Caesars密码项目的演练