Pytorch安装测试训练建自己的数据集

  • 前言
  • 一、PyTorch是什么?
  • 二、PyTorch环境搭建
    • 1.设备要求
    • 2.安装Pytorch
    • 3.验证PyTorch
  • 二、CIFAR10测试
    • 1、关于CIFAR10
    • 2、训练图像分类器 Training an image classifier
      • (1)加载并归一化CIFAR10 Loading and normalizing CIFAR10
      • (2)定义卷积神经网络 Define a Convolutional Neural Network
      • (3)定义损失函数和优化器 Define a Loss function and optimizer
      • (4)训练网络 Train the network
      • (5)在测试数据上测试网络 Test the network on the test data
  • 三、制作自己的数据集
    • 1、数据收集(通过爬虫)
    • 2、划分数据集
    • 3、训练模型+精度验证
  • 总结

前言

由于计算机视觉这门课程的作业需要,本着萌新探索的态度,对Pytorch进行学习。


提示:以下是本篇文章正文内容,大家可以跟着我一起来学,里面都加了一些注解,有不妥之处敬请指正。

一、PyTorch是什么?

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。
2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。
目前我们课程的要求是自制数据集实现分类。


二、PyTorch环境搭建

1.设备要求

首先你需要有一台PC,如果有NVIDA的显卡那最好不过。

2.安装Pytorch

前往官网来进行安装代码的获取:


注意:此处的CUDA我不选,一是因为不想加入GPU,二是不加CUDA下载起来真的很快(因为我只是做个小test)

conda install pytorch torchvision torchaudio cpuonly -c pytorch

此处为安装的命令

3.验证PyTorch

在终端处输入命令:

python
import torch
torch.Tensor()

没有问题的话,就可以正常继续下一步了!


二、CIFAR10测试

首先跟着官网的教程,来一个60分钟闪电战,利用CIFAR10数据集做一次训练分类器简单的尝试

1、关于CIFAR10

在本教程中,我们将使用CIFAR10数据集。它具有以下类别:“飞机”,“汽车”,“鸟”,“猫”,“鹿”,“狗”,“青蛙”,“马”,“船”,“卡车”。CIFAR-10中的图像尺寸为3x32x32,即尺寸为32x32像素的3通道彩色图像。

2、训练图像分类器 Training an image classifier

我们将按顺序执行以下步骤:

1、使用以下命令加载和标准化CIFAR10训练和测试数据集 torchvision
2、定义卷积神经网络
3、定义损失函数
4、根据训练数据训练网络
5、在测试数据上测试网络

(1)加载并归一化CIFAR10 Loading and normalizing CIFAR10

使用torchvision,加载CIFAR10。
加载CIFAR10的数据集和训练集

import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#torchvision数据集的输出是[0,1]范围的PILImage图像。我们将它们转换为归一化范围[-1,1]的张量trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
#如果报错[Errno 32] Broken pipe,把两个num_workers=0
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

运行以上代码,就会自己开始加载CIFAR10的数据集。(不过是真的很慢,我都佛了。。)

我们可以看看训练图像有什么东西:
把该代码加入之前的代码中:

import matplotlib.pyplot as plt
import numpy as npdef imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

就可以看到训练图像是什么了

(2)定义卷积神经网络 Define a Convolutional Neural Network

我们定义一个神经网络,其通道数为三

import torch.nn as nn
import torch.nn.functional as F#我们输入的CIFAR10为3*32*32,即3通道32*32像素,四维张量[N, C, H, W] c=3 H=32 W=32(数量,通道,高、宽)class Net(nn.Module):def __init__(self):super(Net, self).__init__()   self.conv1 = nn.Conv2d(3, 6, 5)  self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):#relu是激活函数ReLU,防止梯度消失x = self.pool(F.relu(self.conv1(x)))#conv1 通道由3变为6通道 用了5*5的卷积核,输出图像应为6*28*28#pool 最大值池化,卷积核2*2,步长2,输出图像为6*14*14x = self.pool(F.relu(self.conv2(x)))#conv2 卷积核为5*5,输出图像为16*10*10#pool 最大值池化,卷积核2*2,步长2,输出图像为16*5*5x = x.view(-1, 16 * 5 * 5)#.view()会将Tensor转成特定维数空间,-1:全部元素,分配为16*5*5,将四维张量转换为二维张量之后,才能作为全连接层的输入,此时x.shape应为#矢量化,将二维特征图转化成一维的一个向量x = F.relu(self.fc1(x))#120全连接 把(n,400)变为(n,120)x = F.relu(self.fc2(x))#84全连接,把(n,120)变为(n,84)x = self.fc3(x)#10全连接return xnet = Net()

(3)定义损失函数和优化器 Define a Loss function and optimizer

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

(4)训练网络 Train the network

for epoch in range(2):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

完成后的效果如图

可以通过下列代码来保存训练模型:

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
print('Finished Saving')

(5)在测试数据上测试网络 Test the network on the test data

#第一步。让我们显示测试集中的图像
dataiter = iter(testloader)
images, labels = dataiter.next()# 输出图像
imshow(torchvision.utils.make_grid(images))   #此处我出现了错误,就把他注释掉了继续运行
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))net = Net()
net.load_state_dict(torch.load(PATH)) #加载保存的模型# 好的,现在让我们看看神经网络对以上这些示例的看法:
outputs = net(images)#输出是10类的能量。一个类别的能量越高,网络就认为该图像属于特定类别。因此,让我们获得最高能量的指数:
_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]for j in range(4)))#看一下网络在整个数据集上的表现
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))#哪些类的表现良好,哪些类的表现不佳:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

此处为最后的输出结果图:
可以看出训练集正确率有52%(如果不训练,按照10组中1组应该是10%,网络还是学到了)

下列为官方给出的结果,可以看出无论是哪种,ship的正确率都是最高的。

PS:以上就是根据官方教程做的一次小test,接下来要尝试自己制作数据集训练集,在此基础上进行二分类的识别以及精度验证。


三、制作自己的数据集

1、数据收集(通过爬虫)

做数据集需要自己搜集数据,我得想法是通过爬虫来爬取图片数据,由于是二分类的任务,我选择爬取了玫瑰和牡丹的图片,以此作为数据集和训练集。
实现方法参考的是百度图片爬虫
在此不做过多赘述,大家可以去该文章查看。
分别爬取了玫瑰300张和牡丹300张照片,并对其中的数据进行筛选。

2、划分数据集

针对我们所收集到的图片,我们需要把数据划分为数据集和训练集
参考了Pytorch将数据集划分为训练集、验证集和测试集这篇文章,
首先建立了两个文件夹src_data和target_data两个文件夹,将存放两个类别图片的文件夹放入src,然后运行代码,对数据进行划分。
分别得到了训练集、测试集。(我不需要验证集)

3、训练模型+精度验证

对PyTorch进行了修改之后即可得到所需。
其中要注意的是,对于数据的引用请使用ImageFolder

trainset = datasets.ImageFolder(root=r'target_data\train', transform=transform)

最后的结果精度如下

总结

非常感谢各位的阅读,后续数据集的篇章写的简略了一些,有问题的可以提出来,谢谢观看!

PyTorch安装测试训练建自己的数据集相关推荐

  1. mmdetection的安装并训练自己的VOC数据集

    mmdetection的安装并训练自己的VOC数据集 mmdetection的安装与VOC数据集的训练 一. mmdetection的安装 1.使用conda创建虚拟环境 2.安装Cython 3.安 ...

  2. WIN10 +pytorch版yolov3训练自己数据集

    pytorch版yolov3训练自己数据集 目录 1. 环境搭建 2. 数据集构建 3. 训练模型 4. 测试模型 5. 评估模型 6. 可视化 7. 高级进阶-网络结构更改 1. 环境搭建 将git ...

  3. 扫盲:mmdetection安装以及训练自己的数据集

    一.安装 # 创建环境名为mmdet conda create -n mmdet python=3.7 # 激活环境mmdet conda activate mmdet # 安装pytorch1.6 ...

  4. AI 图片截取、ffmpeg使用及安装, anaconda环境,图片标注(labelme),模型训练(yolov5),CUDA+Pytorch安装及版本相关问题

    AI 图片截取(ffmpeg), anaconda环境,图片标注(labelme),模型训练(yolov5),CUDA+Pytorch安装及版本相关问题 一.截取有效图片 录制RTSP视频脚本 #!/ ...

  5. Pytorch移植Deeplabv3训练CityScapes数据集详细步骤

    源代码链接: https://github.com/fregu856/deeplabv3#paperspace 这个源代码相对简单,可以用来仔细看下,学习思想. 环境配置 源代码使用的是pytorch ...

  6. 车道线检测laneatt算法实战CULane Datasets、Tusimple数据集——安装运行训练步骤

    简单记录一下. 1.配置.训练步骤 可以新建一个虚拟环境,专门跑laneatt算法,方便管理. 新建之后切换到该虚拟环境(博主的是叫laneatt)执行: conda instasll pytorch ...

  7. SSD 安装、训练、测试(ubuntu14.04+cuda7.5+openvc2.4.9)

    安装步骤 1.安装git,下载SSD源码包 sudo apt-get install git git clone https://github.com/weiliu89/caffe.git cd ca ...

  8. Pytorch 版YOLOV5训练自己的数据集

    1.环境搭建 https://github.com/ultralytics/yolov5 2.安装需要的软件 pip install -U -r requirements.txt 3.准备数据 在da ...

  9. 使用pytorch版faster-rcnn训练自己数据集

    使用pytorch版faster-rcnn训练自己数据集 引言 faster-rcnn pytorch代码下载 训练自己数据集 接下来工作 参考文献 引言 最近在复现目标检测代码(师兄强烈推荐FPN, ...

最新文章

  1. 自己写的一个测试函数执行效率的单元(test on Delphi 7)
  2. ajax对服务器路径请求
  3. python文件编译_我算是白学Python了,现在才知道原来Python是可以编译的
  4. 零件库管理信息系统设计--part03:管理员登录部分设计
  5. python中argparse模块
  6. 【MIPS汇编】ADDI,ADDIU,ADD,ADDU的区别、有符号无符号的谬误
  7. Step by Step 使用AET 创建Product extension fields
  8. jQuery 常用的方法
  9. mysql关系数据库_关系型数据库MySql简介
  10. Axure RP 9
  11. 【利用FLASH制作交互式课件】
  12. 肖文吉mysql_疯狂软件教育中心肖文吉老师_MYSQL视频教程
  13. 汽车洒水器的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告
  14. 【知乎答案】2018校招,笔试应该怎么准备?|牛客网回答
  15. 新基建安全怎么做?看看这场院士领衔的高峰对话
  16. 3.2.2 nodeMCU固件烧录
  17. C++桌面小精灵:实现像Office助手一样的帮助精灵
  18. 基于Web SCADA平台构建实时数字化产线 - 初篇
  19. mysql中dist_在SQL语句中dist是什么意思
  20. Linux学习笔记:联想拯救者Y7000进BIOS

热门文章

  1. 用html做成的音频播放器,HTML5制作酷炫音频播放器插件图文教程
  2. “知识共享”实例:鲁宾逊微积分轮番投放全国高校三月有余
  3. Microsoft SUS Deployment
  4. 苹果手机上网很慢_别再重启iPhone了!信号不好,这样设置让手机网速如飞
  5. 【java】微信文章抓取
  6. lingo标准模型与编程(附习题、代码)
  7. 还记得这门古老的编程语言么,送你一份perl书单!
  8. 【obs】转载:OBS直播严重延迟和卡顿怎么办?
  9. 二阶常系数非齐次线性微分方程特解的设定规则
  10. 什么是域名劫持?遇到域名劫持要怎么处理