目录

  • cifar-10介绍
  • 构建模型
  • 优化

cifar-10介绍

任务是对10个类别的对象进行分类,使用cifar-10数据集。cifar-10数据集共有60000张彩色图像,大小为32 * 32 * 3,一共有10个类别,每个类别6000张。其中50000张用于训练,10000用于测试。

构建模型

首先导入必要的包

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import collections
# 使用gpu训练和测试
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'

数据读入和加载
下载并使用PyTorch提供的内置数据集

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 将数据转换为tensor格式并归一化到(-1, 1)区间
train_data = datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
# 读取数据集
train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=4, shuffle=False, num_workers=4)
# 定义DataLoader加载数据,每批次读入数据为batch_size
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 对应类别labels

可视化操作

def imshow(img):img = img / 2 + 0.5npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
images, labels = next(iter(train_loader))
print(images.shape, labels.shape)
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s'%classes[labels[j]] for j in range(4)))

模型设计
由于任务较为简单,我们搭建一个CNN,模型构建完成后,将模型放在GPU上用于训练。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu((self.conv2(x))))x = x.view(-1, 16*5*5)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
net = net.cuda()

设定损失函数和优化器

  • 使用torch.nn模块自带的交叉熵损失函数。
  • 使用Adam优化器,先设置学习率为1e-3,惩罚系数1e-8。
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-7)

训练

def train(epoch):# 设置模型状态为trainnet.train()train_loss = 0num_pred = 0for image, label in train_loader:image, label = image.cuda(), label.cuda()optimizer.zero_grad()out_put = net(image)loss = criterion(out_put, label)loss.backward()optimizer.step()train_loss += loss * len(image)preds = torch.argmax(out_put, 1)num_pred += np.sum(preds.cpu().numpy()==label.cpu().numpy())train_loss /= len(train_loader.dataset)accuracy = num_pred / len(train_loader.dataset)print('Epoch:{}\tTraining Loss:{:.6f}\tTraining Accuracy:{:.6f}'.format(epoch, train_loss, accuracy))

测试

def val(epoch):print('crrent learning rate: ', optimizer.state_dict()["param_groups"][0]["lr"])net.eval()class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))val_loss = 0gt_labels = []pred_labels = []with torch.no_grad():for data, label in test_loader:data, label = data.cuda(), label.cuda()output = net(data)preds = torch.argmax(output, 1)c = (preds == label).squeeze()for i in range(4):label1 = label[i]class_correct[label1] += c[i].item()class_total[label1] += 1gt_labels.append(label.cpu().data.numpy())pred_labels.append(preds.cpu().data.numpy())loss = criterion(output, label)val_loss += loss.item()*data.size(0)val_loss = val_loss / len(test_loader.dataset)gt_labels, pred_labels = np.concatenate(gt_labels), np.concatenate(pred_labels)acc = np.sum(gt_labels == pred_labels) / len(pred_labels)print(F'Epoch:{epoch} \tValidation Loss: {val_loss:6f} , Accuracy: {acc:6f}')for i in range(10):print("Accuracy of %5s : %2d %%" %(classes[i], 100*class_correct[i] / class_total[i]))

模型搭建完毕,开始训练和测试,先进行10个轮次。

for epoch in range(10):train(epoch+1)val(epoch+1)

结果
Epoch:10 Training Loss:0.855442 Accuracy:0.694800
crrent learning rate: 0.001
Epoch:10 Validation Loss:1.158405 Accuracy:0.618900
发现训练和测试精确率都不满意。

优化

加深模型
由于测试精确率较低,我们考虑采用加深模型来提高精确率,将原本的三层卷积加深为四层。
加宽模型

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.conv2 = nn.Conv2d(16, 80, 4)self.conv3 = nn.Conv2d(80, 400, 3)self.conv4 = nn.Conv2d(400, 800, 2, padding=2)self.fc1 = nn.Linear(3200, 400)self.fc2 = nn.Linear(400, 120)self.fc3 = nn.Linear(120, 84)self.fc4 = nn.Linear(84, 10)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = self.conv3(x)x = self.relu(x)x = self.pool(x)x = self.conv4(x)x = self.relu(x)x = self.pool(x)x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)x = self.relu(x)x = self.fc4(x)return xdef num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
net = net.cuda()

经过加深和加宽模型后结果为:
Epoch:10 Training Loss:0.447619 Accuracy:0.849860
crrent learning rate: 0.001
Epoch:10 Validation Loss:1.045875 Accuracy:0.713100
逐层归一化
如果不进行归一化,那么由于特征向量中不同特征的取值相差较大,会导致目标函数变“扁”。这样在进行梯度下降的时候,梯度的方向就会偏离最小值的方向,走很多弯路,即训练时间过长。
过拟合问题
对比训练精确度和测试精确率发现,训练精确率远大于测试精确率,即考虑出现过拟合问题。
增加dropout
为了解决过拟合问题,增加dropout层,但会增加训练时间。
正则化力度
将正则化惩罚系数1e-8改为1e-5。

调整学习率

def adjust_learning_rate(optimizer, epoch):lr = 1e-3 * (0.1) ** (epoch // 4)for param_group in optimizer.param_groups:param_group['lr'] = lr
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 30, 3)self.conv2 = nn.Conv2d(30, 300, 4)self.conv3 = nn.Conv2d(300, 600, 3)self.conv4 = nn.Conv2d(600, 1200, 2, padding=2)self.fc1 = nn.Linear(4800, 10)self.fc2 = nn.Linear(400, 10)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2)self.dropout = nn.Dropout(0.1)self.batchnorm1 = nn.BatchNorm2d(30)self.batchnorm2 = nn.BatchNorm2d(300)self.batchnorm3 = nn.BatchNorm2d(600)self.batchnorm4 = nn.BatchNorm2d(1200)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm1(x)x = self.dropout(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm2(x)x = self.dropout(x)x = self.conv3(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm3(x)x = self.dropout(x)x = self.conv4(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm4(x)x = self.dropout(x)x = x.view(x.size()[0], -1)x = self.fc1(x)# x = self.relu(x)# x = self.fc2(x)return xdef num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
net = net.cuda()

最终结果
经过以上诸多处理手段,得到最终结果为:
Epoch:20 Training Loss:0.172814 Accuracy:0.941460
crrent learning rate: 1.0000000000000002e-07
Epoch:20 Validation Loss: 0.641045 Accuracy: 0.801200

Cifar-10训练记录相关推荐

  1. ACM组队训练记录(Grooming)

    本文主要记录了本菜鸡(Chen.Jr)所在的队伍的2018年训练记录以及部分题解,以此来鼓励本蒟蒻奋发图强. Name Solved A B C D E F G H I J K L M 2018-20 ...

  2. 训练记录番外篇(2):2022 ICPC Gran Premio de Mexico 2da Fecha

    2022 ICPC Gran Premio de Mexico 2da Fecha 2022.10.3 之前训得ak场,个人认为很edu. (顺便一提,可能这个训练记录番外系列的比赛都非常edu,十分 ...

  3. ubuntu 从刷机到yolov5环境搭建训练记录

    ubuntu 从刷机到yolov5环境搭建训练记录 这两天需要一个模型检测一些摄像头内容,使用yolov5训练了一个模型,记录一下. 1. 刷机 具体步骤不描述,网上很多. 刷机时遇到一个问题,原有系 ...

  4. 2019年7月训练记录(更新ing)

    前言 本月上半月训练记录可详见:2019年暑假绍兴集训. \(Jul\ 15th\) 早上到机房先做了一道一直想做的板子题:[洛谷4781][模板]拉格朗日插值,发现拉格朗日插值也并没有想象中那么难. ...

  5. yoloV5-6.2分类训练记录

    加我微信拉你进群交流:wu331376411 一 背景介绍 yoloV6,V7相继跟新,没有想到用的最熟悉的V5又双叒叕更新了,今天我就来给大家准备分享一下yoloV5-6.2的分类训练. 二 模型下 ...

  6. 第五章:Tensorflow 2.0 利用十三层卷积神经网络实现cifar 100训练(理论+实战)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/LQ_qing/article/deta ...

  7. 台州学院maximum cow训练记录

    前队名太过晦气,故启用最大牛 我们的组队大概就是18年初,组队阵容是17级生詹志龙.陶源和16级的黄睿博. 三人大学前均无接触过此类竞赛,队伍十分年轻.我可能是我们队最菜的,我只是知道的内容最多,靠我 ...

  8. 各个数据库取前10行记录

    SQL查询前10条的方法为: 1.select top X *  from table_name --查询前X条记录,可以改成需要的数字,比如前10条. 2.select top X *  from  ...

  9. mysql查询前10条记录

    select * from no_primary_key order by id limit 10; # 显示从id=1到id=10的前10条记录: select * from no_primary_ ...

  10. sql 取表的前10条记录,任意中间几行的记录

    取表的前10条记录 with a as(select *,row_number()over(order by department)rn from _SucceedStaff ) select * f ...

最新文章

  1. go支持对函数返回值命名,可以解决函数返回值的顺序书写问题
  2. 判断字符串是否为回文(C语言 顺序栈)
  3. [Vue.js] 路由 -- 基于vue-router的案例--后台管理
  4. Vue中虚拟DOM的理解
  5. 计算机在矿山企业中的应用,计算机在矿山工业中的应用与发展
  6. raster | 多图层栅格对象的一些处理方法
  7. Linux 服务器为什么被黑
  8. beforeunload中阻止提示关闭_React 系统中,在离开编辑页面前做提示
  9. AAC AMR WAV MP3 采样率
  10. python代码混淆
  11. QQ空间自动删除说说的js脚本(亲测有效)
  12. mtk协议与qc协议_通用充电器快充协议QC2.0,QC3.0,MTK PE,PE+,充电识别
  13. MySQL: GTID简介,gtid_executed和gtid_purged概念
  14. 发现一款好用的 java web报表工具
  15. php气泡效果,ps简单制作漂亮的人物气泡效果
  16. android rom打包解包工具,Android ROM包定制(解包,增删模块,打包)
  17. 卸载的软件电脑重启后又出现了,怎么办?
  18. IE低版本提示下载新的浏览器js--IEOutTips.zip
  19. 用Windows自带工具给U盘4k对齐
  20. spark实战问题(一):is running beyond physical memory limits. Current usage: xx GB of xx GB physical memory

热门文章

  1. eclipse java access数据库连接_eclipse如何连接access数据库实现代码
  2. 安装金山WPS2013造成的HTML5 file.type值异常
  3. 向字典中相应的键增加值
  4. 职场人士开发新技能 我靠财智金解决费用
  5. tf.dynamic_stitch
  6. 星锐恒通电钢琴电子琴教学控制管理系统概述
  7. postgressql企业级数据库edb学习(一)
  8. 解除百度云浏览器端对下载大文件的限制
  9. GDUT2020新生赛——解题报告
  10. [Golang]力扣Leetcode - 374. 猜数字大小(二分查找)