分类问题的目标:预测输入数据点所对应的单一离散变量。

回归问题的目标:预测一个连续值而不是单一的标签。例如,温度预测、房价预测。

1.数据集:波士顿房价数据集

该数据集包含的数据点相对较少,只有506个,分为404个训练样本和102个测试样本。输入数据的每个特征都有不同的取值范围。

# 1.加载数据集
from keras.datasets import boston_housing
(train_data, train_targets), (test_data, test_targets) = boston_housing.load_data()

2.准备数据:数据标准化

对于取值范围差异很大的数据,我们采取的数据预处理的方法是数据标准化,即对于输入数据的每个特征,减去特征的平均值,再除以标准差,这样得到均值为0,标准差为1。用numpy很容易实现标准化。

# 数据集中的数据有不同的取值范围,且差异较大。
# 准备数据。数据标准化:对于输入数据矩阵中的列,减去特征平均值,再除以标准差。
mean = train_data.mean(axis=0)  # axis = 0 对列运算,变成一行,实际上是求每列均值
train_data -= mean  # 减去平均值
std = train_data.std(axis=0)
train_data /= std  # 除以标准差test_data -= mean
test_data /= std

3.构建网络

一般来说,训练数据越少,过拟合越严重,而较小的网络可以降低过拟合,所以我们使用一个非常小的网络。

# 构建神经网络
from keras import models
from keras import layersdef build_model():  # 因为需要将这个模型多次实例化,所以需要构建一个函数模型model = models.Sequential()model.add(layers.Dense(64, activation='relu', input_shape=(train_data.shape[1],)))model.add(layers.Dense(64, activation='relu'))model.add(layers.Dense(1))  # 没有激活函数,是一个线性层。限制输出范围,可以学到预测任意范围的值model.compile(optimizer='rmsprop', loss='mse', metrics=['mae'])return model"""
Q:为什么这个网络最后一层不使用激活函数?A:不使用激活函数的话这就是一个线性层。这是标量回归(标量回归是预测单一连续值的回归)的典型设置。添加激活函数将会限制输出范围。例如,如果向最后一层添加sigmoid激活函数,网络只能学会预测0~1范围内的值。这里最后一层是纯线性的,所以网络可以学会预测任意范围内的值
"""

4.K折交叉验证

由于数据点非常少,验证集也会非常少,验证分数可以会有很大波动,这样情况下,我们使用K折验证法。

K折交叉验证:K折交叉验证使用了无重复抽样技术的好处:每次迭代过程中每个样本点只有一次被划入训练集或测试集的机会。将可用数据划分为K个分区(K通常取4或5),实例化K个相同的模型,将每个模型在K-1个分区上训练,并在剩下的一个分区上进行评估。模型的验证分数等于K个验证分数的平均值。

a.如果训练集相对较小,则增大k值。

增大k值,在每次迭代过程中将会有更多的数据用于模型训练,能够得到最小偏差,同时算法时间延长。且训练块间高度相似,导致评价结果方差较高。

b.如果训练集相对较大,则减小k值。

减小k值,降低模型在不同的数据块上进行重复拟合的性能评估的计算成本,在平均性能的基础上获得模型的准确评估。

#  K折验证
import numpy as np
#  训练网络,用K折验证法对数据进行训练验证
k = 4
num_val_samples = len(train_data) // k  # 除以K商取整。把训练数据分成4份,每份是多少
num_epochs = 100  # 训练100次
all_scores = []  # 建立一个存放分数的列表
for i in range(k):  # i=0,1,2,3 循环4次print('processing fold #', i)val_data = train_data[i * num_val_samples: (i + 1) * num_val_samples]  # i=0,第一批数据,i=1第一批数据val_targets = train_targets[i * num_val_samples: (i + 1) * num_val_samples]partial_train_data = np.concatenate(  # 剩余的数据,其他所有分区的数据/ concatenate()能够一次完成多个数组的拼接。[train_data[:i * num_val_samples],train_data[(i + 1) * num_val_samples:]],axis=0)  # i=0,就是[1~最后]。i=1,就是合并[0~1]和[2~最后]partial_train_targets = np.concatenate(  # concatenate合并两个array数组,axis =0 ,纵向合并[train_targets[:i * num_val_samples],train_targets[(i + 1) * num_val_samples:]],axis=0)model = build_model()  #构建keras模型model.fit(partial_train_data, partial_train_targets,epochs=num_epochs, batch_size=1, verbose=0)  # 训练模型val_mse, val_mae = model.evaluate(val_data, val_targets, verbose=0)  # 在验证数据上评估模型all_scores.append(val_mae)  # 在列表末尾添加新对象,平均绝对误差
print(all_scores)
mean = np.mean(all_scores)
print(mean)

5.保存每折的验证结果

# 保存每折的验证结果
from keras import backend as K
# Some memory clean-upnum_epochs = 500
all_mae_histories = []
for i in range(k):print('processing fold #', i)# Prepare the validation data: data from partition # kval_data = train_data[i * num_val_samples: (i + 1) * num_val_samples]val_targets = train_targets[i * num_val_samples: (i + 1) * num_val_samples]# Prepare the training data: data from all other partitionspartial_train_data = np.concatenate([train_data[:i * num_val_samples],train_data[(i + 1) * num_val_samples:]],axis=0)partial_train_targets = np.concatenate([train_targets[:i * num_val_samples],train_targets[(i + 1) * num_val_samples:]],axis=0)# Build the Keras model (already compiled)model = build_model()# Train the model (in silent mode, verbose=0)history = model.fit(partial_train_data, partial_train_targets,validation_data=(val_data, val_targets),epochs=num_epochs, batch_size=1, verbose=0)# 就是不输出日志信息 ,进度条、loss、acc这些都不输出,verbose=0mae_history = history.history['val_mae']  # fit返回一个history对象,这个对象有一个history字典,all_mae_histories.append(mae_history)
print(all_mae_histories)# 计算所有轮次中的K折验证分数平均值。
average_mae_history = [np.mean([x[i] for x in all_mae_histories]) for i in range(num_epochs)]
print(average_mae_history)

6.绘制验证分数图像

import matplotlib.pyplot as plt
plt.plot(range(1,len(average_mae_history)+1),average_mae_history)
plt.xlabel('Epochs')
plt.ylabel('Validation MAE')
plt.show()

结果:


由于纵轴的范围较大,且数据方差相对较大,难以看清这张图的规律,我们重新绘制一张图。

a.删除前10个数据点,因为他们的取值范围与曲线上的其他点不同

b.将每个数据点替换为前面数据点的指数移动平均值,得到光滑的曲线。

# 纵轴范围太大,方差较大,把每个数据点替换成
def smooth_curve(points, factor=0.9):  # 数据点,权重系数smoothed_points = []  # 建立一个空的列表,用于存放光滑数据点for point in points:  # 遍历所有的数据点if smoothed_points:  # 如果列表中有数据,则执行下面步骤previous = smoothed_points[-1]smoothed_points.append(previous * factor + point * (1 - factor))#  指数移动平均值EMA,前一个数据点*加权系数+当前数据点*(1-加权系数)else:smoothed_points.append(point)  # append添加到列表中最后面return smoothed_pointssmooth_mae_history = smooth_curve(average_mae_history[10:])
# 输入K折验证平均值,删除前前10个取值范围与曲线不同的点
plt.plot(range(1, len(smooth_mae_history) + 1), smooth_mae_history)
plt.xlabel('Epochs')
plt.ylabel('Validation MAE')
plt.show()# 训练最终模型
model = build_model()
model.fit(train_data, train_targets, epochs=80, batch_size=16, verbose=0)
test_mes_score, test_mae_score = model.evaluate(test_data, test_targets)# 输出最终结果
print(test_mae_score)
# 2.509598970413208


EMA例子

import matplotlib.pyplot as pltpoints = [1, 5, 3, 9, 4]
def smooth_curve(points, factor=0.9):smoothed_points =[] # 数据点,权重系数for point in points:  # 遍历所有的数据点if smoothed_points:  # 如果列表中有数据,则执行下面步骤previous = smoothed_points[-1]smoothed_points.append(previous * factor + point * (1 - factor))#  指数移动平均值EMA,前一个数据点*加权系数+当前数据点*(1-加权系数)else:smoothed_points.append(point)  # append添加到列表中最后面return smoothed_pointsresults = smooth_curve(points)
print(results)
plt.plot(range(1, len(points) + 1), results)
plt.show()


回归问题总结:

1.损失函数与分类问题不同,回归问题常用均方误差(MSE)

2.评估指标与分类问题不同,回归问题常用平均绝对误差(MAE)

3.可用数据很少,可使用K折验证,减小网络模型

《python深度学习》笔记(八):回归问题相关推荐

  1. Kera之父Python深度学习笔记(一)什么是深度学习

    目录 机器学习 深度学习 其他机器学习(简史) 概率建模 早期神经网络 核方法 决策树.随机森林与梯度提升机 机器学习 学习:寻找数据自动搜索的优化过程 假设空间:一组实现定义好的操作(机器学习寻找变 ...

  2. Kera之父Python深度学习笔记(二)神经网络的数学基础

    目录 神经网络的数据表示 标量(0D张量) 向量(1D张量) 矩阵(2D张量) 3D张量以及更高维张量 关键属性 在Numpy中操作张量 数据批量的概念 现实世界中的数据张量 向量数据 时间序列数据和 ...

  3. Programming Computer Vision with Python (学习笔记八)

    图像去噪(Image Denoising)的过程就是将噪点从图像中去除的同时尽可能的保留原图像的细节和结构.这里讲的去噪跟前面笔记提过的去噪不一样,这里是指高级去噪技术,前面提过的高斯平滑也能去噪,但 ...

  4. 《Python 深度学习》刷书笔记 Chapter 3 预测房价:回归问题

    文章目录 波士顿房价数据集 3-24 加载波士顿房价数据 3-25 数据标准化 3-26 模型定义 3-27 K折验证 3-28 训练500轮,保存每折的验证结果 3-29 计算所有轮次茨种的K折验证 ...

  5. 深度学习笔记(3) 向量化逻辑回归

    深度学习笔记(3) 向量化逻辑回归 1. 向量化运算的优势 2. 向量化编程 3. 举例 1. 向量化运算的优势 python的向量化运算速度快,是非常基础的去除代码中for循环的艺术 可以看出相同的 ...

  6. 《Python深度学习》第一章笔记

    <Python深度学习>第一章笔记 1.1人工智能.机器学习.深度学习 人工智能 机器学习 深度学习 深度学习的工作原理 1.2深度学习之前:机器学习简史 概率建模 早期神经网络 核方法 ...

  7. AI Studio 飞桨 零基础入门深度学习笔记2-基于Python编写完成房价预测任务的神经网络模型

    AI Studio 飞桨 零基础入门深度学习笔记2-基于Python编写完成房价预测任务的神经网络模型 波士顿房价预测任务 线性回归模型 线性回归模型的神经网络结构 构建波士顿房价预测任务的神经网络模 ...

  8. 一文让你完全弄懂逻辑回归和分类问题实战《繁凡的深度学习笔记》第 3 章 分类问题与信息论基础(上)(DL笔记整理系列)

    好吧,只好拆分为上下两篇发布了>_< 终于肝出来了,今天就是除夕夜了,祝大家新快乐!^q^ <繁凡的深度学习笔记>第 3 章 分类问题与信息论基础 (上)(逻辑回归.Softm ...

  9. 一文让你完全弄懂回归问题、激活函数、梯度下降和神经元模型实战《繁凡的深度学习笔记》第 2 章 回归问题与神经元模型(DL笔记整理系列)

    <繁凡的深度学习笔记>第 2 章 回归问题与神经元模型(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net/ http ...

  10. python神经结构二层_《python深度学习》笔记---8.3、神经风格迁移

    <python深度学习>笔记---8.3.神经风格迁移 一.总结 一句话总结: 神经风格迁移是指将参考图像的风格应用于目标图像,同时保留目标图像的内容. 1."神经风格迁移是指将 ...

最新文章

  1. Shark0.9.1安装
  2. 开个定时器给echarts组件配置定时更新
  3. 交叉编译及linux简单程序设计,嵌入式实验6交叉编译及Linux简单程序设计实验
  4. 森林怎么训练野人_138年前抓的“野人女孩”,最终命运如何?死前心愿让人心酸...
  5. 掌握Spark机器学习库-06-基础统计部分
  6. 计算机导航医学应用,【2016年】计算机导航在全膝关节置换中的应用技术及进展【临床医学论文】.doc...
  7. 如何实现 迭代器 可迭代对象 (2.1)
  8. python中的__file__、os.path.realpath(__file__)、os.path.dirname(os.path.realpath(__file__))
  9. 树莓派十周年,回顾它的发展历程
  10. 房友系统服务器地址,房友系统的那些功能,你都知道吗?
  11. PHP视频教程源码书籍web前端ThinkPHP5/5.0商城实战开发html5秒杀
  12. matlab保存pdf图片太大,matlab中的图片保存方法精选.pdf
  13. 2015年全国大学生电子设计竞赛专题系列之综合测评-Multisim使用技巧
  14. connect holder is null问题记录
  15. iPhone 6年代iWatch热销缘由:可穿戴更有招引力
  16. 渗透利器BurpSuite简介
  17. 为了不手动命名驼峰变量名,我开发了一套油猴脚本...
  18. 【论文笔记_知识蒸馏_2022】Knowledge Distillation with the Reused Teacher Classifier
  19. css3动画添加间隔
  20. 使用layer弹出层组件绑定页面按钮

热门文章

  1. java 502错误_nginx 502 超时错误解决(java版本)
  2. 银河麒麟系统10服务器安装教程,麒麟系统下安装win10的详细教程
  3. 自签名证书和私有CA签名的证书的区别
  4. 谷歌应用商店开发者注册
  5. 【7】PR音频及结合AU去除噪音【8】PR字幕运用
  6. js清除网页广告代码
  7. cwRsync 文件备份
  8. Python的学习心得和知识总结(十二)|Python图形用户接口编程(Graphical User Interface编程 一)
  9. maven运行Error:(3, 14) java: 程序包不存在
  10. 科技解读:com域名价格为什么连年上涨?小微企业怎样应对?