测试ResNet预训练模型在ImageNet验证集上的准确率结果

import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as Fimport os
import shutil
import argparse
import numpy as npimport torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as modelsfrom bisect import bisect_right
import time
import mathparser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', default='/dev/shm/ImageNet/', type=str, help='trainset directory')
parser.add_argument('--dataset', default='ImageNet', type=str, help='Dataset name')
parser.add_argument('--arch', default='resnet50', type=str, help='network architecture')
parser.add_argument('--batch-size', type=int, default=256, help='batch size')
parser.add_argument('--num-workers', type=int, default=8, help='number workers')
parser.add_argument('--pretrained', default='./resnet50-19c8e357.pth', type=str, help='pretrained weights')
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--manual_seed', type=int, default=0)# global hyperparameter set
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpunp.random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed)num_classes = 1000test_set = datasets.ImageFolder(os.path.join(args.data, 'val'),transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
]))testloader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False,num_workers=args.num_workers, pin_memory=True)
# --------------------------------------------------------------------------------------------# Model
print('==> Building model..')net = models.resnet18(pretrained=True).cuda()
cudnn.benchmark = Truedef correct_num(output, target, topk=(1,)):"""Computes the precision@k for the specified values of k"""maxk = max(topk)batch_size = target.size(0)_, pred = output.topk(maxk, 1, True, True)correct = pred.eq(target.view(-1, 1).expand_as(pred))res = []for k in topk:correct_k = correct[:, :k].float().sum()res.append(correct_k.item())return resnet.eval()
correct1 = 0
correct5 = 0
total = 0
sum_time = time.time()
with torch.no_grad():batch_start_time = time.time()for batch_idx, (inputs, target) in enumerate(testloader):inputs, target = inputs.cuda(), target.cuda()logits = net(inputs)print('batch_idx:{}/{}, Duration:{:.2f}'.format(batch_idx, len(testloader), time.time()-batch_start_time))batch_start_time = time.time()prec1, prec5 = correct_num(logits, target, topk=(1, 5))correct1 += prec1correct5 += prec5total += target.size(0)acc1 = round(correct1/total, 4)acc5 = round(correct5/total, 4)print('Test accuracy_1:{:.4f}\n''Test accuracy_5:{:.4f}\n'.format(acc1, acc5))print('avg times:', (time.time()-sum_time)/50000)

测试ResNet在ImageNet验证集上的准确率相关推荐

  1. paddle静态图训练,训练集和测试集效果都有很好,但验证集上效果很差

    在paddle静态图训练中,训练集和测试集效果都有很好,但验证集上效果很差 在paddle的训练中,如果使用这样的方式进行训练 main_program = fluid.default_main_pr ...

  2. 为什么神经网络模型在测试集上的准确率高于训练集上的准确率?

    为什么神经网络模型在测试集上的准确率高于训练集上的准确率? 种花家的奋斗兔 2020-03-21 17:28:37  5847  已收藏 11 分类专栏: Deep Learning 文章标签: dr ...

  3. 深度神经网络训练过程中为什么验证集上波动很大_图神经网络的新基准

    作者 | 李光明 编辑 | 贾 伟 编者注:本文解读论文与我们曾发文章<Bengio 团队力作:GNN 对比基准横空出世,图神经网络的「ImageNet」来了>所解读论文,为同一篇,不同作 ...

  4. SourceChangeWarning:验证集上准确率很高,但是测试集上很低

    https://yan624.github.io/posts/cb7d01da.html

  5. ImageNet验证集6%的标签都是错的,MIT:十大常用数据集没那么靠谱

    作者|张倩.小舟 来源|机器之心 把老虎标成猴子,把青蛙标成猫,把码头标成纸巾--MIT.Amazon 的一项研究表明,ImageNet 等十个主流机器学习数据集的测试集平均错误率高达 3.4%. 我 ...

  6. yolov4实现口罩佩戴检测,在验证集上做到了0.954的mAP

    向AI转型的程序员都关注了这个号

  7. 【深度学习】训练集、验证集、测试集

    训练集:使用训练集来对某个网络模型进行训练,使用梯度下降法来更新普通参数,如权重和偏置. 验证集:使用验证集来对训练集训练的模型调节他的超参数(如:网络层数.网络节点数.迭代次数.学习率.正则化参数) ...

  8. [深度学习-TF2实践]应用Tensorflow2.x训练ResNet,SeNet和Inception模型在cifar10,测试集上准确率88.6%

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  9. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

最新文章

  1. java取字符串中不相同的字母_java 判断两个字符串是否为相同字母异序词 --- 记录...
  2. c语言javapython哪个好-C#、C++、Java、Python 选择哪个好?
  3. android自定义抽奖,Android自定义view制作抽奖转盘
  4. wxWidgets:wxControl类用法
  5. Win11开始菜单没反应怎么办 Win11开始菜单点了没反应解决方法
  6. Ext.grid.Panel一定要有renderTo或autoRender属性,不然页面为空
  7. 【ACM】括号配对问题 - 栈
  8. 用编译安装搭建自己的http服务器
  9. 设置php语言,PHP语言之php-fpm 基本设置与启动
  10. Java基础笔记-String类
  11. ARM64移动处理器解惑
  12. xposed模块编写教程_太极xposed模块使用教程
  13. [安全论文翻译] Analysis of Location Data Leakage in the Internet Traffic of Android-based Mobile
  14. 海明码java编程,海明码校验程序设计
  15. Android usb otg通讯总结 HiD通讯直接来取吧
  16. 解决论文写作排版中,两端对齐导致文字间距被word补过大的问题
  17. 在ASP.NET中轻松实现加密
  18. Windows下卸载CUDA10.2
  19. Python实现图像垂直翻转
  20. 解决rotatedRectangleIntersection计算目标检测旋转框IOU不准确问题C++、opencv

热门文章

  1. Lazada打造爆款秘籍
  2. 中关村“黑马程序员”训练营
  3. 阿里云服务器安装WordPress,搭建自己的博客网站
  4. java自动化测试语言高级之多线程编程
  5. 琳幼儿园同学-育扬牧童星辰✨ 中一班
  6. scada与MySQL连接_SCADA系统数据库连接功能设计及应用
  7. 利用 AI 跟踪和优化视频质量
  8. Python mariadb
  9. 汇编中的通用基础寄存器ax,bx,cx,dx等的含意及作用解释
  10. 青山清水静心情 下联是...