经过实验发现还是封装好的resnet50鲁棒性以及运行速度和效率上比较好,但是没有数据增强很难也只能维持在84%左右,因此采用了增强训练集和原始训练集结合的方式在 50 epochs 中就达到了92%左右,但是由于train acc已经维持在99.9%了因此overfitting还是没有很好地解决。后续会继续优化hyper parameter来提高test acc。(后续会更新)

其实最终的结果就看当 train acc -> 100% 的时候 test acc 收敛值

import torchvision as tv
import numpy as np
import torch
import time
import os
from torch import nn, optim
from torchvision.models import resnet50
from torchvision.transforms import transformsos.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2"# cifar-10进行测验class Cutout(object):"""Randomly mask out one or more patches from an image.Args:n_holes (int): Number of patches to cut out of each image.length (int): The length (in pixels) of each square patch."""def __init__(self, n_holes, length):self.n_holes = n_holesself.length = lengthdef __call__(self, img):"""Args:img (Tensor): Tensor image of size (C, H, W).Returns:Tensor: Image with n_holes of dimension length x length cut out of it."""h = img.size(1)w = img.size(2)mask = np.ones((h, w), np.float32)for n in range(self.n_holes):y = np.random.randint(h)x = np.random.randint(w)y1 = np.clip(y - self.length // 2, 0, h)y2 = np.clip(y + self.length // 2, 0, h)x1 = np.clip(x - self.length // 2, 0, w)x2 = np.clip(x + self.length // 2, 0, w)mask[y1: y2, x1: x2] = 0.mask = torch.from_numpy(mask)mask = mask.expand_as(img)img = img * maskreturn imgdef load_data_cifar10(batch_size=128,num_workers=2):# 操作合集# Data augmentationtrain_transform_1 = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(degrees=(-80,80)),  # 随机角度翻转transforms.ToTensor(),transforms.Normalize((0.491339968,0.48215827,0.44653124), (0.24703233,0.24348505,0.26158768)  # 两者分别为(mean,std)),Cutout(1, 16),  # 务必放在ToTensor的后面])train_transform_2 = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])test_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.491339968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)  # 两者分别为(mean,std))])# 训练集1trainset1 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_1,)# 训练集2trainset2 = tv.datasets.CIFAR10(root='data',train=True,download=False,transform=train_transform_2,)# 测试集testset = tv.datasets.CIFAR10(root='data',train=False,download=False,transform=test_transform,)# 训练数据加载器1trainloader1 = torch.utils.data.DataLoader(trainset1,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 训练数据加载器2trainloader2 = torch.utils.data.DataLoader(trainset2,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))# 测试数据加载器testloader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=(torch.cuda.is_available()))return trainloader1,trainloader2,testloaderdef main():start = time.time()batch_size = 128cifar_train1,cifar_train2,cifar_test = load_data_cifar10(batch_size=batch_size)model = resnet50().cuda()# model.load_state_dict(torch.load('_ResNet50.pth'))# 存在已保存的参数文件# model = nn.DataParallel(model,device_ids=[0,])  # 又套一层model = nn.DataParallel(model,device_ids=[0,1,2])loss = nn.CrossEntropyLoss().cuda()optimizer = optim.Adam(model.parameters(),lr=0.001)for epoch in range(50):model.train()  # 训练时务必写loss_=0.0num=0.0# train on trainloader1(data augmentation) and trainloader2for i,data in enumerate(cifar_train1,0):x, label = datax, label = x.cuda(),label.cuda()# xp = model(x)l = loss(p,label)optimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num+=1for i, data in enumerate(cifar_train2, 0):x, label = datax, label = x.cuda(), label.cuda()# xp = model(x)l = loss(p, label)optimizer.zero_grad()l.backward()optimizer.step()loss_ += float(l.mean().item())num += 1model.eval()  # 评估时务必写print("loss:",float(loss_)/num)# test on trainloader2,testloaderwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_train2:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_1 = total_correct / total_num# Testwith torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:# [b, 3, 32, 32]# [b]x, label = x.cuda(), label.cuda()# [b, 10]logits = model(x)# [b]pred = logits.argmax(dim=1)# [b] vs [b] => scalar tensorcorrect = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)# print(correct)acc_2 = total_correct / total_numprint(epoch+1,'train acc',acc_1,'|','test acc:', acc_2)# 保存时只保存model.moduletorch.save(model.module.state_dict(),'resnet50.pth')print("The interval is :",time.time() - start)if __name__ == '__main__':main()

结果如下所示(我是用三块板跑的,效率差强人意):

loss: 1.4909098822899791
1 train acc 0.62204 | test acc: 0.6027
loss: 1.015482439866761
2 train acc 0.7517 | test acc: 0.7263
loss: 0.8184814366233318
3 train acc 0.82326 | test acc: 0.7869
loss: 0.6881685950185942
4 train acc 0.85712 | test acc: 0.8082
loss: 0.6008699957633872
5 train acc 0.8961 | test acc: 0.8348
loss: 0.5339095705496076
6 train acc 0.89792 | test acc: 0.8308
loss: 0.475647669447505
7 train acc 0.92746 | test acc: 0.849
loss: 0.42780504078435166
8 train acc 0.95982 | test acc: 0.8782
loss: 0.3862069582095003
9 train acc 0.9669 | test acc: 0.8786
loss: 0.3480941215101296
10 train acc 0.96436 | test acc: 0.8764
loss: 0.32816354770337225
11 train acc 0.97966 | test acc: 0.8849
loss: 0.30012254173988884
12 train acc 0.97432 | test acc: 0.8748
loss: 0.28094098355163777
13 train acc 0.97728 | test acc: 0.8802
loss: 0.2644972345444
14 train acc 0.9829 | test acc: 0.8851
loss: 0.25381022223207117
15 train acc 0.98758 | test acc: 0.8915
loss: 0.23610747831961726
16 train acc 0.98466 | test acc: 0.8881
loss: 0.2285361007898641
17 train acc 0.99266 | test acc: 0.898
loss: 0.2133797939302271
18 train acc 0.9927 | test acc: 0.9007
loss: 0.2053077191711687
19 train acc 0.99042 | test acc: 0.9038
loss: 0.1958171792817838
20 train acc 0.99618 | test acc: 0.9055
loss: 0.18497231054102498
21 train acc 0.99648 | test acc: 0.9098
loss: 0.18142041679822465
22 train acc 0.9949 | test acc: 0.9036
loss: 0.16904177170579116
23 train acc 0.99514 | test acc: 0.9084
loss: 0.16528357028876475
24 train acc 0.99848 | test acc: 0.9161
loss: 0.15970944137910328
25 train acc 0.99132 | test acc: 0.9062
loss: 0.1532783083411653
26 train acc 0.99694 | test acc: 0.9121
loss: 0.14528743405436353
27 train acc 0.9973 | test acc: 0.9155
loss: 0.14273732041051646
28 train acc 0.9989 | test acc: 0.92
loss: 0.13299551776372423
29 train acc 0.99862 | test acc: 0.9194
loss: 0.13098922320296202
30 train acc 0.99852 | test acc: 0.9172
loss: 0.130069030154146
31 train acc 0.99678 | test acc: 0.9152
loss: 0.12149569378277618
32 train acc 0.99404 | test acc: 0.9109
loss: 0.11743721823848764
33 train acc 0.9957 | test acc: 0.9107
loss: 0.11219285747237012
34 train acc 0.99908 | test acc: 0.9251
loss: 0.10887035607129408
35 train acc 0.99838 | test acc: 0.9165
loss: 0.10298154575131607
36 train acc 0.99818 | test acc: 0.918
loss: 0.10380000583286358
37 train acc 0.9991 | test acc: 0.9245
loss: 0.0969829042045687
38 train acc 0.99882 | test acc: 0.9215
loss: 0.0943278603326294
39 train acc 0.99842 | test acc: 0.9204
loss: 0.09098898973615739
40 train acc 0.9986 | test acc: 0.9213
loss: 0.08953166568509514
41 train acc 0.99802 | test acc: 0.9198
loss: 0.08444082959015961
42 train acc 0.99948 | test acc: 0.9248
loss: 0.08660348053458353
43 train acc 0.9993 | test acc: 0.9257
loss: 0.08018309040948728
44 train acc 0.9984 | test acc: 0.9215
loss: 0.07916377012241732
45 train acc 0.99948 | test acc: 0.9275
loss: 0.07556892135066137
46 train acc 0.99916 | test acc: 0.9276
loss: 0.07554686916106672
47 train acc 0.99938 | test acc: 0.9263
loss: 0.07314916058310844
48 train acc 0.9995 | test acc: 0.93
loss: 0.07100697338284725
49 train acc 0.99952 | test acc: 0.9281
loss: 0.06923596535431509
50 train acc 0.99946 | test acc: 0.9285
The interval is : 13397.9380569458

  • 这里借鉴了《Improved Regularization of Convolutional Neural Networks with Cutout》原论文链接如下:https://arxiv.org/abs/1708.04552

ResNet50 on cifar-10 test_acc--->92%(by data augmentation)相关推荐

  1. 10.1 Converting json to data classes

    10.1 Converting json to data classes 处理json数据,是常见的工作,解析和处理json技术含量低,考验的是细心和耐心,原始的办法的就是对着json字符串一个一个的 ...

  2. 深度学习入门——利用卷积神经网络训练CIFAR—10数据集

    CIFAR-10数据集简介 CIFAR-10是由Hinton的学生Alex Krizhevsky和Ilya Sutskever整理的一个用于普适物体的小型数据集.它一共包含10个类别的RGB彩色图片: ...

  3. print (“{0:<10}{1:>5}“.format(word, count))

    print ("{0:<10}{1:>5}".format(word, count)) 这个是format方法bai的格式控制.在duPython二级教程第三章< ...

  4. 基于SVM的思想做CIFAR 10图像分类

    #SVM 回顾一下之前的SVM,找到一个间隔最大的函数,使得正负样本离该函数是最远的,是否最远不是看哪个点离函数最远,而是找到一个离函数最近的点看他是不是和该分割函数离的最近的. 使用large ma ...

  5. Font shape `OMX/cmex/m/n‘ in size <10.53937> not available (Font) size <10.95> substituted.

    Latex在写公式时,报如下错误: Font shape `OMX/cmex/m/n' in size <10.53937> not available (Font) size <1 ...

  6. cifar 10 最高正确率

    http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html 这个网站里有MNIST数据集 ...

  7. 达梦DM8静默安装,No permission to initialize the database under >/u01/dmms/data/JMDM/control for dmdba!

    麻烦各位uu们看下~ [DM版本]:8 [操作系统]:oracle linux 7.4 [问题描述]*:使用图形化安装能安装完成.但使用静默安装,报错No permission to initiali ...

  8. 0day 第10章--10.5节:修改data中的cooki突破GS

    文章目录 实验原理: 实验环境: 实验要求: 源程序: 实验原理: 修改.data中保存的cookie,然后替换掉检查时的cookie,即可绕过对cookie的检查! 实验环境: winxp sp3 ...

  9. 【PyTorch】 99%程序员都不知道, 深度学习还能这样玩 (建议收藏)

    [PyTorch] 99%程序员都不知道, 深度学习还能这样玩 概述 迁移学习 入住 GitHub 项目详解 get_data.py (获取数据) get_model (获取模型) 参数详解 使用说明 ...

最新文章

  1. 五位工程师亲述:AI技术人才如何快速成长?
  2. 全球地区资料json 含中英文 经纬度_含乳饮料行业发展趋势及市场化程度分析(附报告目录)...
  3. Android之ndk编译出现这个错误error: unused variable ‘a‘ [-Werror=unused-variable]
  4. 栈和队列之用2个栈实现一个队列
  5. E:Modular Stability(组合数)
  6. java byte数组与String互转
  7. ELK+filebeat+kafka+zookeeper构建海量日志分析平台
  8. 华为mate20云备份恢复卡住了_注意了!包括华为、荣耀在内的14款老机型开启EMUI11公测了...
  9. Flink典型应用场景
  10. python的源代码下载_官方下载python源码,编译linux版本的python
  11. Matlab 心形函数
  12. python输入球的半径计算球的表面积和体积_球扇形(球心角体)体积,表面积计算公式与在线计算器_三贝计算网_23bei.com...
  13. 线性代数笔记29——正定矩阵和最小值
  14. 2,理论_滑杆_棘轮_间歇运行机构
  15. python计算球体体积_如何在Python中用MonteCarloMethod计算10维球体的体积?
  16. 客户体验决胜2022,低代码是快速取胜之道
  17. ubuntu20.04 nvidia 460显卡安装
  18. Uncaught DONException: Failed to execute ‘atob‘ on “window ‘: The string to be decoded is not carrec
  19. centos下Intel核显应用ffmpeg的qsv插件编解码
  20. C语言输入三个数字判断大小

热门文章

  1. 来自 PDE 的残差神经网络
  2. npm-whoami
  3. neo4j 不识别eneity中的属性
  4. 读书笔记-第五项修炼
  5. Cnetos7系统---文件压缩与解压命令详解。
  6. Kickstarter JSON数据解析,保存csv
  7. 【CSS】CSS 背景设置 ⑦ ( 背景简写 )
  8. 教授专栏59 | 刘佳:备受瞩目的瞬间营销,如何触发效用最大化?
  9. Base封装(一)--我的最简MVP架构
  10. STM32F0系列内部高速时钟的配置方法