为什么80%的码农都做不了架构师?>>>   

#coding:utf-8
from mxnet import ndarray as nd
from mxnet import autograd
import random
import matplotlib.pyplot as pltnum_inputs = 2
num_examples = 1000true_w = [2,-3.4]
true_b = 4.2X = nd.random_normal(shape=(num_examples,num_inputs))
y = true_w[0] * X[:,0] + true_w[1] * X[:,1] + true_b# 添加随机噪声数据
y += .01 * nd.random_normal(shape=y.shape)# 数据读取
def data_iter():batch_size = 10idx = list(range(num_examples))random.shuffle(idx)for i in range(0,num_examples,batch_size):j = nd.array(idx[i:min(i+batch_size,num_examples)])yield nd.take(X,j),nd.take(y,j)# 初始化模型参数
w = nd.random_normal(shape=(num_inputs,1))
b = nd.zeros((1,))params = [w,b]for param in params:param.attach_grad()# 定义模型
def net(X):return nd.dot(X,w) + b# 定义损失函数
def square_loss(yhat,y):# 转换成相同形状,避免自动转换return(yhat - y.reshape(yhat.shape)) ** 2# 优化器
def SGD(params,lr):for param in params:param[:] = param - lr * param.grad# 训练# 模型函数
def real_fun(X):return true_w[0] * X[:,0] + true_w[1] * X[:,1] + true_bdef plot(losses,sample_size=100):xs = list(range(len(losses)))f,(fg1,fg2) = plt.subplots(1,2)fg1.set_title('Loss during training')fg1.plot(xs,losses,'-r')fg2.set_title('Estimate vs real funtion')fg2.plot(X[:sample_size,1].asnumpy(),net(X[:sample_size,:]).asnumpy(),'or',label='Estimated')fg2.plot(X[:sample_size,1].asnumpy(),real_fun(X[:sample_size,:]).asnumpy(),'*g',label='Real')fg2.legend()plt.show()epochs = 5
learning_rate = 0.01
niter = 0
losses = []
moving_loss = 0
smoothing_constant = 0.01for e in range(epochs):total_loss = 0for data,label in data_iter():with autograd.record():output = net(data)loss = square_loss(output,label)loss.backward()SGD(params,learning_rate)total_loss += nd.sum(loss).asscalar()# 记录每一次数据点后,损失的移动平均值变化niter += 1curr_loss = nd.mean(loss).asscalar()moving_loss = (1 - smoothing_constant) * moving_loss + \(smoothing_constant) * curr_lossest_loss = moving_loss / (1 - (1 - smoothing_constant) ** niter)if (niter + 1) % 100 == 0:losses.append(est_loss)print('Epochs %s,batch %s . Moving avg of loss: %s Average loss:%f' %(e,niter,est_loss,total_loss / num_examples))print(true_w,w)
print(true_b,b)

转载于:https://my.oschina.net/wujux/blog/1809135

MXNet动手学深度学习笔记:线性回归相关推荐

  1. MXNet动手学深度学习笔记:卷积计算

    为什么80%的码农都做不了架构师?>>>    #coding:utf-8 ''' 卷积计算 ''' import mxnet as mx from mxnet.gluon impo ...

  2. 动手学深度学习笔记3.4+3.5+3.6+3.7

    系列文章目录 动手学深度学习笔记系列: 动手学深度学习笔记3.1+3.2+3.3 文章目录 系列文章目录 前言 一.softmax回归 1.1 分类问题 1.2 网络架构 1.3 全连接层的参数开销 ...

  3. 动手学深度学习笔记(1)

    动手学深度学习 深度学习简介 深度学习简介 举一个小的例子,如何编写一个程序,让机器识别我输入的图片是否有一只猫?我们需要哪些值来帮助我们确定?事实上,要想解读图像中的内容,需要寻找仅仅在结合成千上万 ...

  4. [深度学习]动手学深度学习笔记-3

    Task-2 文本预处理:语言模型:循环神经网络基础 3.1 文本预处理 文本是一类序列数据,一篇文章可以看作是字符或单词的序列,本节将介绍文本数据的常见预处理步骤,预处理通常包括四个步骤: 读入文本 ...

  5. [深度学习]动手学深度学习笔记-5

    Task2--梯度消失.梯度爆炸 5.1 梯度消失与梯度爆炸的概念 深度神经网络训练的时候,采用的是反向传播方式,该方式使用链式求导,计算每层梯度的时候会涉及一些连乘操作,因此如果网络过深. 那么如果 ...

  6. DJL-Java开发者动手学深度学习之线性回归

    线性回归 回归是指一类为一个或多个自变量与因变量之间关系建模的方法.在自然科学和社会科学领域,回归通常用来表示输入和输出之间的关系. 在机器学习领域中的大多数任务通常都与预测有关. 当我们想预测一个数 ...

  7. [深度学习]动手学深度学习笔记-14

    Task9--目标检测基础 14.1 目标检测和边界框 在前面的一些章节中,我们介绍了诸多用于图像分类的模型.在图像分类任务里,我们假设图像里只有一个主体目标,并关注如何识别该目标的类别.然而,很多时 ...

  8. 动手学深度学习笔记4——微积分自动微分

    目录 1.微积分 1.1导数和微分 1.2偏导数 1.3梯度 1.4链式法则 1.5小结 1.6练习 2.自动微分 2.1一个简单的例子 2.2非标量变量的反向传播 2.3分离计算 2.4Python ...

  9. 动手学深度学习笔记一线性回归

    一,线性回归 线性回归的疑问记录 MSE:mean square error均方误差 epoch:迭代次数 optimizer:优化器 mini-batch:小批量 进行线性回归的思路 选择线性模型, ...

最新文章

  1. 深度学习Anchor Boxes原理与实战技术
  2. mlc tlc slc qlc_看了这么多固态硬盘科普,终于真正搞明白TLC闪存和SLC缓存
  3. 主要元素(超过一半元素)
  4. 微信小程序页面文字超出一行隐藏,文字超出两行隐藏。
  5. 关于发那科机器人的FSSB
  6. requests爬取免费代理2
  7. java连接SqlServer2000类,比较完整,比较强大
  8. mybatis写增删改时候的注意点
  9. Android修改ro.debugable开启全局debug模式
  10. Git ignore UserInterfaceState.xcuserstate
  11. 如何引入阿里矢量图标库彩色图标
  12. 苹果CMSv10新手入门安装必看教程
  13. linux codeblocks汉化
  14. KODI(原XBMC)二次开发完全解析(一)
  15. 达尔优EM915镜面板游戏鼠标拆机教程
  16. 华为mate50pro和小米12ultea对比
  17. puppy linux4,发行版:Puppy Linux 4.00发布
  18. python中括号和方括号的问题
  19. 牛客多校第一场——E-ABBA
  20. 前端程序员后来都去干嘛了?我找了几位聊了聊

热门文章

  1. 2020年,对薪资不满意的程序员要注意了...
  2. 如何避免操作系统中多线程资源竞争的互斥与同步?
  3. Reveal.js一个用来做WEB演示文稿的框架
  4. Google I/O 2019上提及的Javascript新特性
  5. Android 读取meta-data元素的数据
  6. NEO共识节点推荐搭建步骤
  7. Microsoft宣布发布GA版Azure Event Grid
  8. 国外10大IT网站和博客网站
  9. 【BZOJ】1726 [Usaco2006 Nov]Roadblocks第二短路
  10. Windows server 2012 搭建×××图文教程(二)配置路由和远程访问服务