目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上3.6节
此节功能为:线性回归的从零开始实现
由于次节相对复杂,代码注释量较多

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.6_softmax-regression-scratch
# 3.6节
#注释:黄文俊
#邮箱:hurri_cane@qq.com
from _ast import Globalimport torch
import torchvision
import numpy as np
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l
import matplotlib.pyplot as pltdef softmax(X):# print("X",X.size())X_exp = X.exp()# print("X_exp",X_exp[0])partition = X_exp.sum(dim=1, keepdim=True)# a = X_exp / partition# print(a[0])return X_exp / partition  # 这里应用了广播机制#softmax返回的每张图像被划分为不同类别的概率分布#比如:第一张图像:在0-9类中的概率分布可以为[0.1000, 0.1056, 0.1070, 0.0940, 0.0996, 0.1088, 0.0993, 0.0953, 0.0863,0.1042]def net(X):#输入进来的X的size是[256, 1, 28, 28]# print("叉乘前的X",X.size())# view_path= X.view((-1, num_inputs))#X.view((-1, num_inputs))的size是[256, 784]# print(view_path.size())return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)#torch.mm(X.view((-1, num_inputs)), W) + b求出来的结果size为:[256, 10]#其表征的意义是:输入的256张图像为各自分类(一共10种)的概率
'''
torch.mm(X.view((-1, num_inputs)), W) + b解读:
X是一个由256个28*28的图像矩阵构成的张量
将X转换为'''#交叉熵损失函数
def cross_entropy(y_hat, y):#y是真是类别分布size为[256],y_hat是每个类别的预测概率分布[256, 10]#y_hat.gather(1, y.view(-1, 1))表示通过y来索引y_hat对于的概率并转换为列为1的矩阵(即竖直排列)#以y的第一个元素为例,假设值为5,则对应该图像真是分类为5类,则提取y_hat中类为5的概率#因为y_hat中概率分布都是从0-9依次排列的,所以y_hat中类为5的概率即是y_hat中第一个图像对于的第5个元素的值(注意python中顺序是从0开始算的)# print(y_hat.gather(1, y.view(-1, 1)).size())return - torch.log(y_hat.gather(1, y.view(-1, 1)))#分类准确率函数
def accuracy(y_hat, y):return (y_hat.argmax(dim=1) == y).float().mean().item()# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:      #获得图像矩阵;y获得标签值# view_path = net(X)acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()     #acc_sum为正确分类的个数之和n += y.shape[0]return acc_sum / n# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):# 调试时计算次数times_sum = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:#X的size为[256, 1, 28, 28];y的size为[256]# print(y.size())y_hat = net(X)l = loss(y_hat, y).sum()'''损失函数l为交叉熵函数最小化交叉熵损失函数等价于最大化训练数据集所有标签类别的联合预测概率。loss(y_hat, y)返回的是每张图像的交叉熵值,为了反映整体情况需要对其求和“.sum() ”'''# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()times_sum += 1if optimizer is None:d2l.sgd(params, lr, batch_size)else:optimizer.step()  # “softmax回归的简洁实现”一节将用到train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))print(times_sum)batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
'''
train_iter, test_iter都是从fashion_mnist数据库中读取来的小批量(256个)数据集
以train_iter为例,其第一个元素train_iter[0]便包含图像和标签信息,可以用feature, label = train_iter[0]来分别赋给feature和label
feature的大小为[深度*高度*宽度],如:1*28*28;label常以整型数字存在,不同的数字表示所属不同的标签
'''
num_inputs = 784
num_outputs = 10W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
# W 的size为[784,10]
# print(W.size())
b = torch.zeros(num_outputs, dtype=torch.float)#打开模型参数梯度
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)print(evaluate_accuracy(test_iter, net))num_epochs, lr = 5, 0.1
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)print("*"*30)
#训练完成后,现在就可以演示如何对图像进行分类了。给定一系列图像(第三行图像输出),我们比较一下它们的真实标签(第一行文本输出)和模型预测结果(第二行文本输出)。
X, y = iter(test_iter).next()
#不断使⽤next()函数来获取test_iter的下⼀条数据true_labels = d2l.get_fashion_mnist_labels(y.detach().numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).detach().numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]d2l.show_fashion_mnist(X[0:9], titles[0:9])print("*"*30)

《动手学深度学习》(PyTorch版)代码注释 - 3 【Softmaxs_regression_with_zero】相关推荐

  1. 伯禹公益AI《动手学深度学习PyTorch版》Task 04 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 04 学习笔记 Task 04:机器翻译及相关技术:注意力机制与Seq2seq模型:Transformer 微信昵称:WarmIce ...

  2. 伯禹公益AI《动手学深度学习PyTorch版》Task 07 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 07 学习笔记 Task 07:优化算法进阶:word2vec:词嵌入进阶 微信昵称:WarmIce 优化算法进阶 emmmm,讲实 ...

  3. 伯禹公益AI《动手学深度学习PyTorch版》Task 03 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 03 学习笔记 Task 03:过拟合.欠拟合及其解决方案:梯度消失.梯度爆炸:循环神经网络进阶 微信昵称:WarmIce 过拟合. ...

  4. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  5. 【动手学深度学习PyTorch版】12 卷积层

    上一篇移步[动手学深度学习PyTorch版]11 使用GPU_水w的博客-CSDN博客 目录 一.卷积层 1.1从全连接到卷积 ◼ 回顾单隐藏层MLP ◼ Waldo在哪里? ◼ 原则1-平移不变性 ...

  6. 【动手学深度学习PyTorch版】27 数据增强

    上一篇请移步[动手学深度学习PyTorch版]23 深度学习硬件CPU 和 GPU_水w的博客-CSDN博客 目录 一.数据增强 1.1 数据增强(主要是关于图像增强) ◼ CES上的真实的故事 ◼ ...

  7. 【动手学深度学习PyTorch版】13 卷积层的填充和步幅

    上一篇移步[动手学深度学习PyTorch版]12 卷积层_水w的博客-CSDN博客 目录 一.卷积层的填充和步幅 1.1 填充 1.2 步幅 1.3 总结 二.代码实现填充和步幅(使用框架) 一.卷积 ...

  8. 【动手学深度学习PyTorch版】23 深度学习硬件CPU 和 GPU

    上一篇请移步[动手学深度学习PyTorch版]22续 ResNet为什么能训练出1000层的模型_水w的博客-CSDN博客 目录 一.深度学习硬件CPU 和 GPU 1.1 深度学习硬件 ◼ 计算机构 ...

  9. 【动手学深度学习PyTorch版】15 池化层

    上一篇请移步[动手学深度学习PyTorch版]14 卷积层里的多输入多输出通道_水w的博客-CSDN博客 目录 一.池化层 1.1 池化层 ◼池化层原因 ◼ 二维最大池化 1.2 填充.步幅与多个通道 ...

  10. 伯禹公益AI《动手学深度学习PyTorch版》Task 05 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 05 学习笔记 Task 05:卷积神经网络基础:LeNet:卷积神经网络进阶 微信昵称:WarmIce 昨天打了一天的<大革 ...

最新文章

  1. C#设计技巧总结 网上转贴
  2. mysql 导入一个数据库_mysql导入一个数据库
  3. Mysql实战:主从同步
  4. 树莓派 rfid_树莓派工控机做Modbus RTU主站读取RFID数据
  5. word文档老是出现这个提示-----“发现二义性的名称:TmpDDE”错误
  6. c++设计模式:单例模式
  7. 开发工具:收集12 个顶级 Bug 跟踪工具,值得收藏!
  8. java接口文件定义类_Java入门笔记(四)类、包和接口
  9. Shank的大步小步算法(Shank‘s Baby-Step-Giant-Step Algorithm)
  10. 创业,白手起家需要些什么?
  11. java 内存泄漏问题_Java内存泄露的理解与解决(转)
  12. 注意力机制介绍(attention)
  13. c++实现查询天气预报
  14. Windows异常学习笔记(三)—— VEHSEH
  15. tensorflow如何使用tensorboard将图片文件events.out.tfevents.1618410161.DESKTOP-CLCBFNS展示出来
  16. [高数][高昆轮][高等数学上][第一章-函数与极限]01.映射与极限
  17. 西门子1200PLC的MODBUS_RTU轮询程序
  18. 微信小程序装修解决方案ppt_微信小程序开发教程.ppt
  19. 外媒:忘掉微软 Win11 吧
  20. c语言 结构体指针做函数参数

热门文章

  1. win10计算机控制面板在哪里,教您win10控制面板在哪
  2. 网络服务器配置管理综合实训项目心得体会,服务器的配置与管理实训报告.doc...
  3. Linux_加密和安全详细介绍
  4. 转贴自圣骑士wind:Google Maps Android API V2的使用及问题解决
  5. php faker,Laravel的Faker的使用
  6. RK3399平台开发系列讲解(USB网卡)5.48、USBNET的CDC link on/off 消息
  7. Ansible详解(一)
  8. 过去的Tony老师你爱理不理,现在的Tony老师你高攀不起
  9. Angular启动项目时报错
  10. 用了半年的时间,把python学到了能出书的程度