import numpy as np
import torch
from torch.autograd import Variable
from torch import nn, optim
import matplotlib.pyplot as plt# 设置字体为中文
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 构造成次方矩阵
def make_fertures(x):x = x.unsqueeze(1)return torch.cat([x ** i for i in range(1, 4)], 1)# y = 0.9+0.5*x+3*x*x+2.4x*x*x
W_target = torch.FloatTensor([0.5, 3, 2.4]).unsqueeze(1)
b_target = torch.FloatTensor([0.9])# 计算x*w+b
def f(x):return x.mm(W_target) + b_target.item()def get_batch(batch_size=32):random = torch.randn(batch_size)random = np.sort(random)random = torch.Tensor(random)x = make_fertures(random)y = f(x)if (torch.cuda.is_available()):return Variable(x).cuda(), Variable(y).cuda()else:return Variable(x), Variable(y)# 多项式模型
class poly_model(nn.Module):def __init__(self):super(poly_model, self).__init__()self.poly = nn.Linear(3, 1)  # 输入时3维,输出是1维def forward(self, x):out = self.poly(x)return outif torch.cuda.is_available():model = poly_model().cuda()
else:model = poly_model()
# 均方误差,随机梯度下降
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)epoch = 0  # 统计训练次数
ctn = []
lo = []
while True:batch_x, batch_y = get_batch()output = model(batch_x)loss = criterion(output, batch_y)print_loss = loss.item()optimizer.zero_grad()loss.backward()optimizer.step()ctn.append(epoch)lo.append(print_loss)epoch += 1if (print_loss < 1e-3):breakprint("Loss: {:.6f}  after {} batches".format(loss.item(), epoch))
print("==> Learned function: y = {:.2f} + {:.2f}*x + {:.2f}*x^2 + {:.2f}*x^3".format(model.poly.bias[0], model.poly.weight[0][0],model.poly.weight[0][1],model.poly.weight[0][2]))
print("==> Actual function: y = {:.2f} + {:.2f}*x + {:.2f}*x^2 + {:.2f}*x^3".format(b_target[0], W_target[0][0],W_target[1][0], W_target[2][0]))
# 1.可视化真实数据
predict = model(batch_x)
x = batch_x.numpy()[:, 0]  # x~1 x~2 x~3
plt.plot(x, batch_y.numpy(), 'ro')
plt.title(label='可视化真实数据')
plt.show()
# 2.可视化拟合函数
predict = predict.data.numpy()
plt.plot(x, predict, 'b')
plt.plot(x, batch_y.numpy(), 'ro')
plt.title(label='可视化拟合函数')
plt.show()
# 3.可视化训练次数和损失
plt.plot(ctn,lo)
plt.xlabel('训练次数')
plt.ylabel('损失值')
plt.title(label='训练次数与损失关系')
plt.show()

实验结果:

注意:批量产生数据后,进行一个排序,否则可视化时,不是按照x轴从小到大绘制,出现很多折线。对应代码:

 random = np.sort(random)random = torch.Tensor(random)

pytorch:多项式回归相关推荐

  1. pytorch 多项式回归

    目录 1. 准备数据 2. 随机初始化参数 3. 训练数据 4. 可视化 5. code 1. 准备数据 首先,还是先定义我们的数据集,这里我们定义20个样本 然后真实值y近似等于 1 + 2x + ...

  2. pytorch实现多项式回归

    pytorch实现多项式回归 一元线性回归模型虽然能拟合出一条直线,但精度依然欠佳,拟合的直线并不能穿过每个点,对于复杂的拟合任务需要多项式回归拟合,提高精度.多项式回归拟合就是将特征的次数提高,线性 ...

  3. [PyTorch] 基于python和pytorch的多项式回归

    讲解 须导入和函数库 mport torch import numpy as np from torch.autograd import Variable import torch.nn as nn ...

  4. 【深度学习】基于Pytorch多层感知机的高级API实现和注意力机制(二)

    [深度学习]基于Pytorch多层感知机的高级API实现和注意力机制(二) 文章目录1 代码实现 2 训练误差和泛化误差 3 模型复杂性 4 多项式回归4.1 生成数据集4.2 对模型进行训练和测试4 ...

  5. Lesson 15.2 学习率调度在PyTorch中的实现方法

    Lesson 15.2 学习率调度在PyTorch中的实现方法   学习率调度作为模型优化的重要方法,也集成在了PyTorch的optim模块中.我们可以通过下述代码将学习率调度模块进行导入. fro ...

  6. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  7. Pytorch入门-1

    Pytorch的一些常见的用法 pytorch的一些操作和numpy比较类似,对于熟悉numpy的人比较友好,下面总结一些用法: Tensor (张量) ,有32 位浮点型torch.FloatTen ...

  8. 人工智能AI:TensorFlow Keras PyTorch MXNet PaddlePaddle 深度学习实战 part1

    日萌社 人工智能AI:TensorFlow Keras PyTorch MXNet PaddlePaddle 深度学习实战 part1 人工智能AI:TensorFlow Keras PyTorch ...

  9. 深度学习笔记其三:多层感知机和PYTORCH

    深度学习笔记其三:多层感知机和PYTORCH 1. 多层感知机 1.1 隐藏层 1.1.1 线性模型可能会出错 1.1.2 在网络中加入隐藏层 1.1.3 从线性到非线性 1.1.4 通用近似定理 1 ...

最新文章

  1. java通过对.class文件字节码加密,不被轻易反编译出源代码,分析及其实现。
  2. vim自带的练习教程(vimtutor)
  3. chromium关闭更新_你的Win10系统20H2了吗此乃Win10年度最靠谱的更新还有Win10优化大师助阵...
  4. c语言实验分支程序设计二,C语言程序实验报告分支结构的程序设计(0页).doc
  5. 好消息:Dubbo Spring Boot要来了
  6. 汇编语言的强制类型转换
  7. 【第157期】游戏策划:给@Archer的简历分析
  8. 锂电池电源管理系统设计与实现(单片机)
  9. Linux串口属性设置
  10. Python实现数据透视表
  11. 数据分析|基础概念/excel/tableau自学笔记
  12. 找工作,还是找户口?
  13. ipad查看电脑中的文件
  14. 可编辑div在光标位置插入指定内容
  15. CS5263设计原理图|CS5263设计DP转HDMI电路参考|CS5263中文说明
  16. 2022.6.2 质数(素数)与合数
  17. 基于物理的渲染—更精确的微表面分布函数GGX
  18. 01-【浏览器】chrome浏览器收藏夹(书签)的导出与导入
  19. CocosCreator之KUOKUO带你做自己的艺术数字字体
  20. 报表开发利器FastReport .NET v2022.1 - 支持.NET 6

热门文章

  1. XHTML 相对路径与绝对路径
  2. 每日求一录~20170704
  3. 当PDF页面总数不确定的时候导出PDF增加页码(i of n)
  4. eclipse 如何忽略js文件报错
  5. UOJ #588. 图图的旅行
  6. HDU 6030 Happy Necklace
  7. PHPSTORM下安装XDEBUG
  8. centos node跟npm 安装
  9. 安装完成后在命令行运行bash时报错0x80070057
  10. ![CDATA[ ]]