深度学习基础–SOFTMAX回归(单层神经网络)

最近在阅读一本书籍–Dive-into-DL-Pytorch(动手学深度学习),链接:https://github.com/newmonkey/Dive-into-DL-PyTorch,自身觉得受益匪浅,在此记录下自己的学习历程。

本篇主要记录关于SOFTMAX回归的知识。softmax回归和线性回归一样都属于单层神经网络;线性回归主要适用于回归问题,而softmax回归主要使用于分类问题。本文主要尝试对手写数字进行识别。sofemax函数又叫归一化指数函数。

1 收集数据集

我们通过torchvision的 torchvision.datasets 来下载这个手写数字识别数据集MNIST。可以获得60000个训练集样本数与10000个测试集样本数。

import torch
import torchvision
import numpy as npdef load_data_fashion_mnist(batch_size, root='本地地址url'):transform = torchvision.transforms.ToTensor()mnist_train = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform)train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)return train_iter,test_iterbatch_size=256
train_iter,test_iter=load_data_fashion_mnist(batch_size)

2 定义和初始化模型

每个样本的shape为[1,28,28],即通道数为1,高和宽都为为28像素的图像。故模型的输入向量的长度是784。softmax回归的输出层是⼀个全连接层,所以我们⽤⼀个线性模块就可以了。

import torch.nn as nn
num_inputs = 784
num_outputs = 10
class LinearNet(nn.Module):def __init__(self, num_inputs, num_outputs):super(LinearNet, self).__init__()self.linear = nn.Linear(num_inputs, num_outputs)def forward(self, x): # x shape: (batch, 1, 28, 28)y = self.linear(x.view(x.shape[0], -1))return ynet = LinearNet(num_inputs, num_outputs)
print(net)init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0)

3 sofemax和交叉熵损失函数

PyTorch提供了⼀个包括softmax运算和交叉熵损失计算的函数。

loss=nn.CrossEntropyLoss()

4 定义优化算法

采用学习率为0.005的⼩批量随机梯度下降(SGD)为优化算法。

optimizer = torch.optim.SGD(net.parameters(), lr=0.005)

5 训练模型

迭代周期设置为10,模型训练。

num_epochs = 10def sgd(params, lr, batch_size):for param in params:param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.datadef evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / ndef train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:sgd(params, lr, batch_size)else:optimizer.step()  # “softmax回归的简洁实现”一节将用到train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))train_ch3(net, train_iter, test_iter, loss, num_epochs,batch_size, None, None, optimizer)
#结果
#epoch 1, loss 0.0071, train acc 0.675, test acc 0.788
#epoch 2, loss 0.0050, train acc 0.795, test acc 0.821
#epoch 3, loss 0.0040, train acc 0.819, test acc 0.837
#epoch 4, loss 0.0034, train acc 0.832, test acc 0.846
#epoch 5, loss 0.0031, train acc 0.841, test acc 0.855
#epoch 6, loss 0.0028, train acc 0.848, test acc 0.860
#epoch 7, loss 0.0026, train acc 0.853, test acc 0.865
#epoch 8, loss 0.0025, train acc 0.857, test acc 0.868
#epoch 9, loss 0.0024, train acc 0.860, test acc 0.871
#epoch 10, loss 0.0023, train acc 0.863, test acc 0.874

6 预测

训练完成后,现在就可以演示如何对图像进⾏分类了。第⼀⾏为真实标签,第⼆⾏为预测标签,第三行为图像。

from IPython import display
import matplotlib.pyplot as plt
X, y = iter(test_iter).next()
def get_fashion_mnist_labels(labels):text_labels = ['0', '1', '2', '3', '4','5', '6', '7', '8', '9']return [text_labels[int(i)] for i in labels]def show_fashion_mnist(images, labels):#use_svg_display()display.display_svg()# 这⾥的_表示我们忽略(不使⽤)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels =get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels,pred_labels)]
show_fashion_mnist(X[0:20], titles[0:20])

预测结果展示:(第⼀⾏为真实标签,第⼆⾏为预测标签,第三行为图像)

END!

深度学习基础--SOFTMAX回归(单层神经网络)相关推荐

  1. 深度学习:Softmax回归

    在前面,我们介绍了线性回归模型的原理及实现.线性回归适合于预测连续值,而对于分类问题的离散值则束手无策.因此引出了本文所要介绍的softmax回归模型,该模型是针对多分类问题所提出的.下面我们将从so ...

  2. 【动手学深度学习】Softmax 回归 + 损失函数 + 图片分类数据集

    学习资料: 09 Softmax 回归 + 损失函数 + 图片分类数据集[动手学深度学习v2]_哔哩哔哩_bilibili torchvision.transforms.ToTensor详解 | 使用 ...

  3. 【深度学习基础】经典卷积神经网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导语 卷积神经网络(Convolutional Neural Ne ...

  4. 【动手学深度学习】softmax回归

    softmax回归 1.softmax回归基本概念 2.图像分类数据集流程图 3.softmax从零开始实现流程图 4.softmax回归的简洁实现 1.softmax回归基本概念 分类问题 独热编码 ...

  5. [动手学深度学习]02 softmax回归

    softmax回归 1. softmax回归 2. softmax操作 3. 最大似然估计 4. 损失函数 5. 梯度 6. 实现 6.1 从零实现softmax回归 6.2 简洁实现 7. 课后习题 ...

  6. 深度学习基础知识(一): 概述-神经网络起源和发展

    目录 1. 神经网络概述 1.1 概念和起源 1.2 神经网络基本结构 1.2.1 神经元数学模型 1.2.2 sigmoid激活函数 1.2.3 神经网络结构 1.3 我对神经网络的构建和流程理解 ...

  7. 【一】零基础入门深度学习:用numpy实现神经网络训练

    (给机器学习算法与Python学习加星标,提升AI技能) 作者 | 毕然 百度深度学习技术平台部主任架构师 内容来源 | 百度飞桨深度学习集训营 本文转自飞桨PaddlePaddle 本课程是百度官方 ...

  8. 独家思维导图!让你秒懂李宏毅2020深度学习(三)——深度学习基础(神经网络和反向传播部分)

    独家思维导图!让你秒懂李宏毅2020深度学习(三)--深度学习基础(神经网络和反向传播部分) 长文预警!!!前面两篇文章主要介绍了李宏毅视频中的机器学习部分,从这篇文章开始,我将介绍李宏毅视频中的深度 ...

  9. 深度学习导论(4)神经网络基础

    深度学习导论(4)神经网络基础 一. 训练深度学习模型的步骤 二. 线性层(或叫全链接层)(Linear layer(dense or fully connected layers)) 1. 定义一个 ...

最新文章

  1. 公司用的 MySQL 团队开发规范
  2. 《写给大家看的设计书:实例与创意(修订版)》—1你已经知道多少了?
  3. SAP UI5 初学者教程之三:开始接触第一个 SAP UI5 控件 试读版
  4. 基于Curator实现dubbo服务自动注册发现
  5. python 菜鸟:返回值_Python中的真实值和虚假值:详细介绍
  6. flutter DateTime 日期时间详细解析 Dart语言基础
  7. 关于 iOS 证书,你必须了解的知识
  8. CPU亲和性(affinity)sched_setaffinity() 和 sched_getaffinity()
  9. Dato for mac(自定义菜单栏日历)支持m1
  10. 在这个人人拥抱python的时代,R真的out了吗?
  11. python图片中文字识别
  12. UCenter+云市场?开源用户中心2.0时代即将开启
  13. “字符串的展开”【题解】
  14. 苹果中国应用商店改为人民币结算 可网银充值
  15. 基于STM32F4:多通道ADC采集,采用DMA的形式,亲测有效
  16. Virtual host / experienced an error on node rabbit@XX and may be inaccessible
  17. ToggleSwitch控件介绍
  18. 解决virus.vbs.wiritebin.a和Virus.Win32.Ramin.x病毒
  19. 近期活动盘点:数据院五周年系列活动之医疗专场、DeeCamp2019:实战AI 铸造定雨神针...
  20. mySQL 错误 3167 - The 'INFORMATION_SCHEMA.GLOBAL_STATUS' feature is disabled; see the document

热门文章

  1. 设计模式—单例模式(饿汉式、懒汉式)
  2. WPF--控件(代码讲解)
  3. 有道词典java下载手机版下载手机版_有道词典app下载_有道词典在线翻译下载安装手机版v9.08...
  4. 常用ruby gem
  5. 零基础入门UI设计必备实用技巧!
  6. 软件测试--【软件测试和bug】
  7. 关于Python not 及is None的有趣现象(两者的区别)
  8. 惠普服务器装系统无法识别u盘,惠普uefi bios无法识别u盘的解决方法
  9. http://windowsandroid.cn.uptodown.com/download
  10. sql server 获取本机的ip地址