python实现SGD(stochastic gradient descent)算法
文章目录
- 源码
- 知识点
- 1. 实现参数(权重)矩阵初始化小值:
- 2. 打乱数据
- 3. 实现随机梯度下降
- 4. 使用函数实现数据处理步骤
- 5. 在保持初始化权重不变的情况下实现训练
- 6. 返回自己
- 结果
源码
此为适应机使用随机梯度下降的python实现,代码主要来源于Python Machine Learning 3rd(此书包含大量python实现算法的源码,宜啃读)自己把代码又弄了一遍,欢迎想看或者看过这本书的来交流讨论
class AdalineSGD(object):def __init__(self, eta=0.01, n_iter=10, shuffle=True, random_state=None):self.eta = etaself.n_iter = n_iterself.w_initialized = Falseself.shuffle = shuffleself.random_state = random_statedef fit(self, X, y):self._initialize_weights(X.shape[1])self.cost_ = []for i in range(self.n_iter):if self.shuffle:X, y = self._shuffle(X, y)cost = []for i in range(self.n_iter):if self.shuffle:X, y = self._shuffle(X, y)cost = []for xi, target in zip(X, y):cost.append(self._update_weights(xi, target))avg_cost = sum(cost) / len(y)self.cost_.append(avg_cost)return selfdef partial_fit(self, X, y):if not self.w_initialized:self._initialize_weights(X.shape[1])if y.ravel().shape[0] > 1:for xi, target in zip(X, y):self._update_weights(xi, target)else: self._update_weights(X, y)return selfdef _initialize_weights(self, m):self.rgen = np.random.RandomState(self.random_state)self.w_ = self.rgen.normal(loc=0.0, scale=0.01, size=1 + m )self.w_initialized = True# Shuffle the datadef _shuffle(self, X, y):r = self.rgen.permutation(len(y)) # use this method to get a randomly arrayreturn X[r], y[r] # return the randomly array, this kind of indexing exsiting only in npdef _update_weights(self, xi, target):output = self.activation(self.net_input(xi))error = target - outputself.w_[0] += self.eta * errorself.w_[1:] += self.eta * xi.dot(error)cost = 0.5 * error**2return costdef activation(self, X):return Xdef predict(self, X):return np.where(self.activation(self.net_input(X))>=0.0, 1, -1)def net_input(self, X):return np.dot(X, self.w_[1:]) + self.w_[0]
知识点
1. 实现参数(权重)矩阵初始化小值:
def _initialize_weights(self, m):self.rgen = np.random.RandomState(self.random_state)self.w_ = self.rgen.normal(loc=0.0, scale=0.01, size=1 + m )self.w_initialized = True
1-设定RandomState的随机种,不同的数值会产生不同的随机初始化数据的效果
2-使用numpy.random.normal可以设定呈正态分布的初始数组,loc设置均值,scale设置标准差,size设置numpy数组长度,+1是因为加入了偏置项
2. 打乱数据
def _shuffle(self, X, y):r = self.rgen.permutation(len(y)) # use this method to get a randomly arrayreturn X[r], y[r] # return the randomly array, this kind of indexing exsiting only in np
1-numpy.random.permutation方法(rgen是random实例)可以产生随机打乱序号的numpy array
2-numpy array可以通过双重数组索引进行重新组合,实现洗牌效果
3. 实现随机梯度下降
def _update_weights(self, xi, target):output = self.activation(self.net_input(xi))error = target - outputself.w_[0] += self.eta * errorself.w_[1:] += self.eta * xi.dot(error)cost = 0.5 * error**2return cost
1-相比批量梯度下降适应机,X(矩阵)改成了xi(矢量),errors(矢量)改成了error,然后单独计算单个样本预测误差,如此而已
4. 使用函数实现数据处理步骤
def activation(self, X):return Xdef predict(self, X):return np.where(self.activation(self.net_input(X))>=0.0, 1, -1)def net_input(self, X):return np.dot(X, self.w_[1:]) + self.w_[0]
不难发现三个函数分别完成了数据处理的三个节点的功能,predict就是量化器,输出1/-1。根据Adaline原理,量化器产生的最终结果不参与优化过程,所以其实并没有在类中使用,只在实例中进行最终的结果预测
5. 在保持初始化权重不变的情况下实现训练
def partial_fit(self, X, y):if not self.w_initialized:self._initialize_weights(X.shape[1])if y.ravel().shape[0] > 1:for xi, target in zip(X, y):self._update_weights(xi, target)else: self._update_weights(X, y)return self
判断一:如果未被初始化则进行初始化,否则继续训练,叠加训练结果
判断二:如果传递进来的是是一组数据集(>1),对里面的所有数据进行一轮更新
判断三:如果传递进来的是单组数据,对这组数据再进行一轮优化更新
6. 返回自己
return self
返回自己让你可以进行链式调用,返回指向类的实例的引用
不返回自己:AdalineSGD(...).fit().fit()报错
返回自己:AdalineSGD(...).fit().fit()不报错,不停地进行数据训练,训练结果是累加的
return self在算法库种是很常用的算法
结果
左图为决策边界,右图为误差随着迭代优化次数的趋势图,画图原理不作介绍了
python实现SGD(stochastic gradient descent)算法相关推荐
- 使用SGD(Stochastic Gradient Descent)进行大规模机器学习
1 基于梯度下降的学习 对于一个简单的机器学习算法,每一个样例包含了一个(x,y)对,其中一个输入x和一个数值输出y.我们考虑损失函数 ,它描述了预测值 和实际值y之间的损失.预测值是我们选择从一函 ...
- 随机梯度下降算法SGD(Stochastic gradient descent)
SGD是什么 SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一.SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数.它的基本 ...
- 线性回归之随机梯度下降法(Stochastic Gradient Descent,SGD)
通俗易懂 一个经典的例子就是假设你现在在山上,为了以最快的速度下山,且视线良好,你可以看清自己的位置以及所处位置的坡度,那么沿着坡向下走,最终你会走到山底.但是如果你被蒙上双眼,那么你则只能凭借脚踩石 ...
- 三种梯度下降法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)
梯度下降(GD)是最小化风险函数.损失函数的一种常用方法,随机梯度下降(stochastic gradient descent).批量梯度下降(Batch gradient descent)和mini ...
- 随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比、实现对比
随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比.实现对比 标签: 梯度下降最优化迭代 2013 ...
- 手工实现:SVM with Stochastic Gradient Descent
手工实现:SVM with Stochastic Gradient Descent 引入 实际问题 理论知识 SVM 直观认识 什么是分的好? 1.是不是只要都分对了就是分的好? 2.是不是只要训练集 ...
- 【文献阅读】Federated Accelerated Stochastic Gradient Descent
文章目录 1 Federated Accelerated Stochastic Gradient Descent (FedAc) 2 challenge 3 how to do 4 baseline ...
- 【Byrd-SAGA】Federated Variance-Reduced Stochastic Gradient Descent with Robustness
Federated Variance-Reduced Stochastic Gradient Descent with Robustness to Byzantine Attacks 对拜占庭攻击具有 ...
- 几种梯度下降方法简介(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)
我们在训练神经网络模型时,最常用的就是梯度下降,这篇博客主要介绍下几种梯度下降的变种(mini-batch gradient descent和stochastic gradient descent), ...
- 几种梯度下降方法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)
几种梯度下降方法对比(Batch gradient descent.Mini-batch gradient descent 和 stochastic gradient descent) 我们在训练神经 ...
最新文章
- 2019 训练比赛 记录
- C#操作项目配置文件
- Qt中QtTableWidget的使用
- win10计算机磁盘图标,Win10 21H1怎么更换电脑磁盘的图标标识
- 中文短文本的实体识别实体链接,第一名解决方案
- 力扣两数之和jAVA_力扣----1.两数之和(JavaScript, Java实现)
- 大学生自学网python_大学生免费自学网官网
- 用极大似然法估计因子载荷矩阵_关于因子分析|stata
- 计算机视觉—车道线检测
- 简聊聊天软件的表设计
- 最低销售量计算机公式,最低、最高、安全库存量的计算公式
- Meta拟裁撤Instagram伦敦员工 其余人将调往美国---转自百度新闻|财联社
- SolidWorks宏工具介绍——初识宏工具
- 数据结构PTA习题:进阶实验5-3.2 新浪微博热门话题 (30分)
- 【GDOI2018模拟7.14】小奇的糖果
- 通过internet连接到股票信息服务器,一种股票机的制作方法
- 【干货】Kaggle数据挖掘比赛经验分享,陈成龙博士整理!
- DataFrame计算corr()函数计算相关系数时,出现返回值为空或NaN的情况
- Win10怎么设置窗口护眼色
- 颜色类中英文词汇大全(5)
热门文章
- 信号硬件入门--振幅调制信号发生器(正弦波发生器方案、AM调制方案)--First理论部分
- 故障诊断仪采集发动机EMS故障的报文与故障码记录
- ​SQL注入非常详细总结
- Redis 官方推出可视化工具,颜值爆表,功能真心强大!这是不给其他工具活路啊!...
- USB协议详解第17讲(USB事务总结)
- 矢量地图自定义切片样式
- 使用mathematica求解最优化模型
- NAT(地址转换技术)详解
- 64位linux下安装libpng出错,安装libpng-1.6.10时make出现错误,请帮忙
- 如何清除项目 git 版本控制信息