本章介绍了数个训练模型(非常多的数学公式对咱十分不友好)

1.线性回归

随机生成线性数据集:

import numpy as npX = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)plt.plot(X, y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.axis([0, 2, 0, 15])
save_fig("generated_data_plot")
plt.show()

使用sklearn执行线性回归

from sklearn.linear_model import LinearRegressionlin_reg = LinearRegression()
lin_reg.fit(X, y)
lin_reg.intercept_, lin_reg.coef_

很接近原始参数。

2.梯度下降

批量梯度下降(学习率0.1):

eta = 0.1  # learning rate
n_iterations = 1000
m = 100theta = np.random.randn(2,1)  # random initializationfor iteration in range(n_iterations):gradients = 2/m * X_b.T.dot(X_b.dot(theta) - y)theta = theta - eta * gradientstheta

随机梯度下降:

theta_path_sgd = []
m = len(X_b)
np.random.seed(42)
n_epochs = 50
t0, t1 = 5, 50  # learning schedule hyperparametersdef learning_schedule(t):return t0 / (t + t1)theta = np.random.randn(2,1)  # random initializationfor epoch in range(n_epochs):for i in range(m):random_index = np.random.randint(m)xi = X_b[random_index:random_index+1]yi = y[random_index:random_index+1]gradients = 2 * xi.T.dot(xi.dot(theta) - yi)eta = learning_schedule(epoch * m + i)theta = theta - eta * gradientstheta

小批量梯度下降

theta_path_mgd = []n_iterations = 50
minibatch_size = 20np.random.seed(42)
theta = np.random.randn(2,1)  # random initializationt0, t1 = 200, 1000
def learning_schedule(t):return t0 / (t + t1)t = 0
for epoch in range(n_iterations):shuffled_indices = np.random.permutation(m)X_b_shuffled = X_b[shuffled_indices]y_shuffled = y[shuffled_indices]for i in range(0, m, minibatch_size):t += 1xi = X_b_shuffled[i:i+minibatch_size]yi = y_shuffled[i:i+minibatch_size]gradients = 2/minibatch_size * xi.T.dot(xi.dot(theta) - yi)eta = learning_schedule(t)theta = theta - eta * gradientstheta_path_mgd.append(theta)theta

3.多项式回归

生成一个二次多项式:

import numpy as np
import numpy.random as rndnp.random.seed(42)
m = 100
X = 6 * np.random.rand(m, 1) - 3
y = 0.5 * X**2 + X + 2 + np.random.randn(m, 1)
plt.plot(X, y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.axis([-3, 3, 0, 10])
save_fig("quadratic_data_plot")
plt.show()

使用模型预测参数:

from sklearn.preprocessing import PolynomialFeatures
poly_features = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly_features.fit_transform(X)
X[0]lin_reg = LinearRegression()
lin_reg.fit(X_poly, y)
lin_reg.intercept_, lin_reg.coef_

模型估算y=0.56x^2+0.93x+1.78,很接近。

4.学习曲线

通过观察学习曲线来判断模型是过于简单还是过于复杂:

1.

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_splitdef plot_learning_curves(model, X, y):X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=10)train_errors, val_errors = [], []for m in range(1, len(X_train)):model.fit(X_train[:m], y_train[:m])y_train_predict = model.predict(X_train[:m])y_val_predict = model.predict(X_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend(loc="upper right", fontsize=14)   plt.xlabel("Training set size", fontsize=14) plt.ylabel("RMSE", fontsize=14)
lin_reg = LinearRegression()
plot_learning_curves(lin_reg, X, y)
plt.axis([0, 80, 0, 3])
save_fig("underfitting_learning_curves_plot")
plt.show()                                      

这说明模型欠拟合

2.

from sklearn.pipeline import Pipelinepolynomial_regression = Pipeline([("poly_features", PolynomialFeatures(degree=10, include_bias=False)),("lin_reg", LinearRegression()),])plot_learning_curves(polynomial_regression, X, y)
plt.axis([0, 80, 0, 3])
save_fig("learning_curves_plot")
plt.show()          

这说明模型过拟合。

后续书中还介绍了正则化线性模型和逻辑回归,遗憾有些不知所云,后续加强学习。

《机器学习实战》第四章相关推荐

  1. 机器学习实战——第四章(分类):朴素贝叶斯

    前言 首先感谢博主:Jack-Cui 主页:http://blog.csdn.net/c406495762 朴素贝叶斯博文地址: https://blog.csdn.net/c406495762/ar ...

  2. 尚学堂java实战第四章课后习题

    尚学堂java实战第四章课后习题 文章中的题目答案仅供参考 选择题答案: 1.B 解析:一个java类必然存在构造器,即使没有定义构造器,也会存在一个默认的无参构造器. 2.D 3.AC 解析: A( ...

  3. 李弘毅机器学习:第四章—梯度下降法

    李弘毅机器学习:第四章-梯度下降法 什么是梯度下降法? Review: 梯度下降法 Tip1:调整学习速率 小心翼翼地调整学习率 自适应学习率 Adagrad 算法 Adagrad 是什么? Adag ...

  4. 零基础学Python课后实战第四章

    零基础学Python课后实战第四章 实战一:输出王者荣耀的游戏角色 实战二:模拟火车订票系统 实战三:电视剧的收视率排行榜 tips 实战一:输出王者荣耀的游戏角色 列表的创建.遍历列表 代码 pri ...

  5. 《机器学习实战》第二章学习笔记:K-近邻算法(代码详解)

    <机器学习实战>数据资料以及总代码可以去GitHub中下载: GitHub代码地址:https://github.com/yangshangqi/Machine-Learning-in-A ...

  6. 机器学习实战:第一章

    根据方教授的建议和要求,在暑假里简单自学<机器学习实战>,记录学习过程和代码. 记 第一章是对机器学习的一些概念介绍,定义了若干专业术语.列举了很多机器学习的各类实例.给出了一个" ...

  7. 机器学习实战第15章pegasos算法原理剖析以及伪代码和算法的对应关系

    Pegasos原文是: http://ttic.uchicago.edu/~nati/Publications/PegasosMPB.pdf 还是挺长的,论文结构是: 第1~6页:主要原理 第7~15 ...

  8. 机器学习实战(四)逻辑回归LR(Logistic Regression)

    目录 0. 前言 1. Sigmoid 函数 2. 梯度上升与梯度下降 3. 梯度下降法(Gradient descent) 4. 梯度上升法(Gradient ascent) 5. 梯度下降/上升法 ...

  9. android movie 资源释放,Android 资讯类App项目实战 第四章 电影模块

    前言: 正在做一个资讯类app,打算一边做一边整理,供自己学习与巩固.用到的知识复杂度不高,仅适于新手.经验不多,如果写出来的代码有不好的地方欢迎讨论. 以往的内容 第四章 电影模块 本章内容最终效果 ...

  10. 吴恩达机器学习(第四章)——多变量线性回归

    第四章-多变量线性回归 文章目录 第四章-多变量线性回归 多功能 多元梯度下降法 梯度下降算法 特征缩放 学习率 特征与多项式回归 正规方程 正规方程的概念 公式的推导 梯度下降法 VS 正规方程 奇 ...

最新文章

  1. bootstarp js设置列隐藏_隐藏工作表的行、列(第一种简单,第二种很坑,第三种最坑)...
  2. python实现记事本的查找功能_Python + PyQt4 实现记事本功能
  3. Java File类基本操作
  4. pat1049. Counting Ones (30)
  5. 杭哥试用过的精品软件推荐:PDF转power point 格式-----PDFtoPowerPointPortable 已注册版本...
  6. 使用海康威视设备在Web端显示实时视频
  7. 苹果cmsV10高仿草民电影网在线影视网站模板 带手机版
  8. Linux下的文件I/O编程
  9. CISP 考试教材《第 3 章 知识域:信息安全管理》知识整理
  10. 基于Spatial-Temporal Transformer的城市交通流预测
  11. 怎么用电脑把mp4格式转换成mp3格式
  12. mysql访问错误:1682
  13. 2018-12支付宝红包赚钱薅羊毛全攻略
  14. [DAX] IF函数
  15. 2022爱分析·人工智能厂商全景报告
  16. 工控领域为什么需要OPC,OPC是什么?
  17. 制作U-Boot的SD启动卡
  18. ZZY‘s_wsl_guide
  19. VLC android 3.0解码器使用及移植TV项目调研
  20. php mcrypt_encrypt,PHP 将 mcrypt_encrypt 迁移至 openssl_encrypt 的方法

热门文章

  1. 《千万别学英语》精粹
  2. 最新流浪猫流浪狗H5完整运营源码下载/可封装APP
  3. ai如何旋转画布_Illustrator让一个图形沿着某一点或顶点旋转复制教程
  4. Multisim基础 DIP开关 添加元件的位置
  5. 天猫用户重复购买预测——特征工程
  6. [4G/5G/6G专题基础-154]: 5G无线准入控制RAC(Radio Admission Control)+ 其他准入控制
  7. 新手玩转unwallet攻略
  8. java junit 覆盖率_java单元测试篇:使用clover为junit单元测试做覆盖率分析(二)...
  9. 实战电商后端系统(三)—— 以vue-element-admin为基础的前端项目对接后端接口
  10. 华为便携机修改服务器密码,华为随身WiFi如何修改WiFi密码 华为随身WiFi修改WiFi密码方法【介绍】...