数据集下载
https://download.csdn.net/download/weixin_32759777/12526217

大兄弟的地址
https://blog.csdn.net/ZOUZHEN_ID/article/details/83958772

训练数据


import torchimport torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
from load_data import Traffic# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")'''
使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
argparse是python的一个包,用来解析输入的参数
如:python mnist.py --outf model  (意思是将训练的模型保存到model文件夹下,当然,你也可以不加参数,那样的话代码最后一行torch.save()就需要注释掉了)python mnist.py --net model/net_005.pth(意思是加载之前训练好的网络模型,前提是训练使用的网络和测试使用的网络是同一个网络模型,保证权重参数矩阵相等)
'''
parser = argparse.ArgumentParser()parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints')  # 模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)")  # 模型加载路径
opt = parser.parse_args()  # 解析得到你在路径中输入的参数,比如 --outf 后的"model"或者 --net 后的"model/net_005.pth",是作为字符串形式保存的# Load training and testing datasets.
# ROOT_PATH = "./traffic"
train_data_dir = "/home/chenyang/PycharmProjects/detect_traffic_sign/BelgiumTSC_Training"
# test_data_dir = os.path.join(ROOT_PATH, "BelgiumTSC_Training/Testing")'''
定义LeNet神经网络,进一步的理解可查看Pytorch入门,里面很详细,代码本质上是一样的,这里做了一些封装
'''class LeNet(nn.Module):'''该类继承了torch.nn.Modul类构建LeNet神经网络模型'''def __init__(self):super(LeNet, self).__init__()  # 这一个是python中的调用父类LeNet的方法,因为LeNet继承了nn.Module,如果不加这一句,无法使用导入的torch.nn中的方法,这涉及到python的类继承问题,你暂时不用深究# 第一层神经网络,包括卷积层、线性激活函数、池化层self.conv1 = nn.Sequential(     # input_size=(1*28*28):输入层图片的输入尺寸,我看了那个文档,发现不需要天,会自动适配维度nn.Conv2d(3, 32, 5, 1, 2),   # padding=2保证输入输出尺寸相同:采用的是两个像素点进行填充,用尺寸为5的卷积核,保证了输入和输出尺寸的相同nn.ReLU(),                  # input_size=(6*28*28):同上,其中的6是卷积后得到的通道个数,或者叫特征个数,进行ReLu激活nn.MaxPool2d(kernel_size=2, stride=2), # output_size=(6*14*14):经过池化层后的输出)# 第二层神经网络,包括卷积层、线性激活函数、池化层self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5),  # input_size=(6*14*14):  经过上一层池化层后的输出,作为第二层卷积层的输入,不采用填充方式进行卷积nn.ReLU(),            # input_size=(16*10*10): 对卷积神经网络的输出进行ReLu激活nn.MaxPool2d(2, 2)    # output_size=(16*5*5):  池化层后的输出结果)# 全连接层(将神经网络的神经元的多维输出转化为一维)self.fc1 = nn.Sequential(nn.Linear(64 * 5 * 5, 128),  # 进行线性变换nn.ReLU()                    # 进行ReLu激活)# 输出层(将全连接层的一维输出进行处理)self.fc2 = nn.Sequential(nn.Linear(128, 84),nn.ReLU())# 将输出层的数据进行分类(输出预测值)self.fc3 = nn.Linear(84, 62)# 定义前向传播过程,输入为xdef forward(self, x):x = self.conv1(x)x = self.conv2(x)# nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维x = x.view(x.size()[0], -1)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 超参数设置
EPOCH = 20   # 遍历数据集次数(训练模型的轮数)
BATCH_SIZE = 3     # 批处理尺寸(batch_size):关于为何进行批处理,文档中有不错的介绍
LR = 0.001        # 学习率:模型训练过程中每次优化的幅度# 定义数据预处理方式(将输入的类似numpy中arrary形式的数据转化为pytorch中的张量(tensor))
# transform = transforms.ToTensor()
# # transform = torch.FloatTensor
transform = transforms.Compose([transforms.Resize((28, 28)),transforms.CenterCrop(28),transforms.ToTensor()])# 定义训练数据集(此处是加载MNIST手写数据集)
trainset = Traffic(root=train_data_dir, # 如果从本地加载数据集,对应的加载路径train=True,     # 训练模型download=False,  # 是否从网络下载训练数据集transform=transform  # 数据的转换形式
)# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(trainset,                # 加载测试集batch_size=BATCH_SIZE,   # 最小批处理尺寸shuffle=True,            # 标识进行数据迭代时候将数据打乱
)def model_train():# 定义损失函数loss function 和优化方式(采用SGD)net = LeNet().to(device)criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,通常用于多分类问题上optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 优化函数for epoch in range(EPOCH):sum_loss = 0.0# 数据读取(采用python的枚举方法获得标签和数据,这一部分可能和numpy相关)for i, data in enumerate(trainloader):inputs, labels = data# labels = [torch.LongTensor(label) for label in labels]# 将输入数据和标签放入构建的图中 注:图的概念可在pytorch入门中查inputs, labels = inputs.to(device), labels.to(device)# 梯度清零optimizer.zero_grad()# forward + backward  注: 这一部分是训练神经网络的核心outputs = net(inputs)loss = criterion(outputs, labels)loss.backward() # 反向自动求导optimizer.step() # 进行优化# 每训练100个batch打印一次平均losssum_loss += loss.item()if i % 48 == 0:print('[%d, %d] loss: %.03f'% (epoch + 1, i + 1, sum_loss / 100))sum_loss = 0.0# 每跑完一次epoch测试一下准确率# with torch.no_grad():#     correct = 0#     total = 0# for i, data in enumerate(testloader):# for data in testloader:#     images, labels = data#     images, labels = images.to(device), labels.to(device)#     outputs = net(images)#     # 取得分最高的那个类#     _, predicted = torch.max(outputs.data, 1)#     total += labels.size(0)#     correct += (predicted == labels).sum()# print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))torch.save(net.state_dict(), "new{}_{}.pth".format(opt.outf, epoch + 1))
# 训练
if __name__ == "__main__":model_train()

加载数据


import torch.utils.data as dataimport os.path
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Imageroot = "/home/zlab/zhangshun/torch1/data_et/"# -----------------ready the dataset--------------------------
def default_loader(path):return Image.open(path).convert('RGB')# 参考自定义
class MyDataset(Dataset):# 构造函数带有默认参数def __init__(self,epoch=0,transform=None, target_transform=None, loader=default_loader):path_list = os.listdir("/home/chenyang/PycharmProjects/openpose_pruning/openpose_openface_net/save_model_weight")imgs = []path_list.remove(".pth")for one_image in path_list[epoch*64:(epoch+1)*64]:data=torch.load("/home/chenyang/PycharmProjects/openpose_pruning/openpose_openface_net/save_model_weight/"+one_image)imgs.append(("/home/chenyang/PycharmProjects/coco2017/train2017/"+one_image[:-4], data.get("pafs"),data.get("heatmaps")))  # imgs中包含有图像路径和标签self.imgs = imgsself.transform = transformself.target_transform = target_transformself.loader = loaderdef __getitem__(self, index):fn, label1,label2= self.imgs[index]# 调用定义的loader方法img = self.loader(fn)if self.transform is not None:img = self.transform(img)return img, label1,label2def __len__(self):return len(self.imgs)class Traffic(data.Dataset):'''Traffic Dataset.'''def __init__(self, root, train=True, transform=None, target_transform=None, download=False):data_dir="/home/chenyang/PycharmProjects/detect_traffic_sign/BelgiumTSC_Training/Training"self.root = os.path.expanduser(root)self.transform = transformself.target_transform = target_transformself.train = train  # training set or test setself.loader=default_loaderdirectories = [d for d in os.listdir(data_dir)if os.path.isdir(os.path.join(data_dir, d))]self.datas = []for d in directories:label_dir = os.path.join(data_dir, d)file_names = [os.path.join(label_dir, f)for f in os.listdir(label_dir) if f.endswith(".ppm")]for f in file_names:self.datas.append((f,int(d)))def __getitem__(self, index):fn, label1 = self.datas[index]# 调用定义的loader方法img = self.loader(fn)if self.transform is not None:img = self.transform(img)return img, label1def __len__(self):return len(self.datas)if __name__ == '__main__':# transform = transforms.ToTensor()# # transform = torch.FloatTensortransform = transforms.Compose([transforms.Resize((32, 32)),transforms.CenterCrop(32),transforms.ToTensor()])trainset = Traffic(root="",  # 如果从本地加载数据集,对应的加载路径train=True,  # 训练模型download=False,  # 是否从网络下载训练数据集transform=transform  # 数据的转换形式)# 定义训练批处理数据trainloader = torch.utils.data.DataLoader(trainset,  # 加载测试集batch_size=10,  # 最小批处理尺寸shuffle=True,  # 标识进行数据迭代时候将数据打乱)

帮助一个大兄弟修复完善了一下他的代码相关推荐

  1. 根据一个大图片自动生成相应小图片的代码

    我的一个项目中用的 using System; using System.IO; using System.Drawing; namespace Compoment {     /**//// < ...

  2. 硬盘只剩下一个大分区数据恢复图文教程

    赛门铁克的Ghost是一个伟大的软件,给我们系统安装备份带来极大便利.由Ghost派生出来的克隆版操作系统安装方式被大多数朋友采用.便利工具也是双刃剑,由于一些朋友对磁盘.分区的概念不是太了解熟悉,经 ...

  3. 哪位大兄弟有用 cMake 开发Android ndk的

    一直用 Android studio 开发ndk,但是gradle支持的不是很好,只有experimental 版本支持 配置各种蛋疼.主要每次新建一个module都要修改配置半天. 之前也看到过go ...

  4. 底薪80万挖来一个大公司高管

    底薪80万挖来一个大公司高管, 看上去很值,有可能亏大了! 老板都是求贤若渴,看见人才就两眼发光,想都不想就把80万处以12个月,每月六七万的发,干不干活,都有80万一年,谁还认真干? 可以这样做: ...

  5. 2021级C语言大作业 - 合成一个大西瓜

    分享21级同学大一上学期用C语言(及少量C++)实现的合成一个大西瓜游戏.由于同学们刚学了三个月的编程,实现还不够完善,工程代码.图片音乐素材可以从百度网盘下载: 链接:https://pan.bai ...

  6. 对于一个大数据应用项目/产品的落地,可以大致总结为五大步骤阶段?

    对于一个大数据应用项目/产品的落地,可以大致总结为五大步骤阶段: 数据规划.数据治理.数据应用.迭代实施.商业价值. 第一阶段:数据规划 一个成功的大数据项目,需要有一个良好的开端,即做好数据规划阶段 ...

  7. 个是云计算,一个大数据,一个人工智能,

    我今天要讲这三个话题,一个是云计算,一个大数据,一个人工智能,我为什么要讲这三个东西呢?因为这三个东西现在非常非常的火,它们之间好像互相有关系,一般谈云计算的时候也会提到大数据,谈人工智能的时候也会提 ...

  8. Android热更新十:自己写一个Android热修复

    很早之前就想深入的研究和学习一下热修复,由于时间的原因一直拖着,现在才执笔弄起来. Android而更新系列: Android热更新一:JAVA的类加载机制 Android热更新二:理解Java反射 ...

  9. 【网络文摘】一个大神程序员的使命感究竟应该是什么

    来源:一个大神程序员的使命感究竟应该是什么 工作了五年的工程师,算不算高级开发者?归类开发者不是简单地看工作年限,因为经验这种东西千金难换但又一文不值. 我们现在工作的行业很奇怪.明明每年都有新的从业 ...

最新文章

  1. 基于DOS命令打war包
  2. POJ 1449 amp; ZOJ 1036 Enigma(简单枚举)
  3. SAP Fiori Elements原理介绍之类型为Value Help的Smart Field工作原理
  4. WordPress同Kyma成功建立连接后,存放在mysql里的Kyma entry
  5. 引入Spring集成
  6. Android程序结构
  7. 防火墙问题 Linux系统 /etc/sysconfig/路径下无iptables文件
  8. WCF服务实例激活类型编程与开发(转)
  9. ANOI 2009 【同类分布】
  10. httpd的三种模式比较
  11. [Web Chart系列之一]Web端图形绘制SVG,VML, HTML5 Canvas 技术比较
  12. D-LINK二层交换机Asymmetric VLAN配置
  13. confluence统计_【漏洞预警】confluence远程代码执行漏洞(CVE-2019-3396)
  14. 读取mysql表名称_JAVA动态读取mysql表的字段名索引
  15. 云迁移实践:VMware虚拟机迁移到AWS
  16. 合并两个有序链表(Java)
  17. Pegasus读取传感器AD的值
  18. ios开发-- URL Schemes 使用详解
  19. 业务需求与解决方案管理机制
  20. 端云协同,打造更易用的AI计算平台

热门文章

  1. 在Eclipse里Validating非常缓慢
  2. 测试wifi软件 最大容量,使用VulcanCompact应用层测试仪评估测试WiFi6无线路由器的最大支持设备数及并发连接数...
  3. java 自动装箱自动拆箱_自动装箱和自动拆箱
  4. 计算机程序设计员_第二届北京大工匠计算机程序设计员、网络与信息安全管理员挑战赛:一场互联网“战场”的巅峰对决...
  5. kali linux提示安装系统失败,kali“安装系统”失败分析及解决
  6. 禅道测试套件怎么用_优质单元测试的十大标准,你有遵循吗?
  7. 北京交通大学计算机科学与技术研究生导师,熊轲_北京交通大学研究生导师信息...
  8. java 边界_Java数组边界问题
  9. 如何解决Office2016安装时提示:错误1406。安装程序无法将值写入注册表项\.xlsx
  10. 【 C 】作用域、链接属性、存储类型、static 关键字简介及总结