基于Pytorch的热轧钢带表面缺陷分类挑战
文章目录
- 1. 简介
- 2. 数据集
- 2.1 数据集选取
- 2.2 数据集处理
- 2.3 数据集加载
- 3. 训练模型
- 3.1 网络结构
- 3.2 损失函数和优化方式
- 4. 训练及参数调试
- 5. 测试
1. 简介
实现一个完整的图像分类任务,大致需要五个步骤:
- 选择开源框架
目前常用的深度学习框架主要有caffe、tesorflow、pytorch、mxnet、keras、paddlepaddle等。 - 构建并读取数据集
构建或获取数据集,根据选择开源框架进行数据集读取。 - 训练模型搭建
选择合适的网络模型、损失函数以及优化方式,完成整体的训练模型搭建。 - 训练并调试参数
通过训练选定合适参数。 - 测试准确率
在测试集上验证模型的最终性能。
本次实战选择pytorch开源框架,按照上述步骤实现一个基本的图像分类任务,并详细阐述其中的细节。
2. 数据集
2.1 数据集选取
表面缺陷检测是生产制造过程中必不可少的一步,尤其在带钢原料钢卷的轧制工艺过程中形成的表面缺陷是造成废、次品的主要原因,因此必须加强对带钢表面缺陷检测,通过缺陷检测,对于加强轧制工艺管理,剔除废品等都有重要的意义。
本次实战选择的数据库为由东北大学(NEU)发布的热轧钢带表面缺陷数据库,收集了热轧钢带的六种典型表面缺陷,即轧制氧化皮(RS),斑块(Pa),开裂(Cr),点蚀表面( PS),内含物(In)和划痕(Sc)。该数据库包括1,800个灰度图像:六种不同类型的典型表面缺陷,每一类缺陷包含300个样本。
数据库下载地址 NEU-CLS
提取码:175m
下面展示了6中缺陷样本的图像
2.2 数据集处理
首先需要将数据集分类处理成pytorch可以读取的形式,即是将缺陷图像按类别放置在不同的文件夹中。代码如下:
import os
import shutil### 数据集根目录
root_dir = '数据集绝对地址'### 数据集转移目录
shutil_dir = '处理数据集绝对地址'all_images = os.listdir(root_dir) #读取所有文件images_classes= ['Cr', 'In', 'Pa', 'PS', 'RS', 'Sc']for img in all_images:img_shutil_dir = os.path.join(shutil_dir, str(images_classes.index(img[0:2])))if not os.path.isdir(img_shutil_dir):os.mkdir(img_shutil_dir)shutil.copyfile(os.path.join(root_dir, img), os.path.join(img_shutil_dir, img))
运行后,数据集形式如下:每个文件夹中放置的是同类型的缺陷图像。
2.3 数据集加载
在这一步,需要实现数据集的加载和数据集划分,数据集加载运用ImageFolder()
和DataLoader()
, 数据集划分运用random_spilt()
,同时实现数据集加载时的数据增强。
数据增强介绍:数据增强
Pytorch常用图像处理和数据增强方法:Pytorch
import torch.utils.data as Data
import torchvision
import torchvision.transforms as transformstrain_transform = transforms.Compose([transforms.RandomResizedCrop(200),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])dataset = torchvision.datasets.ImageFolder(shutil_dir, transform=train_transform) #全部训练用例
'''按照8 :2 比例切分数据集为训练集和验证集train_dataset 为训练集,valid_dataset为验证集
'''
train_size = int(0.8*len(dataset))
valid_size = len(dataset)-train_sizetrain_dataset, valid_dataset = Data.random_split(dataset, [train_size, valid_size])train_data = Data.DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_data = Data.DataLoader(valid_dataset, batch_size=1, shuffle=False)
本例中的Normalize使用的参数为在ImageNet数据集上计算得到的方差和均值,实际使用时需要重新计算。参考链接:pytorch标准化。
3. 训练模型
3.1 网络结构
常用的图像分类网络有VGG、ResNet、ResNext、DenseNet、Inception、ShuffleNet等,
参考链接:
图像分类:常用分类网络结构(附论文下载)
常用的分类网络
在本次实战中,主要选取了ResNet-50经典网络做为训练模型,
import torchvision
import torch.nn as nnbasic_model = torchvision.models.resnet50(pretrained=True)class resnet_classifier(nn.Module):def __init__(self, classnumber=21):super(resnet_classifier, self).__init__()self.features = nn.Sequential(*list(basic_model.children())[:-1])fc_features = basic_model.fc.in_featuresself.classifier = nn.Linear(fc_features, classnumber, bias=False)def forward(self, x):features = self.features(x)features = torch.flatten(features, 1)classifier = self.classifier(features)return classifier
3.2 损失函数和优化方式
损失函数选择标准的交叉熵损失函数(详细介绍损失函数)
优化方式选择Adam优化(详细介绍优化方式)
4. 训练及参数调试
在训练中,在网络结构中加载了预训练模型,可以加快训练速度和提升训练精度,初始学习率设置为1e-4, 在网络结构的特征层和分类层采取不同的学习率,分类层的学习率为特征层的10倍,学习率调整策略为指数衰减。(参考链接学习率调整)
model = resnet_classifier()
train_params = [{'params':model.features.parameters(), 'lr':lr},{'params':model.classifier.parameters(),'lr':10*lr}]
optimizer = torch.optim.Adam(train_params)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)
训练和测试代码:
def training(self,epoch):train_loss = 0.0self.model.train()tbar = tqdm(self.train_data)num_img_tr = len(self.train_data)for i, sample in enumerate(tbar):img, label = sampleif self.cuda:img = img.cuda()self.optimizer.zero_grad()output = self.model(img)loss = self.Loss(output.cpu(), label)loss.backward()self.optimizer.step()self.scheduler.step()train_loss += loss.item()### 记录训练过程 监控loss值tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.batch_size + img.data.shape[0]))print('Loss: %.3f' % train_loss)def validation(self, epoch):self.model.eval()tbar = tqdm(self.valid_data, desc='\r')test_loss = 0.0train_acc_sum = 0.0num_img_tr = len(self.valid_data) * self.batch_sizefor i, sample in enumerate(tbar):img, label = sampleif self.cuda:img = img.cuda()with torch.no_grad():output = model(img)loss = self.loss(output.cpu(), label)test_loss += loss.item()tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))# Add batch sample into evaluatortrain_acc_sum += (output.cpu().argmax(dim=1) == label).sum().cpu().item()### 监控验证过程 记录正确率 accuracy = train_acc_sum / num_img_trself.writer.add_scalar('test/total_loss_epoch', test_loss, epoch)self.writer.add_scalar('accuracy', accuracy, epoch)
5. 测试
选用不同的模型和训练参数,对比训练精度,对模型或者超参数进行调整优化。
基于Pytorch的热轧钢带表面缺陷分类挑战相关推荐
- 基于Pytorch实现猫狗分类
基于Pytorch实现猫狗分类 一.环境配置 二.数据集准备 三.猫狗分类的实例 四.实现分类预测测试 五.参考资料 一.环境配置 1.环境使用 Anaconda 2.配置Pytorch pip in ...
- Kaggle猫狗大战——基于Pytorch的CNN网络分类:数据获取、预处理、载入(1)
Kaggle猫狗大战--基于Pytorch的CNN网络分类:数据获取.预处理.载入(1) 第一次写CSDN博客,之前一直是靠着CSDN学学代码,这次不得不亲自上场了,就想着将学习的过程都记录下来.新人 ...
- 基于Pytorch的猫狗分类
无偿分享~ 猫狗二分类文件下载地址 在下一章说 猫狗分类这个真是困扰我好几天,找了好多资料都是以TensorFlow的猫狗分类,但我们要求的是以pytorch的猫狗分类.刚开始我找到了也 ...
- 基于pytorch的简单图片分类问题实现
pytorch中基于简单图片分类问题的实现大致可以分为以下几个步骤: 1.建立处理图片的神经网络,提前设置好损失函数(图片分类问题一般使用交叉熵损失函数),以及优化器. 2.在每一个学习的步骤中,将训 ...
- 【代码实战】基于pytorch实现中文文本分类任务
点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 来自 | 知乎 地址 | https://zhuanlan.zhihu.com/p/73176 ...
- Python基于PyTorch实现BP神经网络ANN分类模型项目实战
说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 在人工神经网络的发展历史上,感知机(Multilayer Per ...
- 【项目实战课】基于Pytorch的DANet自然图像降噪实战
欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的DANet自然图像降噪实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战讲 ...
- 【项目实战课】基于Pytorch的EnlightenGAN自然图像增强实战
欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的EnlightenGAN自然图像增强实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行 ...
- 【项目实战课】基于Pytorch的SiameseFC通用目标跟踪实战
欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的SiameseFC目标跟踪实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实 ...
最新文章
- 在医学图像分析中使用ICP算法进行点云配准
- 数据科学-通过数据探索了解我们的特征
- java ssh 下载excel,SSH整合WEB导出EXCEL案例
- atexit注册进程终止处理函数
- 实现div可以调整高度(div实现resize)
- MongoDB事实:商品硬件上每秒插入80000次以上
- void和void *
- 计算机网络(十八)-以太网
- 高性能开发,别点,发际线要紧!
- java多线程的安全_java-多线程的安全问题
- oracle 添加登陆文件路径
- Andriod广播注册接收过程简析
- 交待给你的事办完了,就不能回个话么?
- 域名微信拦截html代码,微信域名拦截查询网页源码——一个非常实用的微信域名检测工具实现...
- 正确的配置Android开发环境-让你的C盘不在爆红
- 【微信小程序中的股票分时图、K线图的源代码解析】
- LSI Logic 1068 SAS 磁盘阵列卡配置教程
- 联想计算机设置恢复出厂,联想电脑一键恢复出厂设置使用方法
- 今日头条快手等大厂刨根问底之APP启动流程篇
- 实现windows与ubuntu的之间的复制与粘贴
热门文章
- 2012年10月管理计算机系统,2010年10月全国高等教育自学考试管理系统中计算机应用真题...
- 免费的样机素材,拿走不谢
- Java中pop和poll区别
- iOS--UI之导航控制器与标签控制器
- 内蒙古2019年经济发展“稳”字当头
- 正确使用 Adobe 系列全家桶工具的教程(2021.2.20)
- 目标检测“A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection”
- 如何准确找客户?有什么好的办法找客户吗?
- c语言读取一张hdr图片,在Photoshop中调出人物照片高质量的HDR效果
- 【Pytorch学习】复现DCGAN训练生成动漫头像