接触机器学习和深度学习已经有一段时间了,一直想做个记录,方便自己以后的查阅。

一开始我搭建神经网络时所使用的框架是Tensorflow,虽然功能强大但是不同版本代码的兼容性有些差强人意。

以下的内容所创建的环境是Anaconda中的虚拟环境,采用的python版本是3.8,cuda和cudnn都是对应的版本。

搭建和训练神经网络分为以下几个步骤:1.数据集的准备 2.神经网络的搭建 3.反向传播调整参数降低loss值

只是一个记录。

一. 数据集的准备

搭建完一个神经网络之后必不可少的步骤是对其进行训练,不断降低其loss值,修改参数。在此之前介绍一下与数据集息息相关的两个函数DatasetsDataLoader.

首先来说说Datasets,Datasets是torch下工具包的一个primitive,中文翻译为基元

Datasets

import torchvision

Dataset提供了很多已经收集好标记好的数据集,如Image Datasets,Text Datasets,Audio Datasets.我们可以通过以下方式来下载数据集。

dataset=torchvision.datasets.CIFAR10(root="./dataset",train=True,tansform=torchvision.transforms.ToTensor(),download=True)

此时我们下载的是CIFAR10数据集,该数据集提供了10个class的图片,分为planes,cars,truck等,所以该数据集适合多分类网络的入门。

root代表我们下载数据集所存储的位置,一般是直接放在项目列表下;train指定训练或测试数据集,如果设置为True则设置为训练集,如果为False则为测试集;download设置为True的话是如果我们的根目录里没有检测到该数据集,则从网上进行下载。

我们都知道神经网络需要的input是tensor数据类型,也就是张量,所以我们在加载图片之前需要将图片数据转换为tensor数据类型。torchvision提供了ToTensor的方法来进行转变,我们在下载数据集时可以直接进行转换。

DataLoader

  当下载完数据集后,我们需要装载数据集,就像打牌时发完牌我们需要用手去抓取,一次抓四张还是一次抓五张由我们自己决定。

dataloader=DataLoader(dataset=dataset,batch_size=32,shuffle=True)

dataset:要使用的数据集

batch_size:每个batch有多少个样本,就和我们一次抓多少张牌

shuff:在每个epoch开始的时候对数据进行重新排序

二.神经网络的搭建

这里我们主要使用的是torch.nn,它为我们封装好了现成的函数

class Net(nn.Module):#继承nn.Moduledef __init__(self):super(Net,self).__init__()#初始化self.module=nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,input):output=self.module(input)return output

这里搭建module使用的是nn.Sequential,如果一步一步来搭建网络,在forward时我们还得重新写一遍。

以上仅仅使用了卷积层、最大池化层和线性层,激活层并未使用。

三.读取数据并进行反向传播

接下来我们要把数据加载进神经网络并进行反向传播。

net=Net()#实例化网络loss=nn.CrossEntryLoss()optim=torch.optim.SGD(net.parameters(),lr=0.01)for epoch in range(100):for data in dataloader:imgs,targets=dataoutputs=net(imgs)loss_result=loss(outputs,targets)optim.zero_grad()#梯度清零loss.backward()#进行反向传播计算梯度optim.step()#进行参数优化print(loss_result)

这里我们设置了50个epoch。在第一个epoch时loss值为2.3左右

在第40个epoch时loss值已经下降到0.4左右

在第100个epoch时已经下降到0.001左右了

下面我把全部的代码贴出来,做个备份。

# -*- codeing = utf-8 -*-
# @Time : 2021/7/21 14:45
# @Author : ZY
# @File : caculate_loss.py
# Software : PyCharm
# Goal:Caculate the loss of CIFAR10 dataset using CrossEntrylossimport torchvision
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("pytorch代码/data",train=False,transform=torchvision.transforms.ToTensor(),download=False)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)#build the net
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xnet=Net()
loss=nn.CrossEntropyLoss()
optim=torch.optim.SGD(net.parameters(),lr=0.01)
for epoch in range(100):print(epoch)for data in dataloader:imgs, targrts=dataoutputs=net(imgs)# print(outputs)result_loss=loss(outputs,targrts)optim.zero_grad()result_loss.backward()optim.step()print(result_loss)

Pytorch(1) 学习笔记-多分类网络的搭建相关推荐

  1. pyTorch——基础学习笔记

    pytorch基础学习笔记博文,在整理的时候借鉴的大量的网上资料,存在和一部分图片定义的直接复制黏贴,在本博文的最后将会表明所有的参考链接.由于参考的内容众多,所以博文的更新是一个长久的过程,如果大佬 ...

  2. Pytorch Document学习笔记

    Pytorch Document学习笔记 Pytorch Document学习笔记 1. 网络层 1.1 torch.nn.Conv2d 1.2 torch.nn.MaxPool2d / torch. ...

  3. HALCON 20.11:深度学习笔记(10)---分类

    HALCON 20.11:深度学习笔记(10)---分类 HALCON 20.11.0.0中,实现了深度学习方法. 本章解释了如何在训练和推理阶段使用基于深度学习的分类. 基于深度学习的分类是一种对一 ...

  4. R语言与机器学习学习笔记(分类算法)

    转载自:http://www.itongji.cn/article/0P534092014.html 人工神经网络(ANN),简称神经网络,是一种模仿生物神经网络的结构和功能的数学模型或计算模型.神经 ...

  5. Linux+javaEE学习笔记之Linux网络环境配置

    Linux+javaEE学习笔记之Linux网络环境配置 网络知识简单介绍: Ip地址是:IP地址是IP协议提供的一种统一的地址格式,它为互联网上的每一个网络和每一台主机分配一个逻辑地址,以此来屏蔽物 ...

  6. 数通学习笔记1 - 数据通信网络基础

    数通学习笔记1 - 数据通信网络基础 数据通信网络基础 数通学习笔记1 - 数据通信网络基础 前言 一.通信与网络 1. 什么是通信.什么是网络通信? 2. 信息传递过程 3. 数据通信网络 二.网络 ...

  7. Neutron学习笔记2-- Neutron的网络实现模型

    Neutron学习笔记2-- Neutron的网络实现模型 Neutron的三类节点 计算节点 网络节点 控制节点 Neutron将在这三类节点中进行部署,Neutron在各个计算节点,网络节点中运行 ...

  8. nginx学习笔记-01nginx入门,环境搭建,常见命令

    nginx学习笔记-01nginx入门,环境搭建,常见命令 文章目录 nginx学习笔记-01nginx入门,环境搭建,常见命令 1.nginx的基本概念 2.nginx的安装,常用命令和配置文件 3 ...

  9. 华芯微特SWM181学习笔记--GPIO应用与环境搭建

    华芯微特SWM181 系列 32 位 MCU(以下简称 SWM181)内嵌 ARM® CortexTM-M0 内核, SWM181 支持片上包含精度为 1%以内的 24MHz.48MHz 时钟,并提供 ...

最新文章

  1. OpenJudge/Poj 2001 Shortest Prefixes
  2. 安卓桌面精灵_小米MIUI安卓Q来啦,超多黑科技!凭啥红米先尝鲜?内附预览图...
  3. 字符串的方法、注释及示例1.
  4. eclipse无法创建java虚拟机_2020年哪些IDE是最适合Java开发人员的?
  5. 蚂蚁金服王旭:开源的意义是把社区往前推进一步
  6. python属于私有属性_Python私有属性和受保护的属性原理解析
  7. 【连载】如何掌握openGauss数据库核心技术?秘诀五:拿捏数据库安全(4)
  8. python出现Unresolved import:库名,已解决
  9. 一道实用linux运维问题的9种shell解答方法!
  10. qchart 怎么点击一下 出一条线_疏通身上一条线,很多难缠病,不知不觉消失了!...
  11. PCB传输线阻抗计算工具Polar Si9000的安装方法
  12. 服务器2012怎么安装无线网卡驱动,如何安装usb无线网卡驱动,教您如何安装电脑usb驱动...
  13. 德芙网络营销策略ppt_德芙网络营销案例ppt采集
  14. ITIL服务管理知识体系的介绍
  15. Linux怎么复制文件到其他文件夹
  16. iOS开发中的Web应用概述
  17. 破解微信数据库 并查询数据上传服务器
  18. js添加到桌面快捷方式(实现功能)
  19. 中国半导体芯片产业布局图(2022版)-爱普搜汽车
  20. RPC实践(二)JsonRPC实践

热门文章

  1. 微服务设计第 1 章 微服务
  2. 文件夹有把锁怎么去掉Linux,文件夹有锁图标怎么去掉?
  3. 树莓派最轻便的图形界面连接(无网线、无路由)
  4. [EXP]CVE-2019-9621 Zimbra小于8.8.11 远程代码执行漏洞 XXE GetShell Exploit
  5. android 音乐app 保活,aggregationProject聚合项目
  6. 九阳真经名句欣赏:清风拂山冈
  7. 全栈工程师到底有什么用
  8. Samsung手机验证
  9. 日产佳奔_Nissan的历史
  10. Scrapy框架-redis分布式(从Scrapy框架创建项目到redis分布式)