机器学习课程笔记【三】广义线性模型(2)-构建广义线性模型
本节为吴恩达教授机器学习课程笔记第三部分,广义线性模型(2)构建广义线性模型,包括最小均方算法、逻辑回归,着重介绍softmax的推导,并给出softmax的核心代码以及pytorch实现。
2. 构建GLMS广义线性模型
考虑一个分类或者回归问题,我们希望得到一个x的函数,来计算随机变量y的值。为了推导出一个广义线性模型,我们给出这样三个假设:
- 给定x和θ\thetaθ,我们假定y的分布服从参数为η\etaη的某个指数族分布
- 给定x,目标是预测T(y)T(y)T(y)的期望值,在绝大多数例子中,T(y)=yT(y)=yT(y)=y,因此,我们想要预测h(x)h(x)h(x)的值使得h(x)=E[y∣x]h(x)=E[y|x]h(x)=E[y∣x]
- 自然参数和x是线性相关的,即η=θTx\eta=\theta^Txη=θTx,对于向量来说,即ηi=θiTx\eta_i=\theta_i^Txηi=θiTx
2.1 典型的最小均方算法
典型的最小均方算法也是广义线性模型族中的一个,在之前高斯分布写为指数族形式中有μ=η\mu=\etaμ=η,于是有下面的式子:
2.2 逻辑回归
对于二分类问题,有:
2.3 Softmax回归
考虑一个多分类问题,即模型有多个可能的输出对应多个不同的类别,我们可以用多项式分布来对该问题进行建模。
首先,我们可以将多项式分布表达为指数族的形式。定义:
即:
这里,T(y)T(y)T(y)是一个k−1k-1k−1维向量,用(T(y))i(T(y))_i(T(y))i表示向量T(y)T(y)T(y)的第i个分量。
此外引入一种特殊的定义,即1{⋅\cdot⋅},如果它接受的参数为True,则输出1,参数为False,输出0。这样我们可将把y和新的T(y)T(y)T(y)的关系写为:
更进一步地,有:
接下来证明多项式分布确实属于指数族分布,我们有:
对应地:
于是有:
这表明:
从而得到:
这个将η′s\eta'sη′s映射为φ′s\varphi'sφ′s的函数成为softmax函数,根据之前给出的假设3,有:
这个多分类模型即softmax回归,是逻辑回归的推广,所以hθ(x)h_{\theta}(x)hθ(x)会输出:
换句话说,hθ(x)h_{\theta}(x)hθ(x)将输出属于每一类的概率。
最后我们来探讨参数拟合的问题,给出如下对数似然函数:
之后我们就可以用牛顿法或者梯度下降来最大化对数似然函数已得到参数值了。
softmax回归的核心代码
# softmax函数,将线性回归值转化为概率的激活函数。输入s要是行向量
def softmax(s):return np.exp(s) / np.sum(np.exp(s), axis=1)# 逻辑回归中使用梯度下降法求回归系数。逻辑回归和线性回归中原理相同,只不过逻辑回归使用sigmoid作为迭代进化函数。因为逻辑回归是为了分类而生。线性回归为了预测而生
def gradAscent(dataMat, labelPMat):alpha = 0.2 #移动步长,也就是学习速率,控制更新的幅度。maxCycles = 1000 #最大迭代次数weights = np.ones((dataMat.shape[1],labelPMat.shape[1])) #初始化权回归系数矩阵 系数矩阵的行数为特征矩阵的列数,系数矩阵的列数为分类数目for k in range(maxCycles):h = softmax(dataMat*weights) #梯度上升矢量化公式,计算预测值(行向量)。每一个样本产生一个概率行向量error = h-labelPMat #计算每一个样本预测值误差weights = weights - alpha * dataMat.T * error # 根据所有的样本产生的误差调整回归系数return weights # 将矩阵转换为数组,返回回归系数数组
pytorch实现
EPOCH = 5000
BATCH_SIZE = 100
LR = 0.01train_set = Data.TensorDataset(x_train,y_train)
train_loader = Data.DataLoader(dataset=train_set,batch_size=BATCH_SIZE,shuffle=True)net = nn.Sequential(nn.Linear(8,50),nn.ReLU(),nn.Linear(50,4))
LOSS_FUNC = nn.CrossEntropyLoss() # 损失函数torch.nn.CrossEntropyLoss()中已经包含了Softmax函数
OPTIMIZER = torch.optim.SGD(net.parameters(), lr=LR) for epoch in range(1,EPOCH+1):loss_sum = 0.0for step,(x,y) in enumerate(train_loader):y_pred = net(x)y = y.squeeze() #修正标签格式loss = LOSS_FUNC(y_pred,y)loss_sum += lossOPTIMIZER.zero_grad()loss.backward()OPTIMIZER.step()print("epoch: %d, loss: %f" %(epoch,loss_sum/BATCH_SIZE))
欢迎扫描二维码关注微信公众号 深度学习与数学 [每天获取免费的大数据、AI等相关的学习资源、经典和最新的深度学习相关的论文研读,算法和其他互联网技能的学习,概率论、线性代数等高等数学知识的回顾]
机器学习课程笔记【三】广义线性模型(2)-构建广义线性模型相关推荐
- 吴恩达机器学习课程笔记一
吴恩达机器学习课程笔记 前言 监督学习---`Supervised learning` 无监督学习---`Unsupervised learning` 聚类 异常检测 降维 增强学习---`Reinf ...
- 李宏毅2020机器学习课程笔记(二)
相关专题: 李宏毅2020机器学习资料汇总 李宏毅2020机器学习课程笔记(一) 文章目录 4. CNN Convolutional Neural Network(P17) 5. GNN Graph ...
- Github标星24300!吴恩达机器学习课程笔记.pdf
个人认为:吴恩达老师的机器学习课程,是初学者入门机器学习的最好的课程!我们整理了笔记(336页),复现的Python代码等资源,文末提供下载. 课程简介 课程地址:https://www.course ...
- 机器学习总结——机器学习课程笔记整理
机器学习笔记整理 说明 基础点整理 1. 基础数学知识 (1) 一些零七八碎的基础知识 (2) 最优化相关问题 (3) 概率论相关问题 (4) 矩阵相关问题 2. 回归(线性回归.Logistic回归 ...
- 干货|机器学习零基础?不要怕,吴恩达机器学习课程笔记2-多元线性回归
吴恩达Coursera机器学习课系列笔记 课程笔记|吴恩达Coursera机器学习 Week1 笔记-机器学习基础 1 Linear Regression with Multiple Variable ...
- 李弘毅机器学习课程笔记(一):机器/深度学习入门
文章目录 什么是ML ML分类 一个例子 Model(function) Loss function Error surface Optimization Conclusion 最近在Youtube上 ...
- 唐宇迪机器学习课程笔记:逻辑回归之信用卡检测任务
信用卡欺诈检测 基于信用卡交易记录数据建立分类模型来预测哪些交易记录是异常的哪些是正常的. 任务流程: 加载数据,观察问题 针对问题给出解决方案 数据集切分 评估方法对比 逻辑回归模型 建模结果分析 ...
- 吴恩达机器学习课程笔记(英文授课) Lv.1 新手村(回归)
目录 1-1机器学习的相关名词 1-2 什么是机器学习? 1.definition 定义 2.主要的机器学习算法的分类 1-3有监督学习及常用算法 1.定义 2.两种数据类型补充:categorica ...
- 李宏毅2020机器学习课程笔记(一)
文章目录 1. 课程简介 Course Introduction(P1) Rule of ML 2020(P2) 2. Regression Case Study (P3) Basic concept ...
- 吴恩达机器学习课程笔记(11-19章)
第十一章 11.1 确定执行的优先级 垃圾邮件分类器算法: 为了解决这样一个问题,我们首先要做的决定是如何选择并表达特征向量 x x x .我们可以选择一个由 100 100 100 个最常出现在垃圾 ...
最新文章
- FD.io/VPP — VPP Agent — Overview
- CodeSmith和PowerDesigner的使用安装和数据库创建(原创系列教程)
- 深度学习:卷积神经网络(convolution neural network)
- 随想录(由自定义打印函数想到的)
- 自闭症患者很难读懂他人情绪?情绪机器人来帮忙
- Halcon基础操作
- redis源码分析(2)——事件循环
- Bootstrap响应式图表设计
- 基于GEE使用Landsat 8和Landsat 5影像进行分类
- 细谈数据库表锁和行锁
- 爱学术,让论文写作不再难!
- Hadoop大数据入门
- ros手柄控制机器人小车(一)
- 现在好的测试缺陷管理工具都有哪些啊?
- 2008年8月26号,星期二,晴。欲穷千里目,更上一层楼。 —— 王之涣《登鹳雀楼》今天是我博士生涯的第51天,争吵,分歧,以自我为中心的考虑问题,那个关键问题
- 在Excel中创建彩色的Harvey球
- 年后创业,该如何选择适合年轻人的小成本创业项目?
- LTC6268-10 4GHz 超低偏置电流 FET 输入运算放大器
- java用浏览器下载文件_JAVA读取文件流,设置浏览器下载或直接预览操作
- 【ZJX-3A AC220V剪断销信号装置】