pytorch 车型分类代码 (1)

  • 可以跑通,代码简洁
  • 学习自用
  • 持续更新

1 数据

1.1 文件配置

数据需要按照下面的方式组织,在数据文件夹中需要有train和val文件夹,然后每个文件夹内中,每个类别放在一个文件夹即可,可以直接利用pytorch自带的datasets.ImageFolder函数加载数据,image有点像imagenet 分类配置。* 这种凡方式也有弊端,一般业务场景数据是通过一个image_list.txt 文件加载,其中包含图片地址和label,调整类别映射的数字不方便*

1.2 数据增强

- 使用torch自带的transform函数,过程如下
- 随机裁剪
- 水平翻转
- 归一化(像素值编程0到1之间,并且转成torch张量)
- 标准化(减去均值,除以方差)```python
train_transforms = transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),  # 随机裁剪到256*256transforms.RandomRotation(degrees=15),  # 随机旋转transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.CenterCrop(size=224),  # 中心裁剪到224*224transforms.ToTensor(),  # 转化成张量,#归一化transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])  # 标准化])
```

1.3数据加载

1.3.1 使用ImageFolder加载数据

 train_datasets = datasets.ImageFolder(TRAIN_DATASET_DIR, transform=train_transforms)train_dataloader = torch.utils.data.DataLoader(train_datasets,batch_size=TRAIN_BATCH_SIZE,shuffle=True,**kwargs)
  • 查看数据类别映射

    • 通过 datasets.ImageFolder 可以查看类别映射 -
   print(train_datasets.class_to_idx)>> {'SUV': 0, 'bus': 1, 'family sedan': 2, 'fire engine': 3, 'heavy truck': 4, 'jeep': 5, 'minibus': 6, 'racing car': 7, 'taxi': 8, 'truck': 9}print(train_datasets.classes)>>['SUV', 'bus', 'family sedan', 'fire engine', 'heavy truck', 'jeep', 'minibus', 'racing car', 'taxi', 'truck']

1.3.2 自己定义数据结构

class MY_DATASET(Dataset):def __init__(self, train, train_data_list_txt, val_data_list_txt, transform,target_transform):self.train = trainself.train_data_list_txt = train_data_list_txtself.val_data_list_txt = val_data_list_txtself.transform = transformself.target_transform = target_transformif self.train:self.data_file = train_data_list_txtelse:self.data_file = val_data_list_txtself.data_list = self.read_data_list(self.data_file)def read_data_list(self, data_path):print("data_path", data_path)assert os.path.exists(data_path)with open(data_path, "r") as fp:print("data_path", data_path)data_list = fp.readlines()return data_listdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img_path, target = self.data_list[index].strip().split("  ")target = int(target)# doing this so that it is consistent with all other datasets# to return a PIL Imageif not os.path.exists(img_path):print("{} is not exists!!!!!!!!!!!".format(img_path))raiseimg = cv2.imread(img_path)img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))if self.transform is not None:img = self.transform(img)if self.target_transform is not None:self.target = self.target_transform(target)return img, targetdef __len__(self):return len(self.data_list)

1.4 数据结构


2 代码

"""
仅用于学习********类别说明**********>
0,巴士,bus
1,出租车,taxi
2,货车,truck
3,家用轿车,family sedan
4,面包车,minibus
5,吉普车,jeep
6,运动型多功能车,SUV
7,重型货车,heavy truck
8,赛车,racing car
9,消防车,fire engine
"""
import torch
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
from  torchvision import models
import torch.nn as nn
from torch import optim#****************setting*******************
NUM_CLASSES = 10
BATCH_SIZE = 32
NUM_EPOCHS= 25
#下载地址:https://download.pytorch.org/models/resnet50-19c8e357.pth
PRETRAINED_MODEL = './resnet50-19c8e357.pth'
MODEL_SAVE_PATH = 'trained_models/vehicle-10_record.pth'#数据集的存放位置
TRAIN_DATASET_DIR = r'F:/car_class10_data/train'
VALID_DATASET_DIR = r'F:/car_class10_data/val'
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 128
DROPOUT_RATE = 0.3
show_interval_num = 10
epochs = 20#此处数据是分别放在10个文件夹中#****************设置数据增强方式**************************
#针对训练集train_data
train_transforms = transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),#随机裁剪到256*256transforms.RandomRotation(degrees=15),#随机旋转transforms.RandomHorizontalFlip(),#随机水平翻转transforms.CenterCrop(size=224),#中心裁剪到224*224transforms.ToTensor(),#转化成张量,#归一化transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])#标准化
])
#针对测试集,test data,测试就不需要随机中心裁剪了,直接resize到224*224
test_valid_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
#*************************通过pytorch自带的dataload加载数据**************************
#关于dataload 可以查看 https://blog.csdn.net/weixin_40123108/article/details/85099449
# ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图,详细的可以去了解这个类ImageFolder,主要关注__getitem__函数,该函数会根据索引返回每张图和label
# 关于torch.utils.data.DataLoader,数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
# 在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
# 可以参考学习 https://zhuanlan.zhihu.com/p/28200166    https://www.jb51.net/article/184042.htm
train_datasets = datasets.ImageFolder(TRAIN_DATASET_DIR, transform=train_transforms)
train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
train_data_size = len(train_datasets)
valid_datasets = datasets.ImageFolder(VALID_DATASET_DIR,transform=test_valid_transforms)
valid_dataloader = torch.utils.data.DataLoader(valid_datasets, batch_size=TEST_BATCH_SIZE, shuffle=True)
valid_data_size = len(valid_datasets)
#****************************可以通过运行test_data函数查看数据类型**************************************def test_data():print("train_dataloade len", len(train_dataloader))for images, labels in train_dataloader:print(labels)print("label len", len(labels))img = images[0]img = img.numpy()img = np.transpose(img, (1, 2, 0))plt.imshow(img)plt.show()break
#*******************使用预训练模型 resnet50进行fineturn**************************
#修改最后一层fc
def resnet50():model = models.resnet50(pretrained=True)for param in model.parameters():param.requires_grad = Falsefc_inputs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(fc_inputs, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, 10),nn.LogSoftmax(dim=1))return model
#********************定义损失函数和优化器*********************
loss_func = nn.NLLLoss()
optimizer = optim.Adam(resnet50.parameters())
#**********************定义训练和验证过程***************************
def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % show_interval_num == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
#****************************定义main*****************************
def main():use_cuda = torch.cuda.is_available()device = torch.device("cuda" if use_cuda else "cpu")kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}train_transforms = transforms.Compose([transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),  # 随机裁剪到256*256transforms.RandomRotation(degrees=15),  # 随机旋转transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.CenterCrop(size=224),  # 中心裁剪到224*224transforms.ToTensor(),  # 转化成张量,#归一化transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])  # 标准化])# 针对测试集,test data,测试就不需要随机中心裁剪了,直接resize到224*224test_valid_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])train_datasets = datasets.ImageFolder(TRAIN_DATASET_DIR, transform=train_transforms)train_dataloader = torch.utils.data.DataLoader(train_datasets,batch_size=TRAIN_BATCH_SIZE,shuffle=True,**kwargs)valid_datasets = datasets.ImageFolder(VALID_DATASET_DIR, transform=test_valid_transforms)valid_dataloader = torch.utils.data.DataLoader(valid_datasets,batch_size=TEST_BATCH_SIZE,shuffle=True,**kwargs)model = resnet50().to(device)optimizer = optim.Adam(resnet50.parameters(), lr=0.001)#***********************print flops and params**************************for epoch in range(1, epochs + 1):train(model, device, train_dataloader, optimizer, epoch)test_acc = test(model, device, valid_dataloader)# report intermediate resultprint('test accuracy %g', test_acc)# report final resultprint('Final result is %g', test_acc)
def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)# sum up batch losstest_loss += F.nll_loss(output, target, reduction='sum').item()# get the index of the max log-probabilitypred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), accuracy))return accuracyif __name__ == "__main__":test_data()pass

3结果

pytorch 车型分类代码相关推荐

  1. pytorch bert文本分类_一起读Bert文本分类代码 (pytorch篇 四)

    Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了.这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码exa ...

  2. 《pytorch车型细分类网络》的源码

    说明:<pytorch车型细分类网络>.这篇文章代码有错误.我稍微调整了一下,可以正常跑了. 标题:pytorch动手实践:pytorch车型细分类网络 1)讲解,代码,主要参考知乎文章& ...

  3. pytorch 音频分类_Pytorch中音频的神经风格转换

    pytorch 音频分类 They've been some really interesting applications of style transfer. It basically aims ...

  4. Pytorch音频分类

    pytorch实现音频分类代码 这两天学习了下pytorch,动手练习练习 数据集:来源是KAGGLE的一个音频分类的比赛 数据集介绍:(需要梯子)https://urbansounddataset. ...

  5. 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)

    实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...

  6. [深度应用]·实战掌握PyTorch图片分类简明教程

    [深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn 项目GitHub地址--> https://github.com/ ...

  7. pytorch lstm crf 代码理解

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  8. php修改新闻分类代码,完整的新闻无限级分类代码,可添加,删除,移动,修改

    //连接数据库教程 $link = mysql教程_connect('localhost','root','密码') or die(mysql_error()); mysql_select_db('s ...

  9. Resnet的pytorch官方实现代码解读

    Resnet的pytorch官方实现代码解读 目录 Resnet的pytorch官方实现代码解读 前言 概述 34层网络结构的"平原"网络与"残差"网络的结构图 ...

最新文章

  1. Appium+python自动化19-iOS模拟器(iOS Simulator)安装自家APP
  2. LaTeX 简介与安装
  3. mysql需要vc_VC连接MySql
  4. 无法创建 SSIS 运行时对象。请验证 DTS.dll 是否可用及是否已注册。此向导无法继续而将终止。 (SQL
  5. [BZOJ 2434][Noi2011]阿狸的打字机(AC自动机+树状数组+dfs序)
  6. UNIX环境高级编程——Linux终端设备详解
  7. python发微信工资条_帮公司财务妹子写了个“群发工资条”的Python脚本!
  8. python3,判断,循环练习1
  9. 做B/S的朋友注意了。。。(又一先进的武器出现了)
  10. HDU 5950 Recursive sequence
  11. Session 工作原理
  12. linux 播放器系统,Linux 中的十大开源视频播放器
  13. 人工智能生物学深度解析,附源代码
  14. 企业全面运营管理沙盘模拟心得_企业经营沙盘模拟心得总结
  15. Scratch3.0 保存缩略图
  16. HDB3码编码规则通俗易懂讲解
  17. 使用阿里云服务器搭建自己的个人网站
  18. opengl光照效果之点光源
  19. 计算机内存条如何区分频率,什么是内存条的频率? 怎么看内存条频率?
  20. 蓝桥杯C/C++B组历届真题刷题【合集】

热门文章

  1. 麦克风音频服务器未响应,耳机和麦克风都没坏,插上电脑后为什么不能语音聊天?...
  2. Ruckus 无线路由器被曝多个严重漏洞
  3. 原创:tar 解压安装zabbix-agentyum源安装zabbix-agent
  4. 潘凯:C++对象布局及多态实现的探索(十)
  5. 《软件测试技术大全》一书的目录
  6. 2023年我们从M2E项目得到的教训
  7. 基于51单片机额温枪非接触红外人体测温仪原理图PCB
  8. 2008台北英特尔信息技术峰会主题演讲精选-Shane Wall
  9. 相机存储卡出现错误数据,照片还能恢复吗?
  10. Vue配合jQuery+Ajax使用