文章目录

  • 1.前言
  • 2.数据准备
  • 3.搭建神经网络
  • 4.训练搭建的神经网络
  • 5.可视化操作

1.前言

我会这次会来见证神经网络是如何通过简单的形式将一群数据用一条线条来表示. 或者说, 是如何在数据当中找到他们的关系, 然后用神经网络模型来建立一个可以代表他们关系的线条.

2.数据准备

我们创建一些假数据来模拟真实的情况. 比如一个一元二次函数: y = a * x^2 + b, 我们给 y 数据加上一点噪声来更加真实的展示它.

import torch
import matplotlib.pyplot as plt#制造一些数据
x = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1)   #torch.Size([100, 1]) #把[a,b,c]变成[[a,b,c]]
#print(x)
y = 2*(x.pow(2)) + 0.5*torch.rand(x.size())  #torch.rand为均匀分布,返回一个张量,包含了从区间[0, 1)的均匀分布中抽取的一组随机数。张量的形状由参数sizes定义
#print(y)
#画图
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

3.搭建神经网络

建立一个神经网络我们可以直接运用 torch 中的体系. 先定义所有的层属性(init()), 然后再一层层搭建(forward(x))层于层的关系链接. 建立关系的时候, 我们会用到激励函数

from torch import nn
import torch.nn.functional as Fclass NetWork(nn.Module):def __init__(self,n_input,n_hidden,n_output):super(NetWork,self).__init__()self.hidden = nn.Linear(n_input,n_hidden)self.output_for_predict = nn.Linear(n_hidden,n_output)def forward(self,x):x = F.relu(self.hidden(x))   #对x进入隐层后的输出应用激活函数(相当于一个筛选的过程)output = self.output_for_predict(x)    #做线性变换,将维度为1return outputnetwork = NetWork(n_input = 1,n_hidden = 8, n_output = 1)
print(network)  #打印模型的层次结构

4.训练搭建的神经网络

训练的步骤很简单, 如下:

from torch import nn
import torch.nn.functional as Fclass NetWork(nn.Module):def __init__(self,n_input,n_hidden,n_output):super(NetWork,self).__init__()self.hidden = nn.Linear(n_input,n_hidden)self.output_for_predict = nn.Linear(n_hidden,n_output)def forward(self,x):x = F.relu(self.hidden(x))   #对x进入隐层后的输出应用激活函数(相当于一个筛选的过程)output = self.output_for_predict(x)    #做线性变换,将维度为1return outputnetwork = NetWork(n_input = 1,n_hidden = 8, n_output = 1)
print(network)   #打印模型的层次结构optimizer = torch.optim.SGD(network.parameters(),lr = 0.2)
criterion = torch.nn.MSELoss()   #均方误差,用于计算预测值与真实值之间的误差for i in range(500):   #训练步数(相当于迭代次数)predication = network(x)loss = criterion(predication, y)    #predication为预测的值,y为真实值optimizer.zero_grad()loss.backward()      #反向传播,更新参数optimizer.step()     #将更新的参数值放进network的parameters

5.可视化操作

x = torch.unsqueeze(torch.linspace(-1,1,100),dim = 1)   #torch.Size([100, 1]) #把[a,b,c]变成[[a,b,c]]
#print(x)
y = 2*(x.pow(2)) + 0.5*torch.rand(x.size())  #torch.rand为均匀分布,返回一个张量,包含了从区间[0, 1)的均匀分布中抽取的一组随机数。张量的形状由参数sizes定义
#print(y)
#画图
# plt.scatter(x.data.numpy(),y.data.numpy())
# plt.show()from torch import nn
import torch.nn.functional as Fclass NetWork(nn.Module):def __init__(self,n_input,n_hidden,n_output):super(NetWork,self).__init__()self.hidden = nn.Linear(n_input,n_hidden)self.output_for_predict = nn.Linear(n_hidden,n_output)def forward(self,x):x = F.relu(self.hidden(x))   #对x进入隐层后的输出应用激活函数(相当于一个筛选的过程)output = self.output_for_predict(x)    #做线性变换,将维度为1return outputnetwork = NetWork(n_input = 1,n_hidden = 8, n_output = 1)
print(network)   #打印模型的层次结构plt.ion()   # 打开交互模式
plt.show()optimizer = torch.optim.SGD(network.parameters(),lr = 0.2)
criterion = torch.nn.MSELoss()   #均方误差,用于计算预测值与真实值之间的误差for i in range(500):   #训练步数(相当于迭代次数)predication = network(x)loss = criterion(predication, y)    #predication为预测的值,y为真实值optimizer.zero_grad()loss.backward()      #反向传播,更新参数optimizer.step()     #将更新的参数值放进network的parametersif i % 10 == 0:plt.cla()   # 清坐标轴plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),predication.data.numpy(),'ro', lw=5)   #画预测曲线,用红色o作为标记plt.text(0.5,0,'Loss = %.4f' % loss.data.numpy(), fontdict = {'size': 20, 'color':  'red'})plt.pause(0.1)

Pytorch——回归问题相关推荐

  1. 【深度学习】基于Torch的Python开源机器学习库PyTorch回归

    [深度学习]基于Torch的Python开源机器学习库PyTorch回归 文章目录1 torch.autograd 2 torch.nn.functional 3 详细的回归DEMO3.1 DATAS ...

  2. pytorch回归_PyTorch:用岭回归检查泰坦尼克号下沉

    pytorch回归 In this notebook, we shall use this dataset containing data about passengers from the Tita ...

  3. pytorch 回归问题实战

    pytorch 回归问题实战 深度学习中最重要的思想就是梯度下降,深度学习可以看出一个黑盒模型,其内部实质采用了梯度下降法.什么是梯度下降法呢?下面从梯度下降开始介绍,以及如何使用Pytorch实现回 ...

  4. 错误录入 算法_如何使用验证错误率确定算法输出之间的关系

    错误录入 算法 Monument (www.monument.ai) enables you to quickly apply algorithms to data in a no-code inte ...

  5. 自助分析_为什么自助服务分析真的不是一回事

    自助分析 That title probably got your attention and now you think I have some explaining to do! The key ...

  6. 开源软件 安全风险_3开源安全风险及其解决方法

    开源软件 安全风险 Open source software is very popular and makes up a significant portion of business applic ...

  7. 网络传播动力学_通过简单的规则传播动力

    网络传播动力学 When a single drop of paint is dropped on a surface the amount of space that the drop will c ...

  8. 存款惊人_如何使您的图快速美丽惊人

    存款惊人 So, you just finished retrieving, processing, and analyzing your data. You grab your data and y ...

  9. 异常检测时间序列_时间序列的无监督异常检测

    异常检测时间序列 To understand the normal behaviour of any flow on time axis and detect anomaly situations i ...

最新文章

  1. 通过数据挖掘组织营销潜力的三个重要途径
  2. shell 数组排序
  3. 16年寒假随笔(4)
  4. 使用tomcat自带的连接池,报错
  5. Android 实现Activity后台运行
  6. HDU4911 Inversion 解题报告
  7. 易语言皮肤模块200个_王者荣耀:第一个200战令玩家,连天美都赞他的升级方法最科学...
  8. 工具变量两阶段最小二乘
  9. word文件做一半未响应_Word经常出现未响应怎么办?
  10. 重启服务器后docker wordpress “Error establishing a database connection”解决办法
  11. java程序设计俄罗斯方块_俄罗斯方块单人游戏JAVA程序设计
  12. 【练习】获取新浪搜索中的热搜榜的标题
  13. php微信支付mch_id参数格式错误,在.net core上,Web网站调用微信支付-统一下单接口(xml传参)一直返回错误:mch_id参数格式错误...
  14. 如何利用华硕Mesh系统路由器在780平方公尺大的场域架设可靠的WiFi系统?
  15. DDD的模式与实践案例
  16. python flask ajax_Python flask+css+js+ajax 综合复习
  17. 计算机word表格基础,Word表格的作-计算机基础.doc
  18. Win7升为Win10以及win7系统的重装
  19. 2022年云南最新建筑八大员(市政)模拟考试题库及答案
  20. 《祝你一路顺风》-吴奇隆(吉他谱)

热门文章

  1. 勒索病毒WannaCry(永恒之蓝)
  2. 《Visual Studio Code权威指南》读后总结
  3. Unity中Web.Config文件的配置与调用
  4. 【汇编语言与计算机系统结构笔记19】虚存概念初步,MIPS内存管理
  5. 【数据结构笔记11】二叉搜索树,动态查找,删除操作
  6. Oracle RAC 11R2配置归档、删除策略,闪回配置完整版
  7. 用GVIM/VIM写Verilog——VIM配置分享
  8. 通过自定义Module实现URl重写和登陆验证
  9. Android中文API(142) —— Gravity
  10. 各种编程技术中的$符的使用