import os
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import model
'''
加载数据集
1.根路径 dir/train dir/val
2.数据类型 type=train val
'''
def get_dataLoader(dir,batch_size,type=None):#训练集if type=="train":#转换为tensortransform=transforms.Compose([transforms.ToTensor()])#制作数据集train_dataset=datasets.ImageFolder(os.path.join(dir,"train\\"),transform=transform)#加载数据集为loadertrain_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)return train_loaderelif type=="val":#转换为tensortransform=transforms.Compose([transforms.ToTensor()])#制作数据集val_dataset=datasets.ImageFolder(os.path.join(str,'val\\'),transform=transform)#加载数据集为loaderval_loader=DataLoader(val_dataset,batch_size=batch_size,shuffle=True)return val_loaderfrom torch import optim
from torch import nn as nn
import torch
from tqdm import tqdmif __name__ == '__main__':#设置超参数epoch_num=100lr=0.001batch_size=64#数据集根目录str=r"E:\data"#首先获取数据集train_loader=get_dataLoader(str,batch_size,"train")val_loader=get_dataLoader(str,batch_size,"val")#调用gpudevice=torch.device("cuda" if torch.cuda.is_available() else "cpu")#调用模型model=model.resnet18().to(device)#设置loss和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=lr)for epoch in range(1,epoch_num+1):#开始加载数据进行批次训练 处理后的数据按照批次加载如模型中for i,(img,label) in enumerate(tqdm(train_loader)):#将数据和标签加入到设备中img,label=img.to(device),label.to(device)#进入训练模式model.train()#梯度归零optimizer.zero_grad()#前向传播output=model(img)#计算lossloss=criterion(output,label)#反向传播loss.backward()#更新梯度optimizer.step()#每十个批次记录一下 acc和lossif i%10==0:correct=0total=0#对output进行处理,返回的值为batch_size行,类别列_,predicted=torch.max(output.data,1)#计算acc,label的格式为批次个标签值[1,1,1,1,1,1,1,1]total+=label.size(0)correct+=(predicted==label).sum()acc=(correct/total)print("[epoch:%d] iter:%d  acc:%.3f loss:%.3f"%(epoch,i*batch_size,acc*100,loss.item()))#经历一个epoch使用val验证一下模型效果with torch.no_grad:correct = 0total = 0for img,label in tqdm(val_loader):#模型进入验证模式model.eval()#将图像,标签送入设备中img,label=img.to(device),label.to(device)#将图片送入模型中output=model(img)# 对output进行处理,返回的值为batch_size行,类别列_, predicted = torch.max(output.data, 1)# 计算acc,label的格式为批次个标签值[1,1,1,1,1,1,1,1]total += label.size(0)correct += (predicted == label).sum()acc = (correct / total)print("Val’s acc:%.3f " % acc * 100)

ResNet网络训练与验证(二)相关推荐

  1. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

  2. ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练

    1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...

  3. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  4. NeurIPS 2022|基于神经微分方程理论可以帮助我们训练更加深层次的ResNet网络

    原文链接:https://www.techbeat.net/article-info?id=4204 作者:seven_ 本文的重点研究对象是在视觉领域占据统治地位的残差神经网络(ResNets),R ...

  5. 二值网络训练--A Empirical Study of Binary Neural Networks' Optimisation

    A Empirical Study of Binary Neural Networks' Optimisation ICLR2019 https://github.com/mi-lad/studyin ...

  6. Tensorflow 2.* 网络训练(二) fit(x, y, batch_size, epochs, verbose, validation_split, initial_epoch... )

    在完成数据集合,网络搭建.以及训练编译设置以后,最后就是要开始训练(拟合)网络 tf.keras.Model.fit 如下fit的参数是相对比较多的,且参数间相互关系较为复杂 fit(x=None, ...

  7. 1. Resnet网络详解

    一.ResNet网络介绍 ResNet是2015年有微软实验室提出的,题目: 作者是何凯明等,这四个都是华人. 当年ResNet斩获了当年各种第一名.基本上只要它参加的比赛,全都是第一名. 我们来看一 ...

  8. 神经网络模型训练简记(二)

    神经网络模型训练简记(二) 内容简述 三.机器视觉网络模型分类及简介 3.2目标检测 3.2.1RCNN 3.2.2SPPNet 3.2.3Fast RCNN 3.2.4Faster RCNN 3.2 ...

  9. bottleneck resnet网络_关于ResNet网络的一点理解(网络结构、building block 及 “bottleneck” building block)...

    [时间]2018.10.05 [题目]关于ResNet网络的一点理解(网络结构.building block 及 "bottleneck" building block) 概述 本 ...

最新文章

  1. html frameset
  2. java swing setborder_Swing编程边框(Border)的用法总结
  3. 第十五章 shell正则表达式
  4. 《You Only Look Once: Unified, Real-Time Object Detection》YOLO一种实时目标检测方法 阅读笔记(未完成版)
  5. 给Android程序员的一些面试建议,附带学习经验
  6. java情书_Java情书已写好,就差妹子了!
  7. mysql 消息队列_MYSQL模拟消息队列(转载) | 学步园
  8. HTML简单实例加表单的显示效果
  9. sklearn.metrics.roc_curve
  10. C语言中常见的内存相关的Bugs
  11. activiti 启动tomcat乱码_Activiti 流程图片显示乱码问题分析与解决
  12. 模拟银行排队叫号机 2011.04.18
  13. Rust:Match语句详解
  14. 2022紫光展锐数字芯片提前批笔试
  15. OTU的定义与解读----了解笔记
  16. MAC Book Pro 使用 libmodbus
  17. 得到网页的最新更新时间
  18. python读取grib格式数据
  19. PostgreSQL 的安装以及在安装过程中遇到的问题及解决方法
  20. 大饼震荡不变,新平台搭建?

热门文章

  1. Kernel 4.9的BBR拥塞控制算法与锐速
  2. Exchange Server 2013 DAG高可用部署(一)-前期准备
  3. 知乎Markdown文件中的公式问题记录
  4. html 文字段落编辑,美化html段落文本 Ⅰ
  5. .netcore-线程池饿死问题分析(CPU空闲,并发量大时请求超时)
  6. 【转】四年记——身在中小企业
  7. servlce和tomcat
  8. Go开发的两个小应用
  9. 【人脸识别】解析MS-Celeb-1M人脸数据集及FaceImageCroppedWithAlignment.tsv文件提取
  10. 向量空间模型简介及算法