Cifar-10训练记录
目录
- cifar-10介绍
- 构建模型
- 优化
cifar-10介绍
任务是对10个类别的对象进行分类,使用cifar-10数据集。cifar-10数据集共有60000张彩色图像,大小为32 * 32 * 3,一共有10个类别,每个类别6000张。其中50000张用于训练,10000用于测试。
构建模型
首先导入必要的包
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import collections
# 使用gpu训练和测试
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
数据读入和加载
下载并使用PyTorch提供的内置数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 将数据转换为tensor格式并归一化到(-1, 1)区间
train_data = datasets.CIFAR10(root='./', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
# 读取数据集
train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=4, shuffle=False, num_workers=4)
# 定义DataLoader加载数据,每批次读入数据为batch_size
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 对应类别labels
可视化操作
def imshow(img):img = img / 2 + 0.5npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
images, labels = next(iter(train_loader))
print(images.shape, labels.shape)
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s'%classes[labels[j]] for j in range(4)))
模型设计
由于任务较为简单,我们搭建一个CNN,模型构建完成后,将模型放在GPU上用于训练。
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu((self.conv2(x))))x = x.view(-1, 16*5*5)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
net = net.cuda()
设定损失函数和优化器
- 使用torch.nn模块自带的交叉熵损失函数。
- 使用Adam优化器,先设置学习率为1e-3,惩罚系数1e-8。
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-7)
训练
def train(epoch):# 设置模型状态为trainnet.train()train_loss = 0num_pred = 0for image, label in train_loader:image, label = image.cuda(), label.cuda()optimizer.zero_grad()out_put = net(image)loss = criterion(out_put, label)loss.backward()optimizer.step()train_loss += loss * len(image)preds = torch.argmax(out_put, 1)num_pred += np.sum(preds.cpu().numpy()==label.cpu().numpy())train_loss /= len(train_loader.dataset)accuracy = num_pred / len(train_loader.dataset)print('Epoch:{}\tTraining Loss:{:.6f}\tTraining Accuracy:{:.6f}'.format(epoch, train_loss, accuracy))
测试
def val(epoch):print('crrent learning rate: ', optimizer.state_dict()["param_groups"][0]["lr"])net.eval()class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))val_loss = 0gt_labels = []pred_labels = []with torch.no_grad():for data, label in test_loader:data, label = data.cuda(), label.cuda()output = net(data)preds = torch.argmax(output, 1)c = (preds == label).squeeze()for i in range(4):label1 = label[i]class_correct[label1] += c[i].item()class_total[label1] += 1gt_labels.append(label.cpu().data.numpy())pred_labels.append(preds.cpu().data.numpy())loss = criterion(output, label)val_loss += loss.item()*data.size(0)val_loss = val_loss / len(test_loader.dataset)gt_labels, pred_labels = np.concatenate(gt_labels), np.concatenate(pred_labels)acc = np.sum(gt_labels == pred_labels) / len(pred_labels)print(F'Epoch:{epoch} \tValidation Loss: {val_loss:6f} , Accuracy: {acc:6f}')for i in range(10):print("Accuracy of %5s : %2d %%" %(classes[i], 100*class_correct[i] / class_total[i]))
模型搭建完毕,开始训练和测试,先进行10个轮次。
for epoch in range(10):train(epoch+1)val(epoch+1)
结果
Epoch:10 Training Loss:0.855442 Accuracy:0.694800
crrent learning rate: 0.001
Epoch:10 Validation Loss:1.158405 Accuracy:0.618900
发现训练和测试精确率都不满意。
优化
加深模型
由于测试精确率较低,我们考虑采用加深模型来提高精确率,将原本的三层卷积加深为四层。
加宽模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3)self.conv2 = nn.Conv2d(16, 80, 4)self.conv3 = nn.Conv2d(80, 400, 3)self.conv4 = nn.Conv2d(400, 800, 2, padding=2)self.fc1 = nn.Linear(3200, 400)self.fc2 = nn.Linear(400, 120)self.fc3 = nn.Linear(120, 84)self.fc4 = nn.Linear(84, 10)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = self.conv3(x)x = self.relu(x)x = self.pool(x)x = self.conv4(x)x = self.relu(x)x = self.pool(x)x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)x = self.relu(x)x = self.fc4(x)return xdef num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
net = net.cuda()
经过加深和加宽模型后结果为:
Epoch:10 Training Loss:0.447619 Accuracy:0.849860
crrent learning rate: 0.001
Epoch:10 Validation Loss:1.045875 Accuracy:0.713100
逐层归一化
如果不进行归一化,那么由于特征向量中不同特征的取值相差较大,会导致目标函数变“扁”。这样在进行梯度下降的时候,梯度的方向就会偏离最小值的方向,走很多弯路,即训练时间过长。
过拟合问题
对比训练精确度和测试精确率发现,训练精确率远大于测试精确率,即考虑出现过拟合问题。
增加dropout
为了解决过拟合问题,增加dropout层,但会增加训练时间。
正则化力度
将正则化惩罚系数1e-8改为1e-5。
调整学习率
def adjust_learning_rate(optimizer, epoch):lr = 1e-3 * (0.1) ** (epoch // 4)for param_group in optimizer.param_groups:param_group['lr'] = lr
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 30, 3)self.conv2 = nn.Conv2d(30, 300, 4)self.conv3 = nn.Conv2d(300, 600, 3)self.conv4 = nn.Conv2d(600, 1200, 2, padding=2)self.fc1 = nn.Linear(4800, 10)self.fc2 = nn.Linear(400, 10)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2)self.dropout = nn.Dropout(0.1)self.batchnorm1 = nn.BatchNorm2d(30)self.batchnorm2 = nn.BatchNorm2d(300)self.batchnorm3 = nn.BatchNorm2d(600)self.batchnorm4 = nn.BatchNorm2d(1200)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm1(x)x = self.dropout(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm2(x)x = self.dropout(x)x = self.conv3(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm3(x)x = self.dropout(x)x = self.conv4(x)x = self.relu(x)x = self.pool(x)x = self.batchnorm4(x)x = self.dropout(x)x = x.view(x.size()[0], -1)x = self.fc1(x)# x = self.relu(x)# x = self.fc2(x)return xdef num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
net = net.cuda()
最终结果
经过以上诸多处理手段,得到最终结果为:
Epoch:20 Training Loss:0.172814 Accuracy:0.941460
crrent learning rate: 1.0000000000000002e-07
Epoch:20 Validation Loss: 0.641045 Accuracy: 0.801200
Cifar-10训练记录相关推荐
- ACM组队训练记录(Grooming)
本文主要记录了本菜鸡(Chen.Jr)所在的队伍的2018年训练记录以及部分题解,以此来鼓励本蒟蒻奋发图强. Name Solved A B C D E F G H I J K L M 2018-20 ...
- 训练记录番外篇(2):2022 ICPC Gran Premio de Mexico 2da Fecha
2022 ICPC Gran Premio de Mexico 2da Fecha 2022.10.3 之前训得ak场,个人认为很edu. (顺便一提,可能这个训练记录番外系列的比赛都非常edu,十分 ...
- ubuntu 从刷机到yolov5环境搭建训练记录
ubuntu 从刷机到yolov5环境搭建训练记录 这两天需要一个模型检测一些摄像头内容,使用yolov5训练了一个模型,记录一下. 1. 刷机 具体步骤不描述,网上很多. 刷机时遇到一个问题,原有系 ...
- 2019年7月训练记录(更新ing)
前言 本月上半月训练记录可详见:2019年暑假绍兴集训. \(Jul\ 15th\) 早上到机房先做了一道一直想做的板子题:[洛谷4781][模板]拉格朗日插值,发现拉格朗日插值也并没有想象中那么难. ...
- yoloV5-6.2分类训练记录
加我微信拉你进群交流:wu331376411 一 背景介绍 yoloV6,V7相继跟新,没有想到用的最熟悉的V5又双叒叕更新了,今天我就来给大家准备分享一下yoloV5-6.2的分类训练. 二 模型下 ...
- 第五章:Tensorflow 2.0 利用十三层卷积神经网络实现cifar 100训练(理论+实战)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/LQ_qing/article/deta ...
- 台州学院maximum cow训练记录
前队名太过晦气,故启用最大牛 我们的组队大概就是18年初,组队阵容是17级生詹志龙.陶源和16级的黄睿博. 三人大学前均无接触过此类竞赛,队伍十分年轻.我可能是我们队最菜的,我只是知道的内容最多,靠我 ...
- 各个数据库取前10行记录
SQL查询前10条的方法为: 1.select top X * from table_name --查询前X条记录,可以改成需要的数字,比如前10条. 2.select top X * from ...
- mysql查询前10条记录
select * from no_primary_key order by id limit 10; # 显示从id=1到id=10的前10条记录: select * from no_primary_ ...
- sql 取表的前10条记录,任意中间几行的记录
取表的前10条记录 with a as(select *,row_number()over(order by department)rn from _SucceedStaff ) select * f ...
最新文章
- go支持对函数返回值命名,可以解决函数返回值的顺序书写问题
- 判断字符串是否为回文(C语言 顺序栈)
- [Vue.js] 路由 -- 基于vue-router的案例--后台管理
- Vue中虚拟DOM的理解
- 计算机在矿山企业中的应用,计算机在矿山工业中的应用与发展
- raster | 多图层栅格对象的一些处理方法
- Linux 服务器为什么被黑
- beforeunload中阻止提示关闭_React 系统中,在离开编辑页面前做提示
- AAC AMR WAV MP3 采样率
- python代码混淆
- QQ空间自动删除说说的js脚本(亲测有效)
- mtk协议与qc协议_通用充电器快充协议QC2.0,QC3.0,MTK PE,PE+,充电识别
- MySQL: GTID简介,gtid_executed和gtid_purged概念
- 发现一款好用的 java web报表工具
- php气泡效果,ps简单制作漂亮的人物气泡效果
- android rom打包解包工具,Android ROM包定制(解包,增删模块,打包)
- 卸载的软件电脑重启后又出现了,怎么办?
- IE低版本提示下载新的浏览器js--IEOutTips.zip
- 用Windows自带工具给U盘4k对齐
- spark实战问题(一):is running beyond physical memory limits. Current usage: xx GB of xx GB physical memory
热门文章
- eclipse java access数据库连接_eclipse如何连接access数据库实现代码
- 安装金山WPS2013造成的HTML5 file.type值异常
- 向字典中相应的键增加值
- 职场人士开发新技能 我靠财智金解决费用
- tf.dynamic_stitch
- 星锐恒通电钢琴电子琴教学控制管理系统概述
- postgressql企业级数据库edb学习(一)
- 解除百度云浏览器端对下载大文件的限制
- GDUT2020新生赛——解题报告
- [Golang]力扣Leetcode - 374. 猜数字大小(二分查找)