SGD 被广泛运用到机器学习(machine learning)中最优化等问题中,学术界一直热衷于提升SGD在优化问题中的收敛速率,并行计算也是热点研究的方向(包括Hogwild! [1], Delay-tolerant Algorithm for SGD [2], Elastic Average SGD [3]),本篇实现了现在比较火的Downpour SGD [4]的样例代码 (选择这个的原因引用量最大)。

理论思想

核心思想,将数据随机划分成数个子数据集sub-data, 将模型变量划分数个机器/进程/线程,基于子集合数据更新各个机器/进程/线程内的变量n次,然后到master 节点更新模型变量,各个机器/进程/线程训练独立互不干扰, 而且各个机器/进程/线程内的模型变量在训练中也互不干扰。引用原文原话 [4]:

We divide the training data into a number of subsets and run a copy of the model on
each of these subsets. The models communicate updates through a centralized parameter server,
which keeps the current state of all parameters for the model, sharded across many machines (e.g.,
if we have 10 parameter server shards, each shard is responsible for storing and applying updates
to 1/10th of the model parameters) (Figure 2). This approach is asynchronous in two distinct aspects:
the model replicas run independently of each other, and the parameter server shards also run
independently of one another

代码部分

# -*- encoding: utf-8 -*-
import re
import sys
import numpy as np
import copy
import time
import threadingdef timeConsumption(func):def func_wrapper(x):start_time = time.time()func(x)end_time = time.time()print "[Function-Time-Consumption] ", end_time - start_timereturn func_wrapperdef initialize(length=100):""""""global Xglobal YX = []Y = []mu, sigma = 0, 0.1V = 100# here we assume x is two-dimesion matrix for i in np.random.random(length):a = i * Vfor j in np.random.random(length): b = j * VX.append([a**2, b**2, a, b, 1, 1])# white noisenoise = np.random.normal(mu, sigma, size=length * length)# a * x**2 + b * x + cfunction = lambda x: np.dot([103.0, -22.5, 35.5, -28, 43, 19.0], x)Y = [function(X[i]) + noise[i] for i in range(length * length)] X = np.array(X)Y = np.array(Y)return X, Yclass HogwildThreading(threading.Thread):""""""def __init__(self, threadID, name, dIndex, pIndex, n, param, gama, function, eplison):""""""threading.Thread.__init__(self)self.threadID = threadIDself.name = nameself.dIndex = dIndexself.pIndex = pIndexself.n = nself.param = copy.copy(param)self.gama = gamaself.function = functionself.eplison = eplisondef run(self):"""In each threading, update relevant parameters with each sub-dataset for each n step"""global Xglobal Yglobal derativewhile self.n > 0:local_y = [self.function(self.param, x) for x in X[self.dIndex]]diff = np.mean(np.subtract(local_y, Y[self.dIndex]))if abs(diff) < self.eplison: break# print self.name + "-" + self.threadID + " : " + str(diff)for i in self.pIndex:self.param[i] -= self.gama * derative[i] * diffself.n -= 1class HogwildSGD(object):""""""def __init__(self, X, Y, eplison=0.000001, gama=0.01, iter_num=1000, thread_num=10):"""ref: https://arxiv.org/abs/1106.5730"""_d = X.shape[-1]# set up number of threadsif _d < thread_num: self.thread_num = _delse: self.thread_num = thread_num# parameter initailizationself.a = np.random.normal(0, 1, size=_d)print self.aself.eplison = eplisonself.gama = (1.0 / max(Y)) * 1.0 / len(X)self.iter_num = iter_numdef function(self, a, x):"""Do we have prior knowledge about the function?- quadratic ?- exponential ?- linear ?"""return np.dot(a, x)def chunkIt(self, seq, num):avg = len(seq) / float(num)out = []last = 0.0while last < len(seq):out.append(seq[int(last):int(last + avg)])last += avgreturn out@timeConsumptiondef run(self):""""""global Xglobal Yglobal derativederative = []for i in range(len(self.a)):derative.append(np.mean([x[i] for x in X]))dIndex = range(len(X))np.random.shuffle(dIndex)dIndex = self.chunkIt(dIndex, self.thread_num)diff = 1while(abs(diff) > self.eplison and self.iter_num > 0):pIndex = range(len(self.a))np.random.shuffle(pIndex)pIndex = self.chunkIt(pIndex, self.thread_num)threads = []for i in range(self.thread_num):instance = HogwildThreading(str(i), "HogwildThreading", dIndex[i], pIndex[i], 10, self.a, self.gama, self.function, self.eplison)instance.start()threads.append(instance)for i in range(self.thread_num):threads[i].join()# update parameter from each data-shardsfor i in range(self.thread_num):index = threads[i].pIndexfor j in index:self.a[j] = threads[i].param[j]local_y = [self.function(self.a, x) for x in X]diff = np.mean(np.subtract(local_y, Y)) print diffself.iter_num -= 1print self.adef main():""""""X, Y = initialize()instance = HogwildSGD(X, Y)instance.run()if __name__ == "__main__":reload(sys)sys.setdefaultencoding("utf-8")main()

SGD平行算法 - Downpour SGD (单机python多线程版)相关推荐

  1. 工具分享(1):FTP暴力破解工具 [Python多线程版]

    工具分享(1):FTP暴力破解工具 [Python多线程版] 参考:https://www.waitalone.cn/python-ftp-mult.html 在他的基础上加了这么一个代码:如果用户输 ...

  2. DistBelief 框架下的并行随机梯度下降法 - Downpour SGD

    本文是读完 Jeffrey Dean, Greg S. Corrado 等人的文章 Large Scale Distributed Deep Networks (2012) 后的一则读书笔记,重点介绍 ...

  3. Python 多线程抓取网页 牛人 use raw socket implement http request great

    Python 多线程抓取网页 - 糖拌咸鱼 - 博客园 Python 多线程抓取网页 最近,一直在做网络爬虫相关的东西. 看了一下开源C++写的larbin爬虫,仔细阅读了里面的设计思想和一些关键技术 ...

  4. python多线程爬虫实例-python多线程爬虫实例讲解

    Python作为一门强大的脚本语言,我们经常使用python来写爬虫程序,简单的爬虫会写,可是用python写多线程网页爬虫,应该如何写呢?一般来说,使用线程有两种模式,一种是创建线程要执行的函数,把 ...

  5. Python 多线程抓取网页

    Python 多线程抓取网页 - 糖拌咸鱼 - 博客园 Python 多线程抓取网页 最近,一直在做网络爬虫相关的东西. 看了一下开源C++写的larbin爬虫,仔细阅读了里面的设计思想和一些关键技术 ...

  6. python多线程读取数据库数据_Python基于多线程操作数据库相关知识点详解

    Python基于多线程操作数据库相关问题分析 本文实例分析了Python多线程操作数据库相关问题.分享给大家供大家参考,具体如下: python多线程并发操作数据库,会存在链接数据库超时.数据库连接丢 ...

  7. python多线程详解 Python 垃圾回收机制

    文章目录 python多线程详解 一.线程介绍 什么是线程 为什么要使用多线程 总结起来,使用多线程编程具有如下几个优点: 二.线程实现 自定义线程 守护线程 主线程等待子线程结束 多线程共享全局变量 ...

  8. python多线程爬取ts文件并合成mp4视频

    python多线程爬取ts文件并合成mp4视频 声明:仅供技术交流,请勿用于非法用途,如有其它非法用途造成损失,和本博客无关 目录 python多线程爬取ts文件并合成mp4视频 前言 一.分析页面 ...

  9. 用通俗易懂的方式讲解:主成分分析(PCA)算法及案例(Python 代码)

    文章目录 知识汇总 加入方式 一.引入问题 二.数据降维 三.PCA基本数学原理 3.1 内积与投影 3.2 基 3.3 基变换的矩阵表示 3.4 协方差矩阵及优化目标 3.5 方差 3.6 协方差 ...

最新文章

  1. 计算机会计课程试题及答案,计算机会计第2次作业_报表_附答案
  2. 刷题两个月,从入门到字节跳动offer,这是我的模板 | GitHub 1.2k星
  3. try to navigate from button to line item page
  4. 升级bios_华硕400系主板升级BIOS:静待11代酷睿CPU
  5. 分布排序(distribution sorts)算法大串讲
  6. movcms能安装PHP吗,LzCMS-博客版 手动安装方法
  7. iPhone手机硬件拆解介绍
  8. 《信息安全技术—个人信息安全影响评估指南》pdf下载
  9. BZOJ3533: [Sdoi2014]向量集
  10. 程序员面试阿里、腾讯、京东等大公司,这些套路你知道吗?
  11. thinkphp5.1 || 给图片添加文字,图片水印
  12. 如何将截图中公式转换成为可用的mathtype公式
  13. DeleteObject()的使用
  14. 正益工作能担起PaaS+SaaS的未来探索吗?
  15. Euler Finance 完成 3200 万美元融资 Haun Ventures 领投
  16. 又双叒叕是Linux笔记
  17. StringTokenizer字符串分割
  18. 中国各省市DNS服务器列表
  19. XenDesktop 之powershell 使用
  20. 中考计算机考试不合格会怎么样,中考不及格会给毕业证吗 有什么后果吗

热门文章

  1. javaweb常识类英语
  2. 使用NtCreateThreadEx将Dll注入目标进程
  3. 用uniapp实现微信小程序的电子签名效果
  4. 货郎担问题java算法_经典算法(1)---货郎担问题
  5. GUI(Graphical User Interface)
  6. 全面了解风控指标体系
  7. 做淘宝店铺为什么一定要定位?
  8. 数独的生成以及解答--回溯算法c++附详细代码
  9. iOS自动化笔记(一)WebDriverAgent安装与使用
  10. linux集显驱动程序,Ubuntu14.04安装intel集显驱动