对于一般的线性回归模型,由于该函数拟合出来的是一条直线,所以精度欠佳,我们可以考虑多项式回归来拟合更多的模型。多项式回归,原理和线性回归是一样的,无非现在是高次多项式而非一次多项式。

比如说我们想要拟合方程:

我们可以先设置参数方程:

代码实现:

导入相关包,torch用来创建模型,matplotlib用来可视化。POLY_DEGREE指多项式的最高次数。

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch import nnPOLY_DEGREE = 3

预处理数据,把数据变成矩阵形式。

用torch.cat函数来实现Tensor的拼接。

def make_features(x):x = x.unsqueeze(1)return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)

定义真实函数。

W_target = torch.FloatTensor([3,6,2]).unsqueeze(1)
b_target = torch.FloatTensor([8])def f(x):return x.mm(W_target) + b_target.item()

在训练的时候我们需要采样一些点,可以随机生成一批数据来得到训练集。下面的函数可以让我们每次取batch_size这么多个数据,然后将其转化为矩阵形式,再把这个值通过函数之后的结果也返回作为真实的输出值。

def get_batch(batch_size=64):random = torch.randn(batch_size)x = make_features(random)y = f(x)  # + torch.rand(1)return Variable(x), Variable(y)# # show an example of a batch
# x_axis = torch.randn(64)
# y_axis = f(make_features(x_axis)).squeeze()
# plt.title("Original Data Example")
# plt.scatter(x_axis.data.numpy(), y_axis.data.numpy())
# plt.show()

定义多项式模型,定义损失函数和优化器。

class poly_model(nn.Module):def __init__(self):super(poly_model,self).__init__()self.poly = nn.Linear(3,1)def forward(self, x):out = self.poly(x)return outmodel = poly_model()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

定义一个打印函数,方便在训练过程中和训练后打印拟合的多项式表达式。

def poly_desc(W, b):result = 'y = 'for i, w in enumerate(W):result += '{:+.2f} x^{} '.format(w, len(W) - i)result += '{:+.2f}'.format(b[0])return result

开始训练。

epoch = 0
while True:batch_x, batch_y = get_batch()output = model(batch_x)loss = criterion(output, batch_y)print_loss = loss.dataoptimizer.zero_grad()loss.backward()optimizer.step()# if epoch % 50 == 0:#     print('Loss: {:.6f} after {} batches'.format(loss, epoch))#     print('==> Learned function:t' + poly_desc(model.poly.weight.view(-1), model.poly.bias))epoch += 1if print_loss < 1e-3:print()print("==========End of Training==========")breakprint('Loss: {:.6f} after {} batches'.format(loss, epoch))
print('==> Learned function:t' + poly_desc(model.poly.weight.view(-1), model.poly.bias))
print('==> Actual function:t' + poly_desc(W_target.view(-1), b_target))

当64个数的均方误差小于1e-3,训练结束。

参考链接:

https://www.zhihu.com/pub/reader/119606991/chapter/1102563393726173184

多项式拟合怎么确定次数_PyTorch入门4 搭建多项式回归模型相关推荐

  1. 利用numpy对已知样本点进行多项式拟合

    0.导入相关包: import matplotlib.pyplot as plt import numpy as np 1.假设有如下样本点: #使用随机数产生样本点 x=[1,2,3,4,5,6,7 ...

  2. java 多项式拟合最多的项数_Matlab概率统计与曲线拟合

    一.二项分布 二项分布来源于伯努利试验 (事件发生概率 ) : 含义为独立重复N次试验后, 事件总共发生k次的概率 分布函数 二项分布记为 binopdf 获得事件共发生次的概率 binocdf 为事 ...

  3. Python之数据分析(numpy中的多项式拟合)

    1.多项式拟合的概念 用一个无穷级数表示一个可微函数,任何可微的函数,总可以用一个N次多项式来近似,而比N次幂更高阶的部分可以作为无穷小量而被忽略不计. f(x) = p0x^n + p1x^n-1 ...

  4. Matlab光滑曲线多项式拟合与样条曲线拟合的两个案例

    %多项式曲线拟合 figure(1) matrix2=[]; %新建空矩阵 h1=polyfit(matrix1(:,1),matrix1(:,2),3); %计算多项式拟合系数,3-拟合次数 mat ...

  5. matlab 拟合光滑曲线图,Matlab光滑曲线多项式拟合与样条曲线拟合的两个案例

    %多项式曲线拟合 figure(1) matrix2=[]; %新建空矩阵 h1=polyfit(matrix1(:,1),matrix1(:,2),3); %计算多项式拟合系数,3-拟合次数 mat ...

  6. 多项式拟合缺点_多项式拟合

    在网上看别人的心得 一 最小二乘法的基本原理 从整体上考虑近似函数同所给数据点(i=0,1,-,m)误差(i=0,1,-,m)的大小,常用的方法有以下三种:一是误差(i=0,1,-,m)绝对值的最 ...

  7. 数值计算之 拟合法,线性拟合,多项式拟合

    数值计算之 拟合法之线性拟合,多项式拟合 前言 最小二乘法 多项式拟合 线性拟合 后记 前言 拟合法是另一种由采样数据求取潜在函数的方法.插值要求函数必须经过每一个采样节点,而拟合则要求函数与全部节点 ...

  8. polyfit多项式拟合函数的用法

    polyfit函数是matlab中用于进行曲线拟合的一个函数.其数学基础是最小二乘法曲线拟合原理.曲线拟合:已知离散点上的数据集,即已知在点集上的函数值,构造一个解析函数(其图形为一曲线)使在原离散点 ...

  9. MATLAB 线性回归多项式拟合+预测区间、置信区间的绘制

    MATLAB 线性回归多项式拟合+预测区间.置信区间的绘制 一.前言 二.多项式拟合polyfit 1.语法 2.示例 三.区间绘制 四.整体源码 五.思考 六.参考博客 一.前言 现有一组数据:x. ...

最新文章

  1. 数字货币EOS半年时间暴跌90%多,还可追捧吗?
  2. CVPR新规严禁审稿期间公开宣传论文,可发arXiv,LeCun:疯了吧!
  3. Oracle 触发器调用存储过程|转||待研究|
  4. 智慧办公的AI博弈——看飞企互联如何接招!
  5. fibonacci climbing-stairs
  6. 【拔刀吧少年】之循环三兄弟for while until
  7. java第三章_【Java】第三章 变量
  8. 如何使用Java创建AWS Lambda函数
  9. [vue] 怎么在vue中使用插件?
  10. python os renames_Python3 os.renames() 方法
  11. jpa原生query_Spring Data JPA原生SQL查询
  12. SLAM GMapping(6)扫描匹配器
  13. java 屏幕键盘io
  14. 数据泵避免个别表数据的导出
  15. 获取滚动条所在页面位置。做一个类似TX的消息框
  16. 《强化学习》中的第15章:神经科学
  17. TF卡里删掉文件后内存没变大_电视装好kodi后打不开?播放原盘4K很卡?教你怎么解决...
  18. java:线程的六种状态
  19. C++通过生日判断星座
  20. 技术分析:苹果之后 HTML5将改变移动互联网

热门文章

  1. C# RangeHelper
  2. 轻松搞定ServerCore初始设置
  3. 基于×××环境下的远程视频监控传输
  4. inner join on, left join on, right join on讲解
  5. 使用Axis,在webservice的服务器端如何取到客户端的IP地址
  6. 利用sender的Parent获取GridView中的当前行
  7. python画折线图虚线_python绘制简单折线图代码示例
  8. 银行计算机系统(第3版),清华大学出版社-图书详情-《银行计算机系统》
  9. mysql workbench企业_甲骨文发布MySQL Workbench 6.0版本
  10. oracle vm win10,win10系统oraclevm卸载不了错误2503的解决方法