今天我将为大家介绍逻辑回归的含义并展示Pytorch实现逻辑回归的方法,先我们来看看一个问题。

问题:

大家想必对MNIST数据集已经非常熟悉了吧?这个数据集被反复“咀嚼”,反复研究。今天我们将换个角度研究MNIST数据集。假设现在不使用卷积神经网络,又该使用什么方法来解决MNIST分类问题呢?

一、观察数据

在开始分析数据问题之前,我们需要了解最基本的数据对象。最好的方法就是访问官网去看一看数据的构成。官网地址如下:MNIST。

MNIST数据集包含四个部分:

  • Training set images: train-images-idx3-ubyte.gz (包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (包含 10,000 个标签)

根据官网的介绍,每张图像的大小为28*28,标签和图像的存储格式如下所示:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  60000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

可以使用以下方式读取数据集文件:

import os
import struct
import numpy as npdef load_mnist(path, kind='train'):"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte'% kind)with open(labels_path, 'rb') as lbpath:magic, n = struct.unpack('>II',lbpath.read(8))labels = np.fromfile(lbpath,dtype=np.uint8)with open(images_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 784)return images, labels

图像数据的形式如下所示,具体情况请参考详解 MNIST 数据集一文对MNIST数据集的介绍。

借助上面的代码,我们可以获得对数据的一个直观感受。在分析数据集时往往还需要关注一些细节问题,例如数据是否存在类别不平衡问题、数据的噪声情况、数据是否做过归一化处理等等。MNIST相对比较简单,对于上述数据问题,后面的课程中我们会结合更加复杂的数据集再做探讨。

二、分析问题

上一节课程中,我们讨论了线性回归问题,而这次的课程需要解决的是一个分类问题。在自己动脑筋想办法之前,我们先找到一个巨人,站在他的肩膀上看一看,弄清楚针对这一问题已有的办法是什么样的。一般情况下,每个开放的数据集都会给出基于此数据集的各项任务排名。下面这张表摘录了部分应用于MNIST数据集的方法,这些方法被分为如下7类:

  • Linear Classifiers
  • K-Nearest Neighbors
  • Boosted Stumps
  • Non-Linear Classifiers
  • SVMs
  • Neural Nets
  • Convolutional nets

这些方法并不是凭空出现或者一拍脑门就想出来的,分析研究其设计思路是一件非常有意思的事情。还记得我们在课程开始提出的问题吗?“假设现在不使用卷积神经网络,该使用什么方法来解决MNIST分类问题”。前人的工作已经给出了回答。

我们今天探讨的内容是“逻辑回归”。逻辑回归可以视作用回归的方法处理分类问题。结合我们上一节课程介绍的内容,回归问题可以用线性函数加以拟合。同样的逻辑回归问题也可以使用线性函数来处理。所以说用逻辑回归处理MNIST分类问题,可以理解为使用“Linear Classifiers”来处理该问题。但是回归问题里可以用“距离”这种概念来作为优化目标。分类问题里又该选择哪一种优化目标呢?下面我用尽可能简单的语言来介绍人们设计分类问题优化目标的思路。

思路一:使用分类误差作为优化目标

我们先来看看两张表格,这两张表格分别是两个线性模型进行分类预测的结果。

模型一
模型二

从分类结果来看,两个模型的分类误差是一致的。但是,稍微注意一下就会发现,模型一的性能优于模型二。但是分类误差显然不能区分出这一点。所以说,使用分类误差作为优化目标可能是不精确的。

思路二:使用均方误差作为优化目标

使用均方误差(MSE)来处理分类问题倒也不是不行,但是请朋友们注意一个问题。我们是在做分类问题,输出的结果是一个类别,类别是离散值,我们需要计算每个类别出现的概率并挑选概率最大的类别作为输出。在计算各个类别概率的时候,需要使用softmax函数。但是使用了softmax函数后,MSE的函数形状是非凸的,换句话说就是有许多局部的极值点。这样很难通过反向传播来进行优化。就像下图的小球滚到“山腰”就没法往下滚了。

思路三:使用交叉熵函数作为优化目标

交叉熵函数非常有意思,大家可以看看这篇文章熵与信息增益。今天主要逻辑回归的内容,熵与信息增益的内容会放到生成对抗网络部分进行介绍。这里我们直接看交叉熵函数的公式:

其中

是预测结果,
是ground truth。

那么根据公式,模型一中第一项的交叉熵误差的值是:

以此类推,模型一的平均交叉熵误差是:

模型二的平均交叉熵误差是:

所以在分类任务中,交叉熵函数可以比较好的度量不同分布之间的差别。

三、实验

这部分的代码比较简单,所以我就把代码放上来了,方便大家复制粘贴,完整代码请参见pytorch-tutorial。本节课程的代码非常类似于上一节课程介绍的线性回归代码,主要的差别在于本节课程采用了CrossEntropyLoss这一损失函数。

完整代码如下:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# Hyper-parameters
input_size = 784
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001# MNIST dataset (images and labels)
train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, transform=transforms.ToTensor())# Data loader (input pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# Logistic regression model
model = nn.Linear(input_size, num_classes)# Loss and optimizer
# nn.CrossEntropyLoss() computes softmax internally
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  # Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):# Reshape images to (batch_size, input_size)images = images.reshape(-1, 28*28)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, 28*28)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum()print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

运行结果如下:

四、小结与反思

不知道各位读者看完上述程序后有没有一丝丝疑惑,这个单层的网络到底学习到了什么东西?凭什么建立一个784到10的映射就能完成手写字体识别任务?

换个角度考虑一下,我们来想想784×10个映射关系中的一个具体的映射关系究竟学习到了什么东西?这个问题比较好回答,线性映射能学到的无非就是一个线性函数。这个线性函数能完成的任务也就是判断与这个线性函数对应的像素块对全局信息的贡献度。

那么多个线性映射是否就能完整的表达图像的全局信息呢?答案是:不一定。原因很简单,不同像素位置的点的贡献度应当是不同的,不同位置的像素之间存在一定的关联,这种联系在描述语义信息时是非常重要的。显然单层线性模型无法完成这项任务。我们将针对这一问题在后续课程中继续研究更加合适的处理方法。

逻辑回归代码_Pytorch教程(四):逻辑回归相关推荐

  1. python实现逻辑回归代码_python如何实现逻辑回归 python实现逻辑回归代码示例

    python如何实现逻辑回归?本篇文章小编给大家分享一下python实现逻辑回归代码示例,文章代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看. 代码如下: i ...

  2. python机器学习案例系列教程——逻辑分类/逻辑回归LR/一般线性回归(softmax回归)

    全栈工程师开发手册 (作者:栾鹏) python数据挖掘系列教程 线性函数.线性回归 参考:http://blog.csdn.net/luanpeng825485697/article/details ...

  3. 从原理到代码,轻松深入逻辑回归模型!

    整理 | Jane 出品 | AI科技大本营(ID:rgznai100) [导语]学习逻辑回归模型,今天的内容轻松带你从0到100!阿里巴巴达摩院算法专家.阿里巴巴技术发展专家.阿里巴巴数据架构师联合 ...

  4. python逻辑回归代码_Logistic 逻辑回归及 python 实现

    1. 引言 2. 数例 3. logistic 函数原理 4. 极大似然估计求出参数值 5. python 代码 1. 引言 Logistic 逻辑回归比较适合分类型因变量的回归,这种问题在现实中很多 ...

  5. 逻辑回归和决策树_结合逻辑回归和决策树

    逻辑回归和决策树 Logistic regression is one of the most used machine learning techniques. Its main advantage ...

  6. 逻辑回归(Logistic Regression, LR)又称为逻辑回归分析,是分类和预测算法中的一种。通过历史数据的表现对未来结果发生的概率进行预测。例如,我们可以将购买的概率设置为因变量,将用户的

    逻辑回归(Logistic Regression, LR)又称为逻辑回归分析,是分类和预测算法中的一种.通过历史数据的表现对未来结果发生的概率进行预测.例如,我们可以将购买的概率设置为因变量,将用户的 ...

  7. python逻辑回归模型建模步骤_Python逻辑回归——建模-评估模型

    学完线性回归,逻辑回归建模+评估模型的过程就相对好理解很多.其实就是换汤不换药. 逻辑回归不是回归算法,而是分类算法,准确来说,叫逻辑分类 逻辑分类本质上是二分分类,即分类结果标签只有两个 逻辑回归建 ...

  8. ai逻辑回归_人工智能中的逻辑是什么?

    ai逻辑回归 人工智能逻辑 (Logic in Artificial Intelligence) Logic, as per the definition of the Oxford dictiona ...

  9. 逻辑回归分类python实例_Python逻辑回归原理及实际案例应用

    前言 目录 1. 逻辑回归 2. 优缺点及优化问题 3. 实际案例应用 4. 总结 正文 在前面所介绍的线性回归, 岭回归和Lasso回归这三种回归模型中, 其输出变量均为连续型, 比如常见的线性回归 ...

最新文章

  1. Lite-HRNet
  2. excel执行INSERT和UPDATE操作语句
  3. iOS web与JS交互
  4. Qt线程之QRunnable的使用详解
  5. 《Linux命令行与shell脚本编程大全 第3版》高级Shell脚本编程---32
  6. SWIFT(Society for Worldwide Interbank Financial SWIFT Telecommunications---环球同业银行金融电讯协会)
  7. 枚举、位操作 CLR学习第十二课
  8. 带你了解 HBase 数据模型和 HBase 架构
  9. SAP License:SAP的国家会计科目表
  10. 基于 Eclipse 平台的代码生成技术
  11. 计算机颜色的概念,颜色空间
  12. 模式识别与机器学习---绪论
  13. android输入框边框距离,如何更改Android对话框边距(到屏幕边缘的距离)?
  14. Base16加密算法
  15. 如何在新的Apple TV遥控器上调整触摸灵敏度
  16. Python爬虫使用lxml模块爬取豆瓣读书排行榜并分析
  17. 利用canvas和vue_qrcodes生成带二维码头像的海报(一)
  18. android开发常用工具类、高仿客户端、附近厕所、验证码助手、相机图片处理等源码...
  19. Python基础—简介、变量、运算符
  20. 业务型团队如何提高人效

热门文章

  1. Linux下安装 boost 库
  2. Android消息机制学习笔记
  3. codevs 2075 yh女朋友的危机
  4. WEBSHELL跳板REDUH使用说明
  5. PHP实现XML传输
  6. 开源依旧:再次分享一个进销存系统
  7. mysql主从结构主数据库_mysql主从结构主数据库
  8. php缓存数据到本地缓存,本地缓存localStorage的使用方法
  9. C语言指针(就做个笔记)
  10. “;“分号空语句的使用