《PyTorch深度学习实践》学习笔记 【2】

学习资源:
《PyTorch深度学习实践》完结合集

二、线性模型

2.1 概念:

2.1.1 数据集和测试集

​ 数据集拿到后一般划分为两部分,训练集和测试集,然后使用训练集的数据来训练模型,用测试集上的误差作为最终模型在应对现实场景中的泛化误差。

​ 一般来说,测试集在训练的时候是不能偷看的。

我们可以使用训练集的数据来训练模型,然后用测试集上的误差作为最终模型在应对现实场景中的泛化误差。有了测试集,我们想要验证模型的最终效果,只需将训练好的模型在测试集上计算误差,即可认为此误差即为泛化误差的近似,我们只需让我们训练好的模型在测试集上的误差最小即可。

​ 为了使得模型在现实生活中更有效,我们要使用的数据集要尽可能真实。

2.1.2 过拟合与泛化

下面拿小猫图像识别做例子,说明一下过拟合泛化的概念;

过拟合: 在训练集上匹配度很好,但是太过了,把噪声什么的也学进来了。

泛化能力: 对于没见过的图像也能进行识别,这是我们所需要的。

2.1.3 开发集

有时候无法看到测试集,我们又人为地把数据集划分一部分出来作为验证评估,称为“开发集”。

2.1.4 监督学习和非监督学习

有监督学习方法必须要有训练集与测试样本。在训练集中找规律,而对测试样本使用这种规律。而非监督学习没有训练集,只有一组数据,在该组数据集内寻找规律。

有监督学习的方法就是识别事物,识别的结果表现在给待识别数据加上了标签。因此训练样本集必须由带标签的样本组成。而非监督学习方法只有要分析的数据集的本身,预先没有什么标签。如果发现数据集呈现某种聚集性,则可按自然的聚集性分类,但不予以某种预先分类标签对上号为目的。

2.2 线性回归

2.2.1 线性模型

如 y= kx+b ,我们训练的结果就是k和b的值

2.2.2 损失函数

  • 误差函数

  • 平均平方误差(MSE)

  • 损失函数的值越小,代表拟合的效果越好。

2.3 课上实验【1】

课上代码:

import numpy as np
import matplotlib.pyplot as plt;
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]#线性模型
def forward(x):return x * w#损失函数
def loss(x, y):y_pred = forward(x)return (y_pred - y) * (y_pred - y)#迭代取值,计算每个w取值下的x,y,y_pred,loss_val
w_list = []
mse_list = []
for w in np.arange(0.0, 4.1, 0.1):print('w=', w)l_sum = 0for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val)loss_val = loss(x_val, y_val)l_sum += loss_valprint('\t', x_val, y_val, y_pred_val, loss_val)print('MSE=', l_sum / 3)w_list.append(w)mse_list.append(l_sum / 3)##画图
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

结果:

2.4 作业【1】:

参考 Matplotlib3D作图-plot_surface(), .contourf(), plt.colorbar()

代码:

import numpy as np
import matplotlib.pyplot as plt;
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cmx_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]#线性模型
def forward(x,w,b):return x * w+ b#损失函数
def loss(x, y,w,b):y_pred = forward(x,w,b)return (y_pred - y) * (y_pred - y)def mse(w,b):l_sum = 0for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val,w,b)loss_val = loss(x_val, y_val,w,b)l_sum += loss_valprint('\t', x_val, y_val, y_pred_val, loss_val)print('MSE=', l_sum / 3)return  l_sum/3#迭代取值,计算每个w取值下的x,y,y_pred,loss_val
mse_list = []##画图##定义网格化数据
b_list=np.arange(-30,30,0.1)
w_list=np.arange(-30,30,0.1);##生成网格化数据
xx, yy = np.meshgrid(b_list, w_list,sparse=False, indexing='xy')##每个点的对应高度
zz=mse(xx,yy)fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, cmap=cm.viridis)
plt.show()

结果:

《PyTorch深度学习实践》学习笔记 【2】相关推荐

  1. PyTorch深度学习实践概论笔记9-SoftMax分类器

    上一讲PyTorch深度学习实践概论笔记8-加载数据集中,主要介绍了Dataset 和 DataLoader是加载数据的两个工具类.这一讲介绍多分类问题如何解决,一般会用到SoftMax分类器. 0 ...

  2. 深度学习框架Pytorch入门与实践——读书笔记

    2 快速入门 2.1 安装和配置 pip install torch pip install torchvision#IPython魔术命令 import torch as t a=t.Tensor( ...

  3. PyTorch——深度神经网络的写作笔记

    1 致谢 感谢Facebook的开发者的辛苦和努力- (给Google只有两个字"呵呵") 2 深度神经网络的搭建 2.1 Module的添加 使用nn.Module.add_mo ...

  4. numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践

    <<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗  读完<<深度学习框架PyTorc ...

  5. Pytorch:NLP 迁移学习、NLP中的标准数据集、NLP中的常用预训练模型、加载和使用预训练模型、huggingface的transfomers微调脚本文件

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) run_glue.py微调脚本代码 python命令执行run ...

  6. 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)

    从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...

  7. 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归

    刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...

  8. 【PyTorch深度学习实践 | 刘二大人】B站视频教程笔记

    资料 [参考:<PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili] [参考 分类专栏:PyTorch 深度学习实践_错错莫的博客-CSDN博客] 全[参考 分类专栏:PyT ...

  9. 笔记|(b站)刘二大人:pytorch深度学习实践(代码详细笔记,适合零基础)

    pytorch深度学习实践 笔记中的代码是根据b站刘二大人的课程所做的笔记,代码每一行都有注释方便理解,可以配套刘二大人视频一同使用. 用PyTorch实现线性回归 # 1.算预测值 # 2.算los ...

最新文章

  1. 编程基础 垃圾回收_为什么我回收编程问题
  2. 配置 influxDB 鉴权及 HTTP API 写数据的方法
  3. 一个原生态ajax过程,提交表单的例子
  4. 《数据结构》c语言版学习笔记——单链表结构(线性表的链式存储结构Part1)
  5. 安卓中bundle的使用
  6. Linux 命令之 usermod -- 用于修改用户的基本信息
  7. vSphere HA 原理与配置
  8. 【POJ - 1486】Sorting Slides(思维建图,二分图求必须边,关建边,图论)
  9. 一个ubuntu server下的oracle10g简单生产库全库备份脚本
  10. Struts2国际化——完整实例代码
  11. 修改linux系统时间的方法(date命令)
  12. c语言报数函数问题,[编程入门]报数问题-题解(C语言代码)
  13. Wampserver查看php配置信息
  14. springboot内存占用大_《SpringBoot整合redis、Scheduled/quartz定时任务》
  15. android崩解日志,android – 使用rxJava2和改造的UndeliverableException
  16. [免费专栏] 车联网基础理论之车联网安全常见术语科普
  17. win10c 系统语言 英文,Win10英文版系统下中文软件显示为问号的解决方法
  18. vbscript On Error语句
  19. 2017年计算机ppt考试试题,2017年职称计算机考试(PPT练习题大全)
  20. iOS客户端实现 XMPP协议的步骤

热门文章

  1. python之程序判断季节
  2. 走在团队的前沿(9)---面向交付的团队建设
  3. rtk定位权限_采用GPS-RTK定位方法进行控制测量的技术要求
  4. 【百度经验中excel VBA凑发票程序修改版】
  5. Golang 控制台百行代码贪吃蛇小游戏
  6. 倍福---TwinCAT3Ads通讯
  7. 计算机应用基础教师带教方案,论文实施方案范文
  8. Vue CLI3搭建的项目中,如何给文件夹起别名?
  9. 移动端——swipe特效之图片时间轴
  10. (转)android多国语言适配