pytorch学习1:pytorch 定义网络的方式
以一个简单的全连接神经网络为例,介绍网络的定义过程
import torch
import torch.nn as nn
from torch.optim import SGD
import torch.utils.data as Data
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
以波士顿放假数据集为例,load_boston载入数据,StandardScaler对数据进行标准化,Data模块对数据进行预处理
#读取数据
boston_X,boston_Y = load_boston(return_X_y=True)
print("boston_X.shape:",boston_X.shape)
# plt.figure()
# plt.hist(boston_Y,bins=20)
# plt.show()#数据标准化处理
ss = StandardScaler(with_mean=True,with_std=True)
boston_Xs = ss.fit_transform(boston_X)
#将数据预处理为可以使用pytorch进行批量处理形式#训练数据集x转化为张量
train_xt = torch.from_numpy(boston_Xs.astype(np.float32))
#训练集y转化为张量
train_yt = torch.from_numpy(boston_Y.astype(np.float32))
#将训练集转化为张量后,使用tensordataset将x和y整理到一起
train_data = Data.TensorDataset(train_xt,train_yt)##定义一个数据加载其,将训练数据集进行批量处理
train_loader = Data.DataLoader(dataset=train_data,#使用的数据集batch_size = 128,#批处理样本大小shuffle=True,#每次迭代前打乱数据#num_workers=1#使用两个进程,电脑不行改成0才能运行num_workers=0)#使用继承Module的方式定义全连接神经网络
class MLPmodel(nn.Module):def __init__(self):super(MLPmodel,self).__init__()#定义第一个隐含层self.hidden1 = nn.Linear(in_features=13,#第一个隐含层的输入,数据的特征数out_features=10,#第一个隐藏层的输出,神经元的数量bias = True ,#,默认会有偏执项)self.active1 = nn.ReLU()#定义第二个隐含层self.hidden2 = nn.Linear(10,10)self.active2 = nn.ReLU()#定义预测回归层self.regression = nn.Linear(10,1)#定义网络的前向传播路径def forward(self,x):x = self.hidden1(x)x = self.active1(x)x = self.hidden2(x)x = self.active2(x)output = self.regression(x)#输出为outputreturn outputmlpl = MLPmodel()
print(mlpl)#对回归模型mlpl进行训练并输出损失函数的变化情况,定义优化器和损失函数
optimizer = SGD(mlpl.parameters(),lr = 0.001)
loss_func = nn.MSELoss()
train_loss_all = []#进行训练,并输出每次迭代的损失函数
for epoch in range(100):for step,(b_w,b_y) in enumerate(train_loader):output = mlpl(b_w).flatten()train_loss = loss_func(output,b_y)optimizer.zero_grad()train_loss.backward()optimizer.step()train_loss_all.append(train_loss.item())plt.figure()
plt.plot(train_loss_all,"r-")
plt.title("Train loss per iterration")
plt.show()
损失函数值的变化情况
2 使用nn.sequential()函数进行定义网络
class MLPmodel2(nn.Module):def __init__(self):super(MLPmodel2,self).__init__()#定义隐含层self.hidden= nn.Sequential(nn.Linear(13,10),nn.ReLU(),nn.Linear(10,10),nn.ReLU(),)#预测回归层self.regression = nn.Linear(10,1)#定义网络的前向传播路径def forward(self,x):x = self.hidden(x)output = self.regression(x)return outputmlp2 = MLPmodel2()
print(mlp2)#对回归模型mlpl进行训练并输出损失函数的变化情况,定义优化器和损失函数
optimizer = SGD(mlp2.parameters(),lr = 0.001)
loss_func = nn.MSELoss()
train_loss_all = []#进行训练,并输出每次迭代的损失函数
for epoch in range(100):for step,(b_w,b_y) in enumerate(train_loader):output = mlp2(b_w).flatten()train_loss = loss_func(output,b_y)optimizer.zero_grad()train_loss.backward()optimizer.step()train_loss_all.append(train_loss.item())plt.figure()
plt.plot(train_loss_all,"r-")
plt.title("Train loss per iterration")
plt.show()
损失函数值的变化情况
pytorch学习1:pytorch 定义网络的方式相关推荐
- PyTorch学习记录——PyTorch进阶训练技巧
PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...
- Pytorch学习 - Task5 PyTorch卷积层原理和使用
Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...
- Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用
Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...
- PyTorch学习记录——PyTorch生态
Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...
- pytorch学习笔记-----对抗生成网络GAN
生成器,判别器 G:生成网络生成的都为假的 D:判别网络判别真实数据与来自生成网络的假数据 判别网络其实就是进行一个图像二分类 生成网络需要fc层输出个数为h*w*c(c=1or3 即为一张图片的形式 ...
- 1.pytorch学习:安装pytorch
目录 安装pytorch 检查pytorch安装是否成功 总结 安装pytorch 官方网址: Start Locally | PyTorchhttps://pytorch.org/get-start ...
- 正则学习:组的定义及引用方式
一个正则表达式匹配结果可以分成多个部分,这就是组(Group). 把一次Match结果用(?<name>)的方式分成组,例子: public static void Main() ...
- PyTorch学习笔记——pytorch图像处理(transforms)
原始图像 2.图像处理.转不同格式显示 import torch import torchvision import torchvision.transforms as transforms impo ...
- 指南 | Pytorch定义网络的几种方法
点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者 | ppgod @知乎(已授权) 来源 | https://zhuanlan.zhihu.com/p/8 ...
最新文章
- glibc-2.23学习笔记(一)—— malloc部分源码分析
- android 图片查看动画,Android 共享动画实现点击列表图片跳转查看大图页面
- 动手学无人驾驶(2):车辆检测
- PS把一张白色背景的图片设为透明
- lte核心网由哪些设备组成_投影地面互动的实现由哪些设备组成?「振邦视界」...
- html(+css)/01/html语言基础,标记,标记语法,html文档结构
- GO_00:Mac之Item2的配置安装
- 点击某些按钮不要触发验证控件
- b站主页面视频推荐油猴脚本(更新)
- Java面试之项目介绍
- crt计算机图形系统是什么东西,计算机图形系统功能.PPT
- Protus 8.6 及以上如何找到library文件夹
- oracle10非正常删除卸载干净,win10系统下把Oracle卸载干净
- qq邮箱发送邮件服务器类型,设置QQ邮箱为发送邮件服务器的详细带图步骤
- 怎样查询服务器中标信息,太极中标云服务器
- 没有ftp信息服务器,电脑没有ftp服务器配置
- 生命旅程中何生命个体
- Unity Shader 学习记录(3) —— CG语言和Shader文件
- python 图像清晰度_图像清晰度评价指标(Python)
- 《甄嬛传》解读--后宫女人的心酸血泪史之腹黑学
热门文章
- 速看!上班后如何做好防护?这9点一定要知道
- IDEA 自动生成类注释和方法注释
- 【插件发布】JAVA微服务框架,Jeecg-P3-Biz-OA 1.0.0 插件开源发布
- JEECG - 基于代码生成器的J2EE智能开发框架 续一:开发环境搭建步骤
- 作为一个新晋测试经理,在软件测试计划之前你必须知道的10件事
- Fibonacci数列使用迭代器实现
- LEADTOOLS Multimedia SDK更新:改进RTSP和H.265/H.264的硬件加速
- 利用mycat实现mysql数据库读写分离
- 3. PowerShell --基本操作,Alias,输出
- 切割日志 python版