文章目录

  • 数据
    • 首先导入需要用的一些包
    • 随机生成一组数据
  • 开始搭建神经网络
  • 构建优化目标及损失函数
  • 动态显示学习过程
  • Pytorch是一个开源的Python机器学习库,基于Torch。

  • 神经网络主要分为两种类型,分类和回归,下面就自己学习用Pytorch搭建简易回归网络进行分享。

数据

首先导入需要用的一些包

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable

随机生成一组数据

并加上随机噪声增加数据的复杂性.

x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(3)+0.1*torch.randn(x.size())

将数据转化成Variable的类型用于输入神经网络

为了更好的看出生成的数据类型,我们采用将生成的数据plot出来

x , y =(Variable(x),Variable(y))plt.scatter(x.data,y.data)
# 或者采用如下的方式也可以输出x,y
# plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()


这里由于x,y都是Variable的类型,需要调用data将其输出出来,直接输出也可以.

开始搭建神经网络

以下作为搭建网络的模板,定义类,然后继承nn.Module,再继承自己的超类。

class Net(nn.Module):def __init__(self):super(self).__init__()passdef forward(self):pass

不多说直接搭建网络

为了增加网络的复杂性,网络设置为由两个全连接层组成的隐藏层.

class Net(nn.Module):def __init__(self,n_input,n_hidden,n_output):super(Net,self).__init__()self.hidden1 = nn.Linear(n_input,n_hidden)self.hidden2 = nn.Linear(n_hidden,n_hidden)self.predict = nn.Linear(n_hidden,n_output)def forward(self,input):out = self.hidden1(input)out = F.relu(out)out = self.hidden2(out)out = F.sigmoid(out)out =self.predict(out)return out

为了方便理解,我来画出这个网络的结构

net = Net(1,20,1)
print(net)

简单的网络就搭建好了,通过调用和print可以输出网络的结构.

构建优化目标及损失函数

torch.optim是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性。为了使用torch.optim,你需要构建一个optimizer对象。这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新。

为了构建一个Optimizer,你需要给它一个包含了需要优化的参数(必须都是Variable对象)的iterable。然后,你可以设置optimizer的参 数选项,比如学习率,权重衰减,等等。

optimizer = torch.optim.SGD(net.parameters(),lr = 0.1)
loss_func = torch.nn.MSELoss()

采用随机梯度下降进行训练,损失函数采用常用的均方损失函数,设置学习率为0.1,可以根据需要进行设置,原则上越小学习越慢,但是精度也越高,然后进行迭代训练(这里设置为5000次).

for t in range(5000):prediction = net(x)loss = loss_func(prediction,y)optimizer.zero_grad()loss.backward()optimizer.step()

optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0,即将梯度初始化为零(因为一个batch的loss关于weight的导数是所有sample的loss关于weight的导数的累加和);loss.backward() 对loss进行反向传播, optimizer.step()再对梯度进行优化,更新所有参数。

动态显示学习过程

if t%5 ==0:plt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.text(0.5, 0, 'Loss = %.4f' % loss.data, fontdict={'size': 20, 'color': 'red'})plt.pause(0.05)

附训练开始及结果图


我将网络中的一个激活函数从sigmod激活改成relu激活,将学习率改成了0.01后效果改善了很多。

Pytorch神经网络极简入门(回归)相关推荐

  1. 【2022·深度强化学习课程】深度强化学习极简入门与Pytorch实战

    课程名称:深度强化学习极简入门与Pytorch实战 课程内容:强化学习基础理论,Python和深度学习编程基础.深度强化学习理论与编程实战 课程地址:https://edu.csdn.net/cour ...

  2. RL极简入门:从MDP、DP MC TC到Q函数、策略学习、PPO

    前言 22年底/23年初ChatGPT大火,在写ChatGPT通俗笔记的过程中,发现ChatGPT背后技术涉及到了RL/RLHF,于是又深入研究RL,研究RL的过程中又发现里面的数学公式相比ML/DL ...

  3. 机器学习极简入门课程

    开篇词 | 入门机器学习,已迫在眉睫 大家好,我是李烨.现就职于微软(Microsoft),曾在易安信(EMC)和太阳微系统(Sun Microsystems)任软件工程师.先后参与过聊天机器人.大数 ...

  4. 为 AI 初学者打造的《机器学习极简入门》面世了!

    随着人工智能技术的发展,机器学习已成为软件 / 互联网行业的常用技能,并开始向更多行业渗透.对越来越多的 IT 技术人员及数据分析从业者而言,机器学习正在成为必备技能之一. 今天我们就来聊聊机器学习的 ...

  5. tensorflow平台极简方式_TensorFlow极简入门教程

    原标题:TensorFlow极简入门教程 随着 TensorFlow 在研究及产品中的应用日益广泛,很多开发者及研究者都希望能深入学习这一深度学习框架.本文介绍了TensorFlow 基础,包括静态计 ...

  6. Docker极简入门

    原 Docker极简入门 2018年05月22日 20:25:12 阅读数:44 一.Docker概述 Docker通过一个包括应用程序运行时所需的一切的可执行镜像启动容器,包括配置有代码.运行时.库 ...

  7. .Net Core in Docker极简入门(下篇)

    点击上方蓝字"小黑在哪里"关注我吧 Docker-Compose 代码修改 yml file up & down 镜像仓库 前言 上一篇[.Net Core in Dock ...

  8. Nginx 极简入门教程

    Nginx 极简入门教程 基本介绍 Nginx 是一个高性能的 HTTP 和反向代理 web 服务器,同时也提供了 IMAP/POP3/SMTP服务. Nginx 是由伊戈尔·赛索耶夫为俄罗斯访问量第 ...

  9. Python极简入门教程

    前言 为了方便各位小白能轻松入门Python,同时加深自己对Python的理解,所以创造了"Python极简入门教程",希望能帮到大家,若有错误请多指正,谢谢.极简入门教程代表着不 ...

最新文章

  1. python生成简单的FTP弱口令扫描
  2. Py之uiautomator2:uiautomator2的简介、安装、使用方法之详细攻略
  3. ubuntu安装python-mysqldb
  4. python csv性能_Python 使用和高性能技巧总结
  5. 如何卸载MySQL8.0.11_win10安装mysql8.0.11卸载5.7
  6. 钱准备好!苹果官方账号泄密:iPhone 12明晚发布有戏
  7. MySQL版本升级到5.7.21
  8. Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated column 'userinfo.
  9. STC学习:红外测试
  10. 利用Jwing窗口写程序-----简单计算器(JAVA实用教程2-第五版 第九章 编程题 三(2)小题)
  11. Egg中使用DiyUpload实现图片批量上传
  12. 亿图图示上线小程序,MindMaster移动端迎来大更新,亿图软件八周年再出发
  13. COMSOL报错调试总结(不定期更新)
  14. 【虚拟仿真】Unity3D中实现UI跟随3D模型旋转移动、UI一直面朝屏幕
  15. 微信小程序多音频场景处理 - 背景音频
  16. Debian 下的五笔输入法 Rime
  17. 如何在jupyter中执行带参数的py文件
  18. stm32检测串口空闲的原理
  19. linux添加javahome
  20. C#实现串口通信的上位机开发

热门文章

  1. 致远oa漏洞修复 V5低版本V5.6~V 8.2版本Flash替换为H5化流程图的补丁包
  2. 谷歌浏览器Chrome无法自动同步的解决办法
  3. 《大数据技术原理与应用(第3版)》期末复习——前两章练习题
  4. AppScan(4)安装证书和绕过登录深入扫描
  5. 【根据模板导出多sheet表格数据】
  6. (三十一)国债期货的定价和套期保值
  7. android中PreferenceScreen类的用法
  8. ERP不规范,同事哭晕在厕所
  9. eks安装kubectl
  10. 计算机控制系统的五种类型,计算机控制系统习题(1-5章