ResNet网络训练与验证(二)
import os
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import model
'''
加载数据集
1.根路径 dir/train dir/val
2.数据类型 type=train val
'''
def get_dataLoader(dir,batch_size,type=None):#训练集if type=="train":#转换为tensortransform=transforms.Compose([transforms.ToTensor()])#制作数据集train_dataset=datasets.ImageFolder(os.path.join(dir,"train\\"),transform=transform)#加载数据集为loadertrain_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)return train_loaderelif type=="val":#转换为tensortransform=transforms.Compose([transforms.ToTensor()])#制作数据集val_dataset=datasets.ImageFolder(os.path.join(str,'val\\'),transform=transform)#加载数据集为loaderval_loader=DataLoader(val_dataset,batch_size=batch_size,shuffle=True)return val_loaderfrom torch import optim
from torch import nn as nn
import torch
from tqdm import tqdmif __name__ == '__main__':#设置超参数epoch_num=100lr=0.001batch_size=64#数据集根目录str=r"E:\data"#首先获取数据集train_loader=get_dataLoader(str,batch_size,"train")val_loader=get_dataLoader(str,batch_size,"val")#调用gpudevice=torch.device("cuda" if torch.cuda.is_available() else "cpu")#调用模型model=model.resnet18().to(device)#设置loss和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=lr)for epoch in range(1,epoch_num+1):#开始加载数据进行批次训练 处理后的数据按照批次加载如模型中for i,(img,label) in enumerate(tqdm(train_loader)):#将数据和标签加入到设备中img,label=img.to(device),label.to(device)#进入训练模式model.train()#梯度归零optimizer.zero_grad()#前向传播output=model(img)#计算lossloss=criterion(output,label)#反向传播loss.backward()#更新梯度optimizer.step()#每十个批次记录一下 acc和lossif i%10==0:correct=0total=0#对output进行处理,返回的值为batch_size行,类别列_,predicted=torch.max(output.data,1)#计算acc,label的格式为批次个标签值[1,1,1,1,1,1,1,1]total+=label.size(0)correct+=(predicted==label).sum()acc=(correct/total)print("[epoch:%d] iter:%d acc:%.3f loss:%.3f"%(epoch,i*batch_size,acc*100,loss.item()))#经历一个epoch使用val验证一下模型效果with torch.no_grad:correct = 0total = 0for img,label in tqdm(val_loader):#模型进入验证模式model.eval()#将图像,标签送入设备中img,label=img.to(device),label.to(device)#将图片送入模型中output=model(img)# 对output进行处理,返回的值为batch_size行,类别列_, predicted = torch.max(output.data, 1)# 计算acc,label的格式为批次个标签值[1,1,1,1,1,1,1,1]total += label.size(0)correct += (predicted == label).sum()acc = (correct / total)print("Val’s acc:%.3f " % acc * 100)
ResNet网络训练与验证(二)相关推荐
- ResNet网络的训练和预测
ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...
- ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练
1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...
- [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码
环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...
- NeurIPS 2022|基于神经微分方程理论可以帮助我们训练更加深层次的ResNet网络
原文链接:https://www.techbeat.net/article-info?id=4204 作者:seven_ 本文的重点研究对象是在视觉领域占据统治地位的残差神经网络(ResNets),R ...
- 二值网络训练--A Empirical Study of Binary Neural Networks' Optimisation
A Empirical Study of Binary Neural Networks' Optimisation ICLR2019 https://github.com/mi-lad/studyin ...
- Tensorflow 2.* 网络训练(二) fit(x, y, batch_size, epochs, verbose, validation_split, initial_epoch... )
在完成数据集合,网络搭建.以及训练编译设置以后,最后就是要开始训练(拟合)网络 tf.keras.Model.fit 如下fit的参数是相对比较多的,且参数间相互关系较为复杂 fit(x=None, ...
- 1. Resnet网络详解
一.ResNet网络介绍 ResNet是2015年有微软实验室提出的,题目: 作者是何凯明等,这四个都是华人. 当年ResNet斩获了当年各种第一名.基本上只要它参加的比赛,全都是第一名. 我们来看一 ...
- 神经网络模型训练简记(二)
神经网络模型训练简记(二) 内容简述 三.机器视觉网络模型分类及简介 3.2目标检测 3.2.1RCNN 3.2.2SPPNet 3.2.3Fast RCNN 3.2.4Faster RCNN 3.2 ...
- bottleneck resnet网络_关于ResNet网络的一点理解(网络结构、building block 及 “bottleneck” building block)...
[时间]2018.10.05 [题目]关于ResNet网络的一点理解(网络结构.building block 及 "bottleneck" building block) 概述 本 ...
最新文章
- html frameset
- java swing setborder_Swing编程边框(Border)的用法总结
- 第十五章 shell正则表达式
- 《You Only Look Once: Unified, Real-Time Object Detection》YOLO一种实时目标检测方法 阅读笔记(未完成版)
- 给Android程序员的一些面试建议,附带学习经验
- java情书_Java情书已写好,就差妹子了!
- mysql 消息队列_MYSQL模拟消息队列(转载) | 学步园
- HTML简单实例加表单的显示效果
- sklearn.metrics.roc_curve
- C语言中常见的内存相关的Bugs
- activiti 启动tomcat乱码_Activiti 流程图片显示乱码问题分析与解决
- 模拟银行排队叫号机 2011.04.18
- Rust:Match语句详解
- 2022紫光展锐数字芯片提前批笔试
- OTU的定义与解读----了解笔记
- MAC Book Pro 使用 libmodbus
- 得到网页的最新更新时间
- python读取grib格式数据
- PostgreSQL 的安装以及在安装过程中遇到的问题及解决方法
- 大饼震荡不变,新平台搭建?