文章目录

  • 源码
  • 知识点
    • 1. 实现参数(权重)矩阵初始化小值:
    • 2. 打乱数据
    • 3. 实现随机梯度下降
    • 4. 使用函数实现数据处理步骤
    • 5. 在保持初始化权重不变的情况下实现训练
    • 6. 返回自己
  • 结果

源码

此为适应机使用随机梯度下降的python实现,代码主要来源于Python Machine Learning 3rd(此书包含大量python实现算法的源码,宜啃读)自己把代码又弄了一遍,欢迎想看或者看过这本书的来交流讨论

class AdalineSGD(object):def __init__(self, eta=0.01, n_iter=10, shuffle=True, random_state=None):self.eta = etaself.n_iter = n_iterself.w_initialized = Falseself.shuffle = shuffleself.random_state = random_statedef fit(self, X, y):self._initialize_weights(X.shape[1])self.cost_ = []for i in range(self.n_iter):if self.shuffle:X, y = self._shuffle(X, y)cost = []for i in range(self.n_iter):if self.shuffle:X, y = self._shuffle(X, y)cost = []for xi, target in zip(X, y):cost.append(self._update_weights(xi, target))avg_cost = sum(cost) / len(y)self.cost_.append(avg_cost)return selfdef partial_fit(self, X, y):if not self.w_initialized:self._initialize_weights(X.shape[1])if y.ravel().shape[0] > 1:for xi, target in zip(X, y):self._update_weights(xi, target)else: self._update_weights(X, y)return selfdef _initialize_weights(self, m):self.rgen = np.random.RandomState(self.random_state)self.w_ = self.rgen.normal(loc=0.0, scale=0.01, size=1 + m )self.w_initialized = True# Shuffle the datadef _shuffle(self, X, y):r = self.rgen.permutation(len(y)) # use this method to get a randomly arrayreturn X[r], y[r] # return the randomly array, this kind of indexing exsiting only in npdef _update_weights(self, xi, target):output = self.activation(self.net_input(xi))error = target - outputself.w_[0] += self.eta * errorself.w_[1:] += self.eta * xi.dot(error)cost = 0.5 * error**2return costdef activation(self, X):return Xdef predict(self, X):return np.where(self.activation(self.net_input(X))>=0.0, 1, -1)def net_input(self, X):return np.dot(X, self.w_[1:]) + self.w_[0]

知识点

1. 实现参数(权重)矩阵初始化小值:

def _initialize_weights(self, m):self.rgen = np.random.RandomState(self.random_state)self.w_ = self.rgen.normal(loc=0.0, scale=0.01, size=1 + m )self.w_initialized = True
1-设定RandomState的随机种,不同的数值会产生不同的随机初始化数据的效果
2-使用numpy.random.normal可以设定呈正态分布的初始数组,loc设置均值,scale设置标准差,size设置numpy数组长度,+1是因为加入了偏置项

2. 打乱数据

def _shuffle(self, X, y):r = self.rgen.permutation(len(y)) # use this method to get a randomly arrayreturn X[r], y[r] # return the randomly array, this kind of indexing exsiting only in np
1-numpy.random.permutation方法(rgen是random实例)可以产生随机打乱序号的numpy array
2-numpy array可以通过双重数组索引进行重新组合,实现洗牌效果

3. 实现随机梯度下降

def _update_weights(self, xi, target):output = self.activation(self.net_input(xi))error = target - outputself.w_[0] += self.eta * errorself.w_[1:] += self.eta * xi.dot(error)cost = 0.5 * error**2return cost
1-相比批量梯度下降适应机,X(矩阵)改成了xi(矢量),errors(矢量)改成了error,然后单独计算单个样本预测误差,如此而已

4. 使用函数实现数据处理步骤

def activation(self, X):return Xdef predict(self, X):return np.where(self.activation(self.net_input(X))>=0.0, 1, -1)def net_input(self, X):return np.dot(X, self.w_[1:]) + self.w_[0]


不难发现三个函数分别完成了数据处理的三个节点的功能,predict就是量化器,输出1/-1。根据Adaline原理,量化器产生的最终结果不参与优化过程,所以其实并没有在类中使用,只在实例中进行最终的结果预测

5. 在保持初始化权重不变的情况下实现训练

def partial_fit(self, X, y):if not self.w_initialized:self._initialize_weights(X.shape[1])if y.ravel().shape[0] > 1:for xi, target in zip(X, y):self._update_weights(xi, target)else: self._update_weights(X, y)return self
判断一:如果未被初始化则进行初始化,否则继续训练,叠加训练结果
判断二:如果传递进来的是是一组数据集(>1),对里面的所有数据进行一轮更新
判断三:如果传递进来的是单组数据,对这组数据再进行一轮优化更新

6. 返回自己

return self
返回自己让你可以进行链式调用,返回指向类的实例的引用
不返回自己:AdalineSGD(...).fit().fit()报错
返回自己:AdalineSGD(...).fit().fit()不报错,不停地进行数据训练,训练结果是累加的
return self在算法库种是很常用的算法

结果


左图为决策边界,右图为误差随着迭代优化次数的趋势图,画图原理不作介绍了

python实现SGD(stochastic gradient descent)算法相关推荐

  1. 使用SGD(Stochastic Gradient Descent)进行大规模机器学习

    1 基于梯度下降的学习  对于一个简单的机器学习算法,每一个样例包含了一个(x,y)对,其中一个输入x和一个数值输出y.我们考虑损失函数 ,它描述了预测值 和实际值y之间的损失.预测值是我们选择从一函 ...

  2. 随机梯度下降算法SGD(Stochastic gradient descent)

    SGD是什么 SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一.SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数.它的基本 ...

  3. 线性回归之随机梯度下降法(Stochastic Gradient Descent,SGD)

    通俗易懂 一个经典的例子就是假设你现在在山上,为了以最快的速度下山,且视线良好,你可以看清自己的位置以及所处位置的坡度,那么沿着坡向下走,最终你会走到山底.但是如果你被蒙上双眼,那么你则只能凭借脚踩石 ...

  4. 三种梯度下降法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)

    梯度下降(GD)是最小化风险函数.损失函数的一种常用方法,随机梯度下降(stochastic gradient descent).批量梯度下降(Batch gradient descent)和mini ...

  5. 随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比、实现对比

     随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比.实现对比 标签: 梯度下降最优化迭代 2013 ...

  6. 手工实现:SVM with Stochastic Gradient Descent

    手工实现:SVM with Stochastic Gradient Descent 引入 实际问题 理论知识 SVM 直观认识 什么是分的好? 1.是不是只要都分对了就是分的好? 2.是不是只要训练集 ...

  7. 【文献阅读】Federated Accelerated Stochastic Gradient Descent

    文章目录 1 Federated Accelerated Stochastic Gradient Descent (FedAc) 2 challenge 3 how to do 4 baseline ...

  8. 【Byrd-SAGA】Federated Variance-Reduced Stochastic Gradient Descent with Robustness

    Federated Variance-Reduced Stochastic Gradient Descent with Robustness to Byzantine Attacks 对拜占庭攻击具有 ...

  9. 几种梯度下降方法简介(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)

    我们在训练神经网络模型时,最常用的就是梯度下降,这篇博客主要介绍下几种梯度下降的变种(mini-batch gradient descent和stochastic gradient descent), ...

  10. 几种梯度下降方法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)

    几种梯度下降方法对比(Batch gradient descent.Mini-batch gradient descent 和 stochastic gradient descent) 我们在训练神经 ...

最新文章

  1. 2019 训练比赛 记录
  2. C#操作项目配置文件
  3. Qt中QtTableWidget的使用
  4. win10计算机磁盘图标,Win10 21H1怎么更换电脑磁盘的图标标识
  5. 中文短文本的实体识别实体链接,第一名解决方案
  6. 力扣两数之和jAVA_力扣----1.两数之和(JavaScript, Java实现)
  7. 大学生自学网python_大学生免费自学网官网
  8. 用极大似然法估计因子载荷矩阵_关于因子分析|stata
  9. 计算机视觉—车道线检测
  10. 简聊聊天软件的表设计
  11. 最低销售量计算机公式,最低、最高、安全库存量的计算公式
  12. Meta拟裁撤Instagram伦敦员工 其余人将调往美国---转自百度新闻|财联社
  13. SolidWorks宏工具介绍——初识宏工具
  14. 数据结构PTA习题:进阶实验5-3.2 新浪微博热门话题 (30分)
  15. 【GDOI2018模拟7.14】小奇的糖果
  16. 通过internet连接到股票信息服务器,一种股票机的制作方法
  17. 【干货】Kaggle数据挖掘比赛经验分享,陈成龙博士整理!
  18. DataFrame计算corr()函数计算相关系数时,出现返回值为空或NaN的情况
  19. Win10怎么设置窗口护眼色
  20. 颜色类中英文词汇大全(5)

热门文章

  1. 信号硬件入门--振幅调制信号发生器(正弦波发生器方案、AM调制方案)--First理论部分
  2. 故障诊断仪采集发动机EMS故障的报文与故障码记录
  3. ​SQL注入非常详细总结
  4. Redis 官方推出可视化工具,颜值爆表,功能真心强大!这是不给其他工具活路啊!...
  5. USB协议详解第17讲(USB事务总结)
  6. 矢量地图自定义切片样式
  7. 使用mathematica求解最优化模型
  8. NAT(地址转换技术)详解
  9. 64位linux下安装libpng出错,安装libpng-1.6.10时make出现错误,请帮忙
  10. 如何清除项目 git 版本控制信息