MXNet动手学深度学习笔记:线性回归
为什么80%的码农都做不了架构师?>>>
#coding:utf-8
from mxnet import ndarray as nd
from mxnet import autograd
import random
import matplotlib.pyplot as pltnum_inputs = 2
num_examples = 1000true_w = [2,-3.4]
true_b = 4.2X = nd.random_normal(shape=(num_examples,num_inputs))
y = true_w[0] * X[:,0] + true_w[1] * X[:,1] + true_b# 添加随机噪声数据
y += .01 * nd.random_normal(shape=y.shape)# 数据读取
def data_iter():batch_size = 10idx = list(range(num_examples))random.shuffle(idx)for i in range(0,num_examples,batch_size):j = nd.array(idx[i:min(i+batch_size,num_examples)])yield nd.take(X,j),nd.take(y,j)# 初始化模型参数
w = nd.random_normal(shape=(num_inputs,1))
b = nd.zeros((1,))params = [w,b]for param in params:param.attach_grad()# 定义模型
def net(X):return nd.dot(X,w) + b# 定义损失函数
def square_loss(yhat,y):# 转换成相同形状,避免自动转换return(yhat - y.reshape(yhat.shape)) ** 2# 优化器
def SGD(params,lr):for param in params:param[:] = param - lr * param.grad# 训练# 模型函数
def real_fun(X):return true_w[0] * X[:,0] + true_w[1] * X[:,1] + true_bdef plot(losses,sample_size=100):xs = list(range(len(losses)))f,(fg1,fg2) = plt.subplots(1,2)fg1.set_title('Loss during training')fg1.plot(xs,losses,'-r')fg2.set_title('Estimate vs real funtion')fg2.plot(X[:sample_size,1].asnumpy(),net(X[:sample_size,:]).asnumpy(),'or',label='Estimated')fg2.plot(X[:sample_size,1].asnumpy(),real_fun(X[:sample_size,:]).asnumpy(),'*g',label='Real')fg2.legend()plt.show()epochs = 5
learning_rate = 0.01
niter = 0
losses = []
moving_loss = 0
smoothing_constant = 0.01for e in range(epochs):total_loss = 0for data,label in data_iter():with autograd.record():output = net(data)loss = square_loss(output,label)loss.backward()SGD(params,learning_rate)total_loss += nd.sum(loss).asscalar()# 记录每一次数据点后,损失的移动平均值变化niter += 1curr_loss = nd.mean(loss).asscalar()moving_loss = (1 - smoothing_constant) * moving_loss + \(smoothing_constant) * curr_lossest_loss = moving_loss / (1 - (1 - smoothing_constant) ** niter)if (niter + 1) % 100 == 0:losses.append(est_loss)print('Epochs %s,batch %s . Moving avg of loss: %s Average loss:%f' %(e,niter,est_loss,total_loss / num_examples))print(true_w,w)
print(true_b,b)
转载于:https://my.oschina.net/wujux/blog/1809135
MXNet动手学深度学习笔记:线性回归相关推荐
- MXNet动手学深度学习笔记:卷积计算
为什么80%的码农都做不了架构师?>>> #coding:utf-8 ''' 卷积计算 ''' import mxnet as mx from mxnet.gluon impo ...
- 动手学深度学习笔记3.4+3.5+3.6+3.7
系列文章目录 动手学深度学习笔记系列: 动手学深度学习笔记3.1+3.2+3.3 文章目录 系列文章目录 前言 一.softmax回归 1.1 分类问题 1.2 网络架构 1.3 全连接层的参数开销 ...
- 动手学深度学习笔记(1)
动手学深度学习 深度学习简介 深度学习简介 举一个小的例子,如何编写一个程序,让机器识别我输入的图片是否有一只猫?我们需要哪些值来帮助我们确定?事实上,要想解读图像中的内容,需要寻找仅仅在结合成千上万 ...
- [深度学习]动手学深度学习笔记-3
Task-2 文本预处理:语言模型:循环神经网络基础 3.1 文本预处理 文本是一类序列数据,一篇文章可以看作是字符或单词的序列,本节将介绍文本数据的常见预处理步骤,预处理通常包括四个步骤: 读入文本 ...
- [深度学习]动手学深度学习笔记-5
Task2--梯度消失.梯度爆炸 5.1 梯度消失与梯度爆炸的概念 深度神经网络训练的时候,采用的是反向传播方式,该方式使用链式求导,计算每层梯度的时候会涉及一些连乘操作,因此如果网络过深. 那么如果 ...
- DJL-Java开发者动手学深度学习之线性回归
线性回归 回归是指一类为一个或多个自变量与因变量之间关系建模的方法.在自然科学和社会科学领域,回归通常用来表示输入和输出之间的关系. 在机器学习领域中的大多数任务通常都与预测有关. 当我们想预测一个数 ...
- [深度学习]动手学深度学习笔记-14
Task9--目标检测基础 14.1 目标检测和边界框 在前面的一些章节中,我们介绍了诸多用于图像分类的模型.在图像分类任务里,我们假设图像里只有一个主体目标,并关注如何识别该目标的类别.然而,很多时 ...
- 动手学深度学习笔记4——微积分自动微分
目录 1.微积分 1.1导数和微分 1.2偏导数 1.3梯度 1.4链式法则 1.5小结 1.6练习 2.自动微分 2.1一个简单的例子 2.2非标量变量的反向传播 2.3分离计算 2.4Python ...
- 动手学深度学习笔记一线性回归
一,线性回归 线性回归的疑问记录 MSE:mean square error均方误差 epoch:迭代次数 optimizer:优化器 mini-batch:小批量 进行线性回归的思路 选择线性模型, ...
最新文章
- 深度学习Anchor Boxes原理与实战技术
- mlc tlc slc qlc_看了这么多固态硬盘科普,终于真正搞明白TLC闪存和SLC缓存
- 主要元素(超过一半元素)
- 微信小程序页面文字超出一行隐藏,文字超出两行隐藏。
- 关于发那科机器人的FSSB
- requests爬取免费代理2
- java连接SqlServer2000类,比较完整,比较强大
- mybatis写增删改时候的注意点
- Android修改ro.debugable开启全局debug模式
- Git ignore UserInterfaceState.xcuserstate
- 如何引入阿里矢量图标库彩色图标
- 苹果CMSv10新手入门安装必看教程
- linux codeblocks汉化
- KODI(原XBMC)二次开发完全解析(一)
- 达尔优EM915镜面板游戏鼠标拆机教程
- 华为mate50pro和小米12ultea对比
- puppy linux4,发行版:Puppy Linux 4.00发布
- python中括号和方括号的问题
- 牛客多校第一场——E-ABBA
- 前端程序员后来都去干嘛了?我找了几位聊了聊
热门文章
- 2020年,对薪资不满意的程序员要注意了...
- 如何避免操作系统中多线程资源竞争的互斥与同步?
- Reveal.js一个用来做WEB演示文稿的框架
- Google I/O 2019上提及的Javascript新特性
- Android 读取meta-data元素的数据
- NEO共识节点推荐搭建步骤
- Microsoft宣布发布GA版Azure Event Grid
- 国外10大IT网站和博客网站
- 【BZOJ】1726 [Usaco2006 Nov]Roadblocks第二短路
- Windows server 2012 搭建×××图文教程(二)配置路由和远程访问服务