本文与《20天吃透Pytorch》有所不同,《20天吃透Pytorch》中是继承之前的模型进行拟合,本文是单独建立网络进行拟合。

代码实现:

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset"""
1.准备数据
"""
n=800   #样本数量#生成测试用的数据集
X = 10*torch.rand([n,2])-5.0    #torch.rand是均匀分布
w0 = torch.tensor([[2.0],[-3.0]])
b0 = torch.tensor([10.0])
Y = X@w0 + b0 + torch.normal(0.0,2.0,size=[n,1])    ## @表示矩阵乘法,增加正态扰动#数据可视化
plt.figure(figsize= (12,5))
ax1 = plt.subplot(121)
ax1.scatter(X[:,0],Y[:,0],c = 'b',label = 'samples')
ax1.legend()    #图例
plt.xlabel("x1")
plt.ylabel("y",rotation = 0)
ax2 = plt.subplot(122)
ax2.scatter(X[:,1],Y[:,0],c = 'g',label = 'samples')
ax2.legend()
plt.xlabel('x2')
plt.ylabel('y',rotation = 0)
plt.show()"""
构建通道
"""ds = TensorDataset(X,Y)
ds_train,ds_valid = torch.utils.data.random_split(ds,[int (n*0.7),n-int(n*0.7)])  #选取总样本的70%为训练数据
dl_train = DataLoader(ds_train,batch_size=10,shuffle=True)
dl_valid = DataLoader(ds_valid,batch_size=10,shuffle=True)"""
2.定义模型
"""class LinearRegression(torch.nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.fc = nn.Linear(2,1)def forward(self,x):x = self.fc(x)return xnet = LinearRegression()
"""
3.训练模型
"""
loss_func = torch.nn.MSELoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)eporchs = 10
log_step_freq = 20for eporch in range(1,eporchs+1):net.train()loss_sum = 0.0metric_sum = 0.0step = 1for step,(features,labels) in enumerate(dl_train,1):predictions = net(features)loss = loss_func(predictions,labels)optimizer.zero_grad()loss.backward()optimizer.step()w = net.state_dict()["fc.weight"]b = net.state_dict()["fc.bias"]print("step =", step, "loss = ", loss)print("w =", w)print("b =", b)loss_sum += loss.item()"""
结果可视化
"""
w,b = net.state_dict()["fc.weight"],net.state_dict()["fc.bias"]plt.figure(figsize = (12,5))
ax1 = plt.subplot(121)
ax1.scatter(X[:,0],Y[:,0], c = "b",label = "samples")
ax1.plot(X[:,0],w[0,0]*X[:,0]+b[0],"-r",linewidth = 5.0,label = "model")
ax1.legend()
plt.xlabel("x1")
plt.ylabel("y",rotation = 0)ax2 = plt.subplot(122)
ax2.scatter(X[:,1],Y[:,0], c = "g",label = "samples")
ax2.plot(X[:,1],w[0,1]*X[:,1]+b[0],"-r",linewidth = 5.0,label = "model")
ax2.legend()
plt.xlabel("x2")
plt.ylabel("y",rotation = 0)plt.show()

结果展示:

数据部分:

线性回归结果:

Pytorch高阶API示范——线性回归模型相关推荐

  1. Pytorch高阶API示范——DNN二分类模型

    代码部分: import numpy as np import pandas as pd from matplotlib import pyplot as plt import torch from ...

  2. 【进阶篇】全流程学习《20天掌握Pytorch实战》纪实 | Day10 | 高阶API示范

  3. 速成pytorch学习——4天中阶API示范

    使用Pytorch的中阶API实现线性回归模型和和DNN二分类模型. Pytorch的中阶API主要包括各种模型层,损失函数,优化器,数据管道等等. 一,线性回归模型 1,准备数据 import nu ...

  4. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  5. 华为昇思高阶API套件迎来全新升级!解决无人驾驶疑难杂症真得靠它!

    点击蓝字 MindSpore 关注我们 对于程序员来说,拥有一款低门槛.易操作的深度学习开发工具包,可以说赢在了起跑线!来自华为的全场景AI框架昇思MindSpore在历经短短一年多时间的迭代,为专业 ...

  6. 【进阶篇】全流程学习《20天掌握Pytorch实战》纪实 | Day09 | 中阶API示范

  7. 【进阶篇】全流程学习《20天掌握Pytorch实战》纪实 | Day08 | 低阶API示范

  8. tensorflow使用高阶api导致训练不收敛问题

    摘要 本文将低级api实现的tensorflow网络移植到高级api上遇到的loss值不变和训练结果不收敛问题 引言 tensorflow版本更新很快,猛一回头发现已经推出更高级的api了 主题 te ...

  9. WebDriver高阶API(8)

    17.测试HTML5语言实现的视频播放器 #encoding=utf-8 import unittest import time from selenium import webdriverclass ...

最新文章

  1. c语言double变量后面几个0,C语言double型变量的初始化到底是是0还是0.0?
  2. linux文件属性 -rwxr-xrw,Linux文件属性
  3. 瑞数动态安全:做一个牵着黑客鼻子走的移动靶心
  4. 怎样使用两行代码实现博客园打赏功能
  5. 再见 Postman!Apifox 才是 YYDS!
  6. js正则表达式验证密码
  7. python时间模块time
  8. 设置windows服务依赖项
  9. 错误 Cannot load driver class: com.mysql.jdbc.Driver
  10. windows的又一个问题
  11. matlabadftest_adf检验matlab程序
  12. 关于长论文word转PDF,出现图等错误解决办法
  13. Dapper系列之三:Dapper的事务修改与删除
  14. uniapp打开pdf文件
  15. OMG,史上最全的37个APP推广渠道来啦!
  16. AP作为WLAN用户接入认证点的PEAP用户接入流程
  17. 8 个你应该了解的环保开源项目
  18. HITS算法--从原理到实现
  19. 移动安全之Android安全检测工具大全
  20. 华为HyperSnap特性应用场景演练

热门文章

  1. leetcode(一)刷题两数之和
  2. 【工作经验分享】这些新技术你们都知道吗
  3. datatable转化泛型
  4. merge intervals(合并间隔)
  5. JDBC01 利用JDBC连接数据库【不使用数据库连接池】
  6. [阅读笔记]Zhang Y. 3D Information Extraction Based on GPU.2010.
  7. [archlinux][hardware] 查看SSD的使用寿命
  8. (6)css盒子模型(基础下)
  9. poj 3436 (最大流)
  10. Nagios:企业级系统监控方案