以一个简单的全连接神经网络为例,介绍网络的定义过程

import torch
import torch.nn as nn
from torch.optim import SGD
import torch.utils.data as Data
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

以波士顿放假数据集为例,load_boston载入数据,StandardScaler对数据进行标准化,Data模块对数据进行预处理

#读取数据
boston_X,boston_Y = load_boston(return_X_y=True)
print("boston_X.shape:",boston_X.shape)
# plt.figure()
# plt.hist(boston_Y,bins=20)
# plt.show()#数据标准化处理
ss = StandardScaler(with_mean=True,with_std=True)
boston_Xs = ss.fit_transform(boston_X)
#将数据预处理为可以使用pytorch进行批量处理形式#训练数据集x转化为张量
train_xt = torch.from_numpy(boston_Xs.astype(np.float32))
#训练集y转化为张量
train_yt = torch.from_numpy(boston_Y.astype(np.float32))
#将训练集转化为张量后,使用tensordataset将x和y整理到一起
train_data = Data.TensorDataset(train_xt,train_yt)##定义一个数据加载其,将训练数据集进行批量处理
train_loader = Data.DataLoader(dataset=train_data,#使用的数据集batch_size = 128,#批处理样本大小shuffle=True,#每次迭代前打乱数据#num_workers=1#使用两个进程,电脑不行改成0才能运行num_workers=0)#使用继承Module的方式定义全连接神经网络
class MLPmodel(nn.Module):def __init__(self):super(MLPmodel,self).__init__()#定义第一个隐含层self.hidden1 = nn.Linear(in_features=13,#第一个隐含层的输入,数据的特征数out_features=10,#第一个隐藏层的输出,神经元的数量bias = True ,#,默认会有偏执项)self.active1 = nn.ReLU()#定义第二个隐含层self.hidden2 = nn.Linear(10,10)self.active2 = nn.ReLU()#定义预测回归层self.regression = nn.Linear(10,1)#定义网络的前向传播路径def forward(self,x):x = self.hidden1(x)x = self.active1(x)x = self.hidden2(x)x = self.active2(x)output = self.regression(x)#输出为outputreturn outputmlpl = MLPmodel()
print(mlpl)#对回归模型mlpl进行训练并输出损失函数的变化情况,定义优化器和损失函数
optimizer = SGD(mlpl.parameters(),lr = 0.001)
loss_func = nn.MSELoss()
train_loss_all = []#进行训练,并输出每次迭代的损失函数
for epoch in range(100):for step,(b_w,b_y) in enumerate(train_loader):output = mlpl(b_w).flatten()train_loss = loss_func(output,b_y)optimizer.zero_grad()train_loss.backward()optimizer.step()train_loss_all.append(train_loss.item())plt.figure()
plt.plot(train_loss_all,"r-")
plt.title("Train loss per iterration")
plt.show()

损失函数值的变化情况

2 使用nn.sequential()函数进行定义网络

class MLPmodel2(nn.Module):def __init__(self):super(MLPmodel2,self).__init__()#定义隐含层self.hidden= nn.Sequential(nn.Linear(13,10),nn.ReLU(),nn.Linear(10,10),nn.ReLU(),)#预测回归层self.regression = nn.Linear(10,1)#定义网络的前向传播路径def forward(self,x):x = self.hidden(x)output = self.regression(x)return outputmlp2 = MLPmodel2()
print(mlp2)#对回归模型mlpl进行训练并输出损失函数的变化情况,定义优化器和损失函数
optimizer = SGD(mlp2.parameters(),lr = 0.001)
loss_func = nn.MSELoss()
train_loss_all = []#进行训练,并输出每次迭代的损失函数
for epoch in range(100):for step,(b_w,b_y) in enumerate(train_loader):output = mlp2(b_w).flatten()train_loss = loss_func(output,b_y)optimizer.zero_grad()train_loss.backward()optimizer.step()train_loss_all.append(train_loss.item())plt.figure()
plt.plot(train_loss_all,"r-")
plt.title("Train loss per iterration")
plt.show()

损失函数值的变化情况

pytorch学习1:pytorch 定义网络的方式相关推荐

  1. PyTorch学习记录——PyTorch进阶训练技巧

    PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...

  2. Pytorch学习 - Task5 PyTorch卷积层原理和使用

    Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...

  3. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  4. PyTorch学习记录——PyTorch生态

    Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...

  5. pytorch学习笔记-----对抗生成网络GAN

    生成器,判别器 G:生成网络生成的都为假的 D:判别网络判别真实数据与来自生成网络的假数据 判别网络其实就是进行一个图像二分类 生成网络需要fc层输出个数为h*w*c(c=1or3 即为一张图片的形式 ...

  6. 1.pytorch学习:安装pytorch

    目录 安装pytorch 检查pytorch安装是否成功 总结 安装pytorch 官方网址: Start Locally | PyTorchhttps://pytorch.org/get-start ...

  7. 正则学习:组的定义及引用方式

    一个正则表达式匹配结果可以分成多个部分,这就是组(Group). 把一次Match结果用(?<name>)的方式分成组,例子: public static void Main()      ...

  8. PyTorch学习笔记——pytorch图像处理(transforms)

    原始图像 2.图像处理.转不同格式显示 import torch import torchvision import torchvision.transforms as transforms impo ...

  9. 指南 | Pytorch定义网络的几种方法

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者 | ppgod @知乎(已授权) 来源 | https://zhuanlan.zhihu.com/p/8 ...

最新文章

  1. glibc-2.23学习笔记(一)—— malloc部分源码分析
  2. android 图片查看动画,Android 共享动画实现点击列表图片跳转查看大图页面
  3. 动手学无人驾驶(2):车辆检测
  4. PS把一张白色背景的图片设为透明
  5. lte核心网由哪些设备组成_投影地面互动的实现由哪些设备组成?「振邦视界」...
  6. html(+css)/01/html语言基础,标记,标记语法,html文档结构
  7. GO_00:Mac之Item2的配置安装
  8. 点击某些按钮不要触发验证控件
  9. b站主页面视频推荐油猴脚本(更新)
  10. Java面试之项目介绍
  11. crt计算机图形系统是什么东西,计算机图形系统功能.PPT
  12. Protus 8.6 及以上如何找到library文件夹
  13. oracle10非正常删除卸载干净,win10系统下把Oracle卸载干净
  14. qq邮箱发送邮件服务器类型,设置QQ邮箱为发送邮件服务器的详细带图步骤
  15. 怎样查询服务器中标信息,太极中标云服务器
  16. 没有ftp信息服务器,电脑没有ftp服务器配置
  17. 生命旅程中何生命个体
  18. Unity Shader 学习记录(3) —— CG语言和Shader文件
  19. python 图像清晰度_图像清晰度评价指标(Python)
  20. 《甄嬛传》解读--后宫女人的心酸血泪史之腹黑学

热门文章

  1. 速看!上班后如何做好防护?这9点一定要知道
  2. IDEA 自动生成类注释和方法注释
  3. 【插件发布】JAVA微服务框架,Jeecg-P3-Biz-OA 1.0.0 插件开源发布
  4. JEECG - 基于代码生成器的J2EE智能开发框架 续一:开发环境搭建步骤
  5. 作为一个新晋测试经理,在软件测试计划之前你必须知道的10件事
  6. Fibonacci数列使用迭代器实现
  7. LEADTOOLS Multimedia SDK更新:改进RTSP和H.265/H.264的硬件加速
  8. 利用mycat实现mysql数据库读写分离
  9. 3. PowerShell --基本操作,Alias,输出
  10. 切割日志 python版