首先要有一个pytorch模型,我这里选用googelnet 为例,

我们可以使用pytorch 提供的imagenet的预训练模型。

import torchvision
googlenet = torchvision.models.googlenet(pretrained=True)
input = torch.randn(2,3,224,224)
out = googlenet(input)#控制台会下载预训练模型,找到下载的模型,然后直接使用

如果想自己训练的话,就接着往下看,否则阅读结束。

我是用的罗浩提供的框架,训练了一个猫狗分类,下面是罗浩的框架的github,以及B 站教学视频

https://github.com/michuanhaohao/deep-person-reid
https://www.bilibili.com/video/BV1Pg4y1q7sN

第一步我们要准备猫狗数据集

链接:https://pan.baidu.com/s/1HBewIgKsFD8hh3ICOnnTwA
提取码:ktab
顺便直接下载我的源代码吧。
链接: https://pan.baidu.com/s/1l6mrSpbfNSOsbmw2FT0zYw
提取码: r799
来自 https://blog.csdn.net/qq_43391414/article/details/118462136

然后就是改罗浩老师的框架了

第一步先写googlenet 网络,在models文件夹新建GoogLeNet.py

from __future__ import absolute_importimport torch
from torch import nn
from torch.nn import functional as F
import torchvision
__all__ = ['GoogLeNet']class GoogLeNet(nn.Module):def __init__(self, num_classes=2, loss={'xent'}, **kwargs):super(GoogLeNet, self).__init__()self.loss = lossgooglenet = torchvision.models.googlenet(pretrained=True)# self.base = googlenetself.base = nn.Sequential(*list(googlenet.children())[:-2])self.classifier = nn.Linear(1024, num_classes)self.feat_dim = 1024 # feature dimensiondef forward(self, x):x = self.base(x)# x = F.avg_pool2d(x, x.size()[2:])f = x.view(x.size(0), -1)y = self.classifier(f)return yif not self.training:return fy = self.classifier(f)if self.loss == {'xent'}:return yelif self.loss == {'xent', 'htri'}:return y, felif self.loss == {'cent'}:return y, felif self.loss == {'ring'}:return y, felse:raise KeyError("Unsupported loss: {}".format(self.loss))if __name__ == "__main__":input = torch.randn(2,3,224,224)model = GoogLeNet()out = model(input)aaa= 100

在models/__init__.py修改相应配置

最上面加上
from .GoogLeNet import *factory 加上
'GoogLeNet':GoogLeNet,

第二步修改数据集相关代码

在data_manage.py新增class

class CatDog(object):dataset_dir = 'dog_cat_dataset'def __init__(self, root='E:\\workspace\\dataset', **kwargs):self.dataset_dir = osp.join(root, self.dataset_dir)self.train_dir = osp.join(self.dataset_dir, 'train')self.test_dir = osp.join(self.dataset_dir, 'test')self.class_num = 2self._check_before_run()train, num_train_imgs = self._process_dir(self.train_dir)test,  num_test_imgs = self._process_dir(self.test_dir)num_total_imgs = num_train_imgs + num_test_imgsprint("=> Dog Cat dataset loaded")print("Dataset statistics:")print("  ------------------------------")print("  subset   | # images")print("  ------------------------------")print("  train    | {:8d}".format(num_train_imgs))print("  test    | {:8d}".format(num_test_imgs))print("  ------------------------------")print("  total    | {:8d}".format(num_total_imgs))print("  ------------------------------")self.train = trainself.test = testdef _check_before_run(self):"""Check if all files are available before going deeper"""if not osp.exists(self.dataset_dir):raise RuntimeError("'{}' is not available".format(self.dataset_dir))if not osp.exists(self.train_dir):raise RuntimeError("'{}' is not available".format(self.train_dir))if not osp.exists(self.test_dir):raise RuntimeError("'{}' is not available".format(self.test_dir))def _process_dir(self, dir_path):img_paths = glob.glob(osp.join(dir_path, '*.jpg'))dataset = []for img_path in img_paths:img_path = img_path.replace('\\', '/')class_name = img_path.split(".")[0].split("/")[-1]if not class_name in ["dog","cat"]: continueif class_name == "dog":class_index = 0elif class_name == "cat":class_index = 1dataset.append((img_path, class_index))num_imgs = len(dataset)return dataset, num_imgs

修改一下factory,新增条目


__img_factory = {'market1501': Market1501,'cuhk03': CUHK03,'dukemtmcreid': DukeMTMCreID,'msmt17': MSMT17,'cat_dog':CatDog,
}

一些解释:

root是dataset_dir 的路径,上面是数据集的文件夹,看下的我的文件结构,训练集和测试集随便分分就好,如果用我的代码,文件夹的名称和我一致。

在dataset_loader.py新增代码

class DogCatDataset_test(Dataset):"""Image Person ReID Dataset"""def __init__(self, dataset, transform=None):self.dataset = datasetself.transform = transformdef __len__(self):return len(self.dataset)def __getitem__(self, index):img_path, class_index = self.dataset[index]img = read_image(img_path)if self.transform is not None:img = self.transform(img)return img_path.split("/")[-1],img, class_index
class DogCatDataset(Dataset):"""Image Person ReID Dataset"""def __init__(self, dataset, transform=None):self.dataset = datasetself.transform = transformdef __len__(self):return len(self.dataset)def __getitem__(self, index):img_path, class_index = self.dataset[index]img = read_image(img_path)if self.transform is not None:img = self.transform(img)return img, class_index

训练和测试脚本:

from __future__ import print_function, absolute_import
import os
import sys
import time
import datetime
import argparse
import os.path as osp
import numpy as npimport torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.optim import lr_schedulerimport data_manager
from dataset_loader import ImageDataset, DogCatDataset, DogCatDataset_test
import transforms as T
import models
from losses import CrossEntropyLabelSmooth, DeepSupervision, CrossEntropy_loss
from utils import AverageMeter, Logger, save_checkpoint
from eval_metrics import evaluate
from optimizers import init_optimparser = argparse.ArgumentParser(description='Train image model with cross entropy loss')
# Datasets
parser.add_argument('--root', type=str, default='E:\\workspace\\dataset', help="root path to data directory")
parser.add_argument('-d', '--dataset', type=str, default='cat_dog',choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=4, type=int,help="number of data loading workers (default: 4)")
parser.add_argument('--height', type=int, default=224,help="height of an image (default: 256)")
parser.add_argument('--width', type=int, default=224,help="width of an image (default: 128)")
parser.add_argument('--split-id', type=int, default=0, help="split index")
# CUHK03-specific setting
parser.add_argument('--cuhk03-labeled', action='store_true',help="whether to use labeled images, if false, detected images are used (default: False)")
parser.add_argument('--cuhk03-classic-split', action='store_true',help="whether to use classic split by Li et al. CVPR'14 (default: False)")
parser.add_argument('--use-metric-cuhk03', action='store_true',help="whether to use cuhk03-metric (default: False)")
# Optimization options
parser.add_argument('--optim', type=str, default='adam', help="optimization algorithm (see optimizers.py)")
parser.add_argument('--max-epoch', default=60, type=int,help="maximum epochs to run")
parser.add_argument('--start-epoch', default=0, type=int,help="manual epoch number (useful on restarts)")
parser.add_argument('--train-batch', default=128, type=int,help="train batch size")
parser.add_argument('--test-batch', default=1, type=int, help="test batch size")
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,help="initial learning rate")
parser.add_argument('--stepsize', default=20, type=int,help="stepsize to decay learning rate (>0 means this is enabled)")
parser.add_argument('--gamma', default=0.1, type=float,help="learning rate decay")
parser.add_argument('--weight-decay', default=5e-04, type=float,help="weight decay (default: 5e-04)")
# Architecture
parser.add_argument('-a', '--arch', type=str,default='GoogLeNet',# default='resnet50',choices=models.get_names())
# Miscs
parser.add_argument('--print-freq', type=int, default=10, help="print frequency")
parser.add_argument('--seed', type=int, default=1, help="manual seed")
parser.add_argument('--resume', type=str,# default='E:/workspace/classify/checkpoint_ep60.pth',metavar='PATH')
parser.add_argument('--evaluate',# default=1,action='store_true', help="evaluation only")
parser.add_argument('--eval-step', type=int, default=-1,help="run evaluation for every N epochs (set to -1 to test after training)")
parser.add_argument('--start-eval', type=int, default=0, help="start to evaluate after specific epoch")
parser.add_argument('--save-dir', type=str, default='log_resnet_dog')
parser.add_argument('--use-cpu', action='store_true', help="use cpu")
parser.add_argument('--gpu-devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')args = parser.parse_args()def main():torch.manual_seed(args.seed)os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devicesuse_gpu = torch.cuda.is_available()if args.use_cpu: use_gpu = Falseif not args.evaluate:sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))else:sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))print("==========\nArgs:{}\n==========".format(args))if use_gpu:print("Currently using GPU {}".format(args.gpu_devices))cudnn.benchmark = Truetorch.cuda.manual_seed_all(args.seed)else:print("Currently using CPU (GPU is highly recommended)")print("Initializing dataset {}".format(args.dataset))dataset = data_manager.init_img_dataset(root=args.root, name=args.dataset, split_id=args.split_id,cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,)transform_train = T.Compose([T.Random2DTranslation(args.height, args.width),T.RandomHorizontalFlip(),T.ToTensor(),# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])transform_test = T.Compose([T.Resize((args.height, args.width)),T.ToTensor(),# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])pin_memory = True if use_gpu else Falsetrainloader = DataLoader(DogCatDataset(dataset.train, transform=transform_train),batch_size=args.train_batch, shuffle=True, num_workers=args.workers,pin_memory=pin_memory, drop_last=True,)testloader = DataLoader(DogCatDataset_test(dataset.test, transform=transform_test),batch_size=args.test_batch, shuffle=False, num_workers=args.workers,pin_memory=pin_memory, drop_last=False,)print("Initializing model: {}".format(args.arch))model = models.init_model(name=args.arch, num_classes=2, loss={'xent'}, use_gpu=use_gpu)print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters()) / 1000000.0))# criterion = CrossEntropyLabelSmooth(num_classes=dataset.class_num, use_gpu=use_gpu)criterion = CrossEntropy_loss(num_classes=dataset.class_num, use_gpu=use_gpu)optimizer = init_optim(args.optim, model.parameters(), args.lr, args.weight_decay)if args.stepsize > 0:scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)start_epoch = args.start_epochif args.resume:print("Loading checkpoint from '{}'".format(args.resume))checkpoint = torch.load(args.resume,map_location=torch.device('cpu'))model.load_state_dict(checkpoint)# start_epoch = checkpoint['epoch']if use_gpu:model = nn.DataParallel(model).cuda()if args.evaluate:print("Evaluate only")test(model, testloader, use_gpu)returnstart_time = time.time()train_time = 0best_rank1 = -np.infbest_epoch = 0print("==> Start training")for epoch in range(start_epoch, args.max_epoch):start_train_time = time.time()train(epoch, model, criterion, optimizer, trainloader, use_gpu)train_time += round(time.time() - start_train_time)if args.stepsize > 0: scheduler.step()if use_gpu:state_dict = model.module.state_dict()else:state_dict = model.state_dict()save_checkpoint(state_dict, 0, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth'))# if (epoch + 1) > args.start_eval and args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (#         epoch + 1) == args.max_epoch:#     print("==> Test")#     # rank1 = test(model, testloader, use_gpu)#     # is_best = rank1 > best_rank1#     # if is_best:#     #     best_rank1 = rank1#     #     best_epoch = epoch + 1##     if use_gpu:#         state_dict = model.module.state_dict()#     else:#         state_dict = model.state_dict()##     save_checkpoint(state_dict, 0, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))elapsed = round(time.time() - start_time)elapsed = str(datetime.timedelta(seconds=elapsed))train_time = str(datetime.timedelta(seconds=train_time))print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))def train(epoch, model, criterion, optimizer, trainloader, use_gpu):losses = AverageMeter()batch_time = AverageMeter()data_time = AverageMeter()model.train()end = time.time()for batch_idx, (imgs, class_index, ) in enumerate(trainloader):if use_gpu:imgs, class_index = imgs.cuda(), class_index.cuda()# measure data loading timedata_time.update(time.time() - end)outputs = model(imgs)if isinstance(outputs, tuple):loss = DeepSupervision(criterion, outputs, class_index)else:loss = criterion(outputs, class_index)optimizer.zero_grad()loss.backward()optimizer.step()# measure elapsed timebatch_time.update(time.time() - end)end = time.time()losses.update(loss.item(), class_index.size(0))if (batch_idx + 1) % args.print_freq == 0:print('Epoch: [{0}][{1}/{2}]\t''Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Data {data_time.val:.3f} ({data_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time,data_time=data_time, loss=losses))def test(model, testloader,use_gpu):batch_time = AverageMeter()model.eval()with torch.no_grad():for batch_idx, (file_name,imgs, class_index) in enumerate(testloader):if use_gpu: imgs = imgs.cuda()end = time.time()result = model(imgs)batch_time.update(time.time() - end)print(file_name,result ,"   true",class_index)np.save(file_name[0],imgs.numpy())return 0if __name__ == '__main__':main()

parser.add_argument('--root', type=str, default='E:\\workspace\\dataset', help="root path to data directory")这里改成自己的文件夹

拿到训练结果一个pth

如何测试pth呢

只需改

然后重新运行脚本

前面是文件名,后面是推理结果score,  猫的index是1 (如最后一列所示),可以看到后面那列的数据更大,说明预测的对

我的的代码地址 https://gitee.com/feboreigns/classify

教你自己训练的pytorch模型转caffe(一)相关推荐

  1. 在英特尔独立显卡上训练ResNet PyTorch模型

    作者:武卓,张晶 目录 1.1 英特尔锐炫™独立显卡简介 1.2 蝰蛇峡谷简介 1.3 搭建训练PyTorch模型的开发环境 1.3.1 Windows 版本要求: 1.3.2 下载并安装最新的英特尔 ...

  2. pytorch模型转caffe模型(pytorch->onnx->caffe)

    笔记目录 前言 一.pytorch转onnx 1.修改yolov5/models/export.py 2.在yolov5的目录下运行 3.简化onnx模型 二.onnx转caffe 1.yolov5网 ...

  3. pytorch模型部署

    1. C++调用python训练的pytorch模型(一)--makefile编写基础 https://blog.csdn.net/xiake001/article/details/84838249 ...

  4. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  5. 如何使用TensorRT对训练好的PyTorch模型进行加速?

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨伯恩legacy@知乎 来源丨https://zhuanlan.zhihu.com/p/8831 ...

  6. 手把手教你用自己训练的AI模型玩王者荣耀

    击上方"Python爬虫与数据挖掘",进行关注 回复"书籍"即可获赠Python从入门到进阶共10本电子书 今 日 鸡 汤 浮云一别后,流水十年间. 大家好,我 ...

  7. 手把手教你使用 YOLOV5 训练目标检测模型

    作者 | 肆十二 来源 | CSDN博客 这次要使用YOLOV5来训练一个口罩检测模型,比较契合当下的疫情,并且目标检测涉及到的知识点也比较多. 先来看看我们要实现的效果,我们将会通过数据来训练一个口 ...

  8. 9个让PyTorch模型训练提速的技巧!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 来源:AI公园,译者:ronghuaiyang 作者:William F ...

  9. 9个技巧让你的PyTorch模型训练变得飞快!

    公众号关注 "视学算法" 设为"星标",第一时间知晓最新干货~ 作者丨William Falcon 来源丨AI公园 不要让你的神经网络变成这样 让我们面对现实吧 ...

最新文章

  1. protobuf入门教程(二):消息类型
  2. gps天线拆解图片_华为 畅享 Z 拆解:揭秘千元5G手机物料成本是多少
  3. python默认参数举例_Python之在函数中使用列表作为默认参数
  4. 大学生活应该这样度过之参加一个社团让自己溶入团队——《程序员羊皮卷》连载(11)
  5. mysql repos_mysql yum源安装
  6. 我是如何在GitHub上开源一个项目的(截图说明) (VS2010可以安装git插件)
  7. C#和C++字符串拼接的性能分析
  8. win10计算机打开速度慢,win10电脑速度突然很慢怎么处理
  9. office 安装失败原因
  10. linux下刻录光盘读取不了_如何在Linux下刻录数据光盘
  11. 移动CMPP3.0短信网关接口协议
  12. 导出 服务器 excel文件,服务器数据库导出excel文件格式
  13. 中国压电材料取得突破性进展,未来B超机可折叠弯曲
  14. 简单好用的在线P图工具,一定记得收藏
  15. 淘宝客运营推广技巧方法有哪些?
  16. laravel 使用SSH 隧道连接到远程数据库
  17. datagear数据集添加参数
  18. Python OpenCV 横向平铺图像制作长图
  19. private static final long serialVersionUID=1L作用
  20. 何为非侵入式负荷监测-目标检测

热门文章

  1. 关于Redis集群模式下,使用mget通过keys批量获取value时的解决方案
  2. 【linux】之SSH远程管理服务
  3. python计算菜单消费总额字典_三、Python的列表、字典、元组合集合
  4. 基于单片机的GSM安防系统设计(#0432)
  5. 简单的抽卡模拟器1.1
  6. 找工作的时候怎么确认公司是否靠谱?
  7. 像素射击服务器维护公告图片,像素射击怎样导入自定义头像 - 历史资讯网
  8. python导入excel类库_Python实现的Excel文件读写类
  9. syslog 和 rsyslog
  10. 用MATLAB语言求给定曲线曲率半径