参考

3.12 权重衰减

本节介绍应对过拟合的常用方法

3.12.1 方法

正则化通过为模型损失函数添加惩罚项使学出的模型参数更小,是应对过拟合的常用手段。

3.12.2 高维线性回归实验

import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2ln_train, n_test, num_inputs = 20, 100, 200
true_w, true_b = torch.ones(num_inputs, 1) * 0.01, 0.05features = torch.randn((n_train + n_test, num_inputs))
labels = torch.matmul(features, true_w) + true_b
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

3.13.3 从零开始实现

3.12.3.1 初始化模型参数

def init_params():w = torch.randn((num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

3.12.3.2 定义L2范数惩罚项

def l2_penalty(w):return (w**2).sum() / 2

3.12.3.3 定义训练和测试

batch_size, num_epochs, lr = 1, 100, 0.003
net, loss = d2l.linreg, d2l.squared_lossdataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)def fit_and_plot(lambd):w, b = init_params()train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:# 添加了L2范数惩罚项l = loss(net(X, w, b), y) + lambd * l2_penalty(w)l = l.sum()if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()d2l.sgd([w, b], lr, batch_size)train_ls.append(loss(net(train_features, w, b), train_labels).mean().item())test_ls.append(loss(net(test_features, w, b), test_labels).mean().item())d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('L2 norm of w:', w.norm().item())

3.12.3.4 观察过拟合

fit_and_plot(lambd=0)

3.12.3.5 使用权重衰减

fit_and_plot(lambd=5)

3.12.4 简洁实现

def fit_and_plot_pytorch(wd):# 对权重参数衰减。权重名称一般以weight结尾net = nn.Linear(num_inputs, 1)nn.init.normal_(net.weight, mean=0, std=1)nn.init.normal_(net.bias, mean=0 , std=1)optimizer_w = torch.optim.SGD(params=[net.weight], lr= lr, weight_decay=wd)  # 对权重进行衰减optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr)  # 对偏差不进行衰减train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:l = loss(net(X), y).mean()optimizer_w.zero_grad()optimizer_b.zero_grad()l.backward()# 对两个optimizer实例分别调用step函数,从而分别更新权重和偏差optimizer_w.step()optimizer_b.step()train_ls.append(loss(net(train_features), train_labels).mean().item())test_ls.append(loss(net(test_features), test_labels).mean().item())d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('L2 norm of w:', net.weight.data.norm().item())fit_and_plot_pytorch(0)fit_and_plot_pytorch(3)


[pytorch、学习] - 3.12 权重衰减相关推荐

  1. (pytorch-深度学习系列)pytorch避免过拟合-权重衰减的实现-学习笔记

    pytorch避免过拟合-权重衰减的实现 首先学习基本的概念背景 L0范数是指向量中非0的元素的个数:(L0范数难优化求解) L1范数是指向量中各个元素绝对值之和: L2范数是指向量各元素的平方和然后 ...

  2. Pytorch 正则化方法(权重衰减和Dropout)

    正则化方法(权重衰退和Dropout) 正则化方法和以前学过的正则表达式没有任何关系! 花书 p141 说到: 能显式地减少测试误差(可能会以增大训练误差为代价)的方法都被称为正则化. 0. 环境介绍 ...

  3. pytorch学习笔记(十二):权重衰减

    文章目录 1. 方法 2. 高维线性回归实验 3. 从零开始实现 3.1 初始化模型参数 3.2 定义L2L_2L2​范数惩罚项 3.3 定义训练和测试 3.4 观察过拟合 3.5 使用权重衰减 4. ...

  4. 深度学习的权重衰减是什么_【深度学习理论】一文搞透Dropout、L1L2正则化/权重衰减...

    前言 本文主要内容--一文搞透深度学习中的正则化概念,常用正则化方法介绍,重点介绍Dropout的概念和代码实现.L1-norm/L2-norm的概念.L1/L2正则化的概念和代码实现- 要是文章看完 ...

  5. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  6. 深度学习的权重衰减是什么_权重衰减和L2正则化是一个意思吗?它们只是在某些条件下等价...

    权重衰减== L2正则化? 神经网络是很好的函数逼近器和特征提取器,但有时它们的权值过于专门化而导致过度拟合.这就是正则化概念出现的地方,我们将讨论这一概念,以及被错误地认为相同的两种主要权重正则化技 ...

  7. 深度学习PyTorch笔记(12):线性神经网络——softmax回归

    深度学习PyTorch笔记(12):线性神经网络--softmax回归 6 线性神经网络--softmax回归 6.1 softmax回归 6.1.1 概念 6.1.2 softmax运算 6.2 图 ...

  8. 深度学习:权重衰减(weight decay)与学习率衰减(learning rate decay)

    正则化方法:防止过拟合,提高泛化能力 避免过拟合的方法有很多:early stopping.数据集扩增(Data augmentation).正则化(Regularization)包括L1.L2(L2 ...

  9. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

最新文章

  1. Dos批处理常用命令大全入门
  2. 计算机图形学基础考试题,计算机图形学基础复习题
  3. Java异常处理——try-with-resource 语法糖
  4. [itint5]棋盘漫步
  5. Android Weekly Notes Issue #220
  6. Windows 10下,anaconda (conda) 虚拟环境的创建,jupyter notebook如何使用虚拟环境
  7. Android7.1启动系统App必须配置加密
  8. 再谈和字体有关的几个问题
  9. SharePoint2010内容类型剖析(三)
  10. 用C#绘图实现动画出现卡屏(运行慢)问题的解决办法
  11. SQL Server常用的字符串/日期/系统函数
  12. CCNA 折扣号申请流程(新版)
  13. 前端根据后端数据生成表格 行列合并 指定表头
  14. 自动化测试之Appium
  15. iOS开发 - 获取网关IP,运营商,位置,可判断是在国内还是国外
  16. 前端基础:CSS 3
  17. 理解和应用持续集成-Tekton
  18. C语言:L1-039 古风排版 (20 分)
  19. 移动网站性能优化:网页加载技术概览
  20. 微服务 - Hystrix 熔断器

热门文章

  1. 基于android 定位系统,基于Android平台定位系统设计和实现
  2. java s1_转!!Java 基础面试题的剖析: short s1=1;s1 = s1 +1 报错? s1+=1 呢
  3. GPU Gems1 - 26 OpenEXR图像文件格式与HDR(The OpenEXR Image File Format and HDR)
  4. php自动运维,运维自动化之使用PHP+MYSQL+SHELL打造私有监控系统(五)
  5. 中控ecs700 mysql_浙大中控ECS700工程指导手册.pdf
  6. iOS手势操作简介(四)
  7. NOIP2005普及组第3题 采药 (背包问题)
  8. 今天试了一下iscroll
  9. ASP.NET跨页面传值技巧总结
  10. Django Tips