图像多标签分类例子

import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from tensorboardX import SummaryWriterimport seaborn as sns
from sklearn.metrics import confusion_matrix'''数据加载'''
#选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#对三种数据集进行不同预处理,对训练数据进行加强
data_transforms = {'train': transforms.Compose([transforms.RandomRotation(30),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
}#数据目录
data_dir = "/DATA/wanghongzhi/17flowers"#获取两个数据集
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),  #要习惯python这种语法data_transforms[x]) for x in ['train', 'valid']}
traindataset = image_datasets['train']
validdataset = image_datasets['valid']batch_size = 8
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,shuffle=True, num_workers=4) for x in ['train', 'valid']}
print(dataloaders)
traindataloader = dataloaders['train']
validdataloader = dataloaders['valid']dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}'''定义网络结构'''
class Net(nn.Module):def __init__(self,model):super(Net,self).__init__()self.features = model.features# for p in self.parameters():#     p.requires_grad = Falseself.classifier = nn.Sequential(nn.Linear(25088, 4096,bias=True),nn.ReLU(inplace=True),nn.Dropout(p=0.5,inplace=False),nn.Linear(4096, 4096,bias=True),nn.ReLU(inplace=True),nn.Dropout(p=0.5,inplace=False),nn.Linear(4096, 102,bias=True))def forward(self,x):x = self.features(x)x = x.view(x.shape[0], -1)x = self.classifier(x)return xnet = models.resnet50().to(device)net.load_state_dict(torch.load('/home/wanghongzhi/zuoye/resnet50.pth'))'''参数设定'''
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),lr=0.0001,momentum=0.9)'''定义根据loss列表绘制loss曲线函数'''
def hua_loss(loss):l=len(loss)#x=list(range(1,l+1))x=range(1,l+1)# 设置图片大小plt.figure(figsize=(20,8),dpi=80) # figsize设置图片大小,dpi设置清晰度plt.title("Train-Epoch-Loss",fontsize=25)plt.xlabel("Epoch",fontsize=20)plt.ylabel("Loss",fontsize=20)plt.plot(x,loss)x_major_locator=MultipleLocator(2)  #x轴刻度为1的倍数y_major_locator=MultipleLocator(0.15) #y轴刻度为0.01的倍数ax=plt.gca() #ax为两条坐标轴的实例ax.xaxis.set_major_locator(x_major_locator)ax.yaxis.set_major_locator(y_major_locator)#保存#plt.savefig("./t1.png")plt.show()'''先定义验证集检验''' #测试集和验证集代码一模一样
def valid_model(model, criterion):best_acc = 0.0print('-' * 10)running_loss = 0.0running_corrects = 0model = model.to(device)for inputs, labels in validdataloader:inputs = inputs.to(device)labels = labels.to(device)model.eval()with torch.no_grad():outputs = model(inputs)loss = criterion(outputs, labels)print('outputs:',outputs)print('labels:',labels)_, preds = torch.max(outputs, 1)running_loss += loss.item()running_corrects += torch.sum(preds == labels).item()epoch_loss = running_loss / dataset_sizes['valid']print(running_corrects)epoch_acc = running_corrects / dataset_sizes['valid']print('{} Loss: {:.4f} Acc: {:.4f}'.format('valid', epoch_loss, epoch_acc))print('-' * 10)print()#val_loss.append(epoch_loss)'''训练模型'''
def train_model(model, criterion, optimizer, num_epochs=5):#since = time.time()best_acc = 0.0train_loss=[]#val_loss=[]for epoch in range(num_epochs):if (epoch+1)%5==0:       #每五个epoch就用该模型验证一次结果          valid_model(model, criterion)print('-' * 10)print('Epoch {}/{}'.format(epoch+1, num_epochs))running_loss = 0.0running_corrects = 0model = model.to(device)for inputs, labels in traindataloader:inputs = inputs.to(device)labels = labels.to(device)model.train()optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()_, preds = torch.max(outputs, 1)running_loss += loss.item()  #加起来用来计算每个epoch的lossrunning_corrects += torch.sum(preds == labels).item() #item()取出张量中的值,或者(predicted==labels).sum().item()epoch_loss = running_loss / dataset_sizes['train']print(dataset_sizes['train'])  #训练集总数print(running_corrects) #正确预测个数epoch_acc = running_corrects / dataset_sizes['train']best_acc = max(best_acc,epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format('train', epoch_loss, epoch_acc)) print()train_loss.append(epoch_loss)hua_loss(train_loss)print('Best val Acc: {:4f}'.format(best_acc)) return model'''开始训练'''
epochs = 5
model = train_model(net, criterion, optimizer, epochs)

参考了博客:Pytorch实现鲜花分类(102 Category Flower Dataset)

pytorch实现图像分类代码实例相关推荐

  1. 【项目实战课】人人免费可学!基于Pytorch的图像分类简单任务数据增强实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的图像分类简单任务数据增强实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的 ...

  2. 基于pytorch实现图像分类——理解自动求导、计算图、静态图、动态图、pytorch入门

    1. pytorch入门 什么是PYTORCH? 这是一个基于Python的科学计算软件包,针对两组受众: 替代NumPy以使用GPU的功能 提供最大灵活性和速度的深度学习研究平台 1.1 开发环境 ...

  3. NLP: 0基础应用T5模型进行文本翻译代码实例~

    文章目录 前言 一.目标文本是什么? 二.模型调用步骤 1.引入库 2.导入模型,本文使用 t5-base 3.使用分词器对目标文本进行分词 4.对刚刚生成的分词结果进行目标语言的生成工作 5.对生成 ...

  4. 一步步读懂Pytorch Chatbot Tutorial代码(三) - 创建字典

    文章目录 自述 有用的工具 代码出处 目录 代码 Load and trim data 类 class _ _ init _ _ 初始化实例变量 for word in sentence.split( ...

  5. pytorch实现图像分类,训练集准确率很高,测试集准确率总是很低

    在使用pytorch运行图像分类的代码的时候,发现测试集准确率总是只有30%左右, 但是训练集准确率基本可以达到80%以上,那么存在的问题可能是一下几个方面导致的: 1.学习率设置得太高,可以尽量将学 ...

  6. pytorch车牌识别代码

    我不太了解pytorch车牌识别代码,但我可以提供一些关于它的基本信息. PyTorch是一种被广泛用于图像分类.语音识别和自然语言处理的深度学习框架.它可以用来构建.训练和测试复杂的模型,其中包括车 ...

  7. Pytorch CIFAR10图像分类 数据加载与可视化篇

    Pytorch CIFAR10图像分类 数据加载与可视化篇 文章目录 Pytorch CIFAR10图像分类 数据加载与可视化篇 1.数据读取 2. 查看数据(格式,大小,形状) 3. 查看图片 np ...

  8. pytorch lstm crf 代码理解

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  9. php 3 3公派算法代码,PHP常见算法合集代码实例

    许多人都说 算法是程序的核心,一个程序的好于差,关键是这个程序算法的优劣,下面是一些常用的算法和实例,大家可以好好学习下 一.文件夹遍历 function allFile($path = __DIR_ ...

最新文章

  1. boost::test::string_cast相关的测试程序
  2. Mschart图表制作
  3. WPF学习笔记(三)
  4. qt android webview,qt browser 加载一个webview过程
  5. GitHub改为token验证后,如何提交代码?
  6. Q101:真实地模拟一个玻璃酒杯(Wine Glass)(回旋曲面)
  7. sql server 查询工具_分享一款开源的SQL查询优化工具--EverSQL
  8. 【转】心等久了就会死心
  9. ipython 安装_IPYTHON安装.DOC
  10. 大数据Hadoop学习(一)入门
  11. 一、从0开始——黑客学习路线
  12. 如何取得/etiantian文件的权限对应的数字内容,如-rw-r--r-- 为644,要求使用命令取...
  13. qt绘画事件-设置背景图片
  14. Android车辆运动轨迹大数据采集最佳实践
  15. 2023最新车道线综述!近五年文章全面盘点(几何建模/机器学习/深度学习)
  16. 微分中值定理—罗尔中值定理
  17. 实战技法 - 短线操盘 (3)
  18. linux基础09——nl
  19. CLA not signed yet
  20. 【Proteus仿真】步进电机转速数码管显示

热门文章

  1. c语言中如何让诊断代码右移_如何检测和诊断生产中的慢代码
  2. Kubernetes集群上的Apache Ignite和Spring第1部分:Spring Boot应用程序
  3. Java:使用Toxiproxy模拟各种连接问题
  4. 注释嵌套注释_注释梦Night
  5. Java Spring Security示例教程中的2种设置LDAP Active Directory身份验证的方法
  6. permgen_打破PermGen神话
  7. JavaFX真实世界应用程序:欧洲电视网广播联盟
  8. 敏捷中gwt含义_在您的GWT应用程序中添加JSON功能
  9. JDBC教程– ULTIMATE指南(PDF下载)
  10. Katas编写的Java教程:Mars Rover