Pytorch(1) 学习笔记-多分类网络的搭建
接触机器学习和深度学习已经有一段时间了,一直想做个记录,方便自己以后的查阅。
一开始我搭建神经网络时所使用的框架是Tensorflow,虽然功能强大但是不同版本代码的兼容性有些差强人意。
以下的内容所创建的环境是Anaconda中的虚拟环境,采用的python版本是3.8,cuda和cudnn都是对应的版本。
搭建和训练神经网络分为以下几个步骤:1.数据集的准备 2.神经网络的搭建 3.反向传播调整参数降低loss值
只是一个记录。
一. 数据集的准备
搭建完一个神经网络之后必不可少的步骤是对其进行训练,不断降低其loss值,修改参数。在此之前介绍一下与数据集息息相关的两个函数Datasets和DataLoader.
首先来说说Datasets,Datasets是torch下工具包的一个primitive,中文翻译为基元
Datasets
import torchvision
Dataset提供了很多已经收集好标记好的数据集,如Image Datasets,Text Datasets,Audio Datasets.我们可以通过以下方式来下载数据集。
dataset=torchvision.datasets.CIFAR10(root="./dataset",train=True,tansform=torchvision.transforms.ToTensor(),download=True)
此时我们下载的是CIFAR10数据集,该数据集提供了10个class的图片,分为planes,cars,truck等,所以该数据集适合多分类网络的入门。
root代表我们下载数据集所存储的位置,一般是直接放在项目列表下;train指定训练或测试数据集,如果设置为True则设置为训练集,如果为False则为测试集;download设置为True的话是如果我们的根目录里没有检测到该数据集,则从网上进行下载。
我们都知道神经网络需要的input是tensor数据类型,也就是张量,所以我们在加载图片之前需要将图片数据转换为tensor数据类型。torchvision提供了ToTensor的方法来进行转变,我们在下载数据集时可以直接进行转换。
DataLoader
当下载完数据集后,我们需要装载数据集,就像打牌时发完牌我们需要用手去抓取,一次抓四张还是一次抓五张由我们自己决定。
dataloader=DataLoader(dataset=dataset,batch_size=32,shuffle=True)
dataset:要使用的数据集
batch_size:每个batch有多少个样本,就和我们一次抓多少张牌
shuff:在每个epoch开始的时候对数据进行重新排序
二.神经网络的搭建
这里我们主要使用的是torch.nn,它为我们封装好了现成的函数
class Net(nn.Module):#继承nn.Moduledef __init__(self):super(Net,self).__init__()#初始化self.module=nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,input):output=self.module(input)return output
这里搭建module使用的是nn.Sequential,如果一步一步来搭建网络,在forward时我们还得重新写一遍。
以上仅仅使用了卷积层、最大池化层和线性层,激活层并未使用。
三.读取数据并进行反向传播
接下来我们要把数据加载进神经网络并进行反向传播。
net=Net()#实例化网络loss=nn.CrossEntryLoss()optim=torch.optim.SGD(net.parameters(),lr=0.01)for epoch in range(100):for data in dataloader:imgs,targets=dataoutputs=net(imgs)loss_result=loss(outputs,targets)optim.zero_grad()#梯度清零loss.backward()#进行反向传播计算梯度optim.step()#进行参数优化print(loss_result)
这里我们设置了50个epoch。在第一个epoch时loss值为2.3左右
在第40个epoch时loss值已经下降到0.4左右
在第100个epoch时已经下降到0.001左右了
下面我把全部的代码贴出来,做个备份。
# -*- codeing = utf-8 -*-
# @Time : 2021/7/21 14:45
# @Author : ZY
# @File : caculate_loss.py
# Software : PyCharm
# Goal:Caculate the loss of CIFAR10 dataset using CrossEntrylossimport torchvision
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("pytorch代码/data",train=False,transform=torchvision.transforms.ToTensor(),download=False)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)#build the net
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xnet=Net()
loss=nn.CrossEntropyLoss()
optim=torch.optim.SGD(net.parameters(),lr=0.01)
for epoch in range(100):print(epoch)for data in dataloader:imgs, targrts=dataoutputs=net(imgs)# print(outputs)result_loss=loss(outputs,targrts)optim.zero_grad()result_loss.backward()optim.step()print(result_loss)
Pytorch(1) 学习笔记-多分类网络的搭建相关推荐
- pyTorch——基础学习笔记
pytorch基础学习笔记博文,在整理的时候借鉴的大量的网上资料,存在和一部分图片定义的直接复制黏贴,在本博文的最后将会表明所有的参考链接.由于参考的内容众多,所以博文的更新是一个长久的过程,如果大佬 ...
- Pytorch Document学习笔记
Pytorch Document学习笔记 Pytorch Document学习笔记 1. 网络层 1.1 torch.nn.Conv2d 1.2 torch.nn.MaxPool2d / torch. ...
- HALCON 20.11:深度学习笔记(10)---分类
HALCON 20.11:深度学习笔记(10)---分类 HALCON 20.11.0.0中,实现了深度学习方法. 本章解释了如何在训练和推理阶段使用基于深度学习的分类. 基于深度学习的分类是一种对一 ...
- R语言与机器学习学习笔记(分类算法)
转载自:http://www.itongji.cn/article/0P534092014.html 人工神经网络(ANN),简称神经网络,是一种模仿生物神经网络的结构和功能的数学模型或计算模型.神经 ...
- Linux+javaEE学习笔记之Linux网络环境配置
Linux+javaEE学习笔记之Linux网络环境配置 网络知识简单介绍: Ip地址是:IP地址是IP协议提供的一种统一的地址格式,它为互联网上的每一个网络和每一台主机分配一个逻辑地址,以此来屏蔽物 ...
- 数通学习笔记1 - 数据通信网络基础
数通学习笔记1 - 数据通信网络基础 数据通信网络基础 数通学习笔记1 - 数据通信网络基础 前言 一.通信与网络 1. 什么是通信.什么是网络通信? 2. 信息传递过程 3. 数据通信网络 二.网络 ...
- Neutron学习笔记2-- Neutron的网络实现模型
Neutron学习笔记2-- Neutron的网络实现模型 Neutron的三类节点 计算节点 网络节点 控制节点 Neutron将在这三类节点中进行部署,Neutron在各个计算节点,网络节点中运行 ...
- nginx学习笔记-01nginx入门,环境搭建,常见命令
nginx学习笔记-01nginx入门,环境搭建,常见命令 文章目录 nginx学习笔记-01nginx入门,环境搭建,常见命令 1.nginx的基本概念 2.nginx的安装,常用命令和配置文件 3 ...
- 华芯微特SWM181学习笔记--GPIO应用与环境搭建
华芯微特SWM181 系列 32 位 MCU(以下简称 SWM181)内嵌 ARM® CortexTM-M0 内核, SWM181 支持片上包含精度为 1%以内的 24MHz.48MHz 时钟,并提供 ...
最新文章
- OpenJudge/Poj 2001 Shortest Prefixes
- 安卓桌面精灵_小米MIUI安卓Q来啦,超多黑科技!凭啥红米先尝鲜?内附预览图...
- 字符串的方法、注释及示例1.
- eclipse无法创建java虚拟机_2020年哪些IDE是最适合Java开发人员的?
- 蚂蚁金服王旭:开源的意义是把社区往前推进一步
- python属于私有属性_Python私有属性和受保护的属性原理解析
- 【连载】如何掌握openGauss数据库核心技术?秘诀五:拿捏数据库安全(4)
- python出现Unresolved import:库名,已解决
- 一道实用linux运维问题的9种shell解答方法!
- qchart 怎么点击一下 出一条线_疏通身上一条线,很多难缠病,不知不觉消失了!...
- PCB传输线阻抗计算工具Polar Si9000的安装方法
- 服务器2012怎么安装无线网卡驱动,如何安装usb无线网卡驱动,教您如何安装电脑usb驱动...
- 德芙网络营销策略ppt_德芙网络营销案例ppt采集
- ITIL服务管理知识体系的介绍
- Linux怎么复制文件到其他文件夹
- iOS开发中的Web应用概述
- 破解微信数据库 并查询数据上传服务器
- js添加到桌面快捷方式(实现功能)
- 中国半导体芯片产业布局图(2022版)-爱普搜汽车
- RPC实践(二)JsonRPC实践
热门文章
- 微服务设计第 1 章 微服务
- 文件夹有把锁怎么去掉Linux,文件夹有锁图标怎么去掉?
- 树莓派最轻便的图形界面连接(无网线、无路由)
- [EXP]CVE-2019-9621 Zimbra小于8.8.11 远程代码执行漏洞 XXE GetShell Exploit
- android 音乐app 保活,aggregationProject聚合项目
- 九阳真经名句欣赏:清风拂山冈
- 全栈工程师到底有什么用
- Samsung手机验证
- 日产佳奔_Nissan的历史
- Scrapy框架-redis分布式(从Scrapy框架创建项目到redis分布式)