resnet50网络实现垃圾分类
一:介绍:
经典resnet网络是由何凯明团队于2015年提出,论文名为《Deep Residual Learning for Image Recognition》
resnet网络所要解决的问题为深度神经网络的“退化”问题,即随着神经网络搭建的越深,拟合效果却越差的问题,并且这个问题不是由过拟合诱发的。
resnet也成为残差网络,网络由残差块构建:
残差块由多个级联的卷积层和一个shortcut connections组成,将二者的输出值累加后,通过ReLU激活层得到残差块的输出。多个残差块可以串联起来,从而实现更深的网络。
残差块有两种设计方式
左图针对较浅的网络,如ResNet-18/34;右图针对较深的网络,又称为”bottleneck” building block,如ResNet-50/101/152,使用此方式的目的就是为了降低参数数目。
论文给出了五种不同层数的resnet
ResNet-18/34对应的每个残差块的卷积kernel大小依次是3*3、3*3,ResNet-50/101/152对应的每个残差块的卷积kernel大小依次是1*1、3*3、1*1。
论文中给出了层数为34的ResNet网络结构
二:实现垃圾分类
1.准备数据集:
2.加载数据集:
class garbage_datasets(Dataset):def __init__(self, filepath):self.images = []self.labels = []self.transform = transformfor filename in tqdm(os.listdir(filepath+'Hazardous waste')):image = Image.open(filepath+'Hazardous waste/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(0)for filename in tqdm(os.listdir(filepath+'Kitchen waste')):image = Image.open(filepath+'Kitchen waste/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(1) for filename in tqdm(os.listdir(filepath+'Other garbage')):image = Image.open(filepath+'Other garbage/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(2) for filename in tqdm(os.listdir(filepath+'Recyclable garbage')):image = Image.open(filepath+'Recyclable garbage/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(3) self.labels = torch.LongTensor(self.labels)def __getitem__(self, index):return self.images[index], self.labels[index]def __len__(self):images = np.array(self.images)len = images.shape[0]return lentrain_data = garbage_datasets('data/train/')
train_loader = DataLoader(train_data,batch_size = batch_size,shuffle = True)val_data = garbage_datasets('data/val/')
val_loader = DataLoader(val_data,batch_size = batch_size)
3.构建网络:
class Bottleneck(nn.Module):extention=4def __init__(self,inplanes,planes,stride,downsample=None):super(Bottleneck, self).__init__()self.conv1=nn.Conv2d(inplanes,planes,kernel_size=1,stride=stride,bias=False)self.bn1=nn.BatchNorm2d(planes)self.conv2=nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False)self.bn2=nn.BatchNorm2d(planes)self.conv3=nn.Conv2d(planes,planes*self.extention,kernel_size=1,stride=1,bias=False)self.bn3=nn.BatchNorm2d(planes*self.extention)self.relu=nn.ReLU( )self.downsample=downsampleself.stride=stridedef forward(self,x):residual=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)out=self.relu(out)out=self.conv3(out)out=self.bn3(out)out=self.relu(out)if self.downsample is not None:residual=self.downsample(x)out=out + residualout=self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,layers,num_class):self.inplane=64super(ResNet, self).__init__()self.block=blockself.layers=layersself.conv1=nn.Conv2d(3,self.inplane,kernel_size=7,stride=2,padding=3,bias=False)self.bn1=nn.BatchNorm2d(self.inplane)self.relu=nn.ReLU()self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.stage1=self.make_layer(self.block,64,layers[0],stride=1)self.stage2=self.make_layer(self.block,128,layers[1],stride=2)self.stage3=self.make_layer(self.block,256,layers[2],stride=2)self.stage4=self.make_layer(self.block,512,layers[3],stride=2)self.avgpool=nn.AvgPool2d(7)self.fc=nn.Linear(512*block.extention,num_class)def forward(self,x):out=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.maxpool(out)out=self.stage1(out)out=self.stage2(out)out=self.stage3(out)out=self.stage4(out)out=self.avgpool(out)out=torch.flatten(out,1)out=self.fc(out)return outdef make_layer(self,block,plane,block_num,stride=1):block_list=[]downsample=Noneif(stride!=1 or self.inplane!=plane*block.extention):downsample=nn.Sequential(nn.Conv2d(self.inplane,plane*block.extention,stride=stride,kernel_size=1,bias=False),nn.BatchNorm2d(plane*block.extention))conv_block=block(self.inplane,plane,stride=stride,downsample=downsample)block_list.append(conv_block)self.inplane=plane*block.extentionfor i in range(1,block_num):block_list.append(block(self.inplane,plane,stride=1))return nn.Sequential(*block_list)model=ResNet(Bottleneck,[3,4,6,3],4)
4.训练模型:
def train(epoch):model.train()print("epoch:",epoch+1)running_loss = 0.0for batch_idx,data in enumerate(train_loader,0):inputs, targets = datainputs, targets = inputs.to(device),targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs,targets)loss.backward()optimizer.step()running_loss = running_loss + loss.item()print('train loss: %.3f' % (running_loss/batch_idx))torch.save(model.state_dict(), './model1.pth')
5.验证模型:
def val():model.eval()correct = 0total = 0with torch.no_grad():for data in val_loader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('accuracy on test set: %d %% ' % (100*correct/total))return correct/total
6.测试模型:
def test(imgpath):font={ 'color': 'red','size': 20,'family': 'Times New Roman','style':'italic'}o_img = Image.open(imgpath)o_img1 = o_img.resize((224,224))img = transform(o_img1)img = img.unsqueeze(0)img = img.cuda()print(img.shape)model = ResNet(Bottleneck,[3,4,6,3],4)model.load_state_dict(torch.load("model.pth")) model = model.cuda()output = model(img)_, predict = torch.max(output,dim=1)if predict == 0:print("Hazardous waste")plt.imshow(o_img)plt.text(0, -6.0, "Hazardous waste", fontdict=font)plt.show()if predict == 1:print("Kitchen waste")plt.imshow(o_img)plt.text(0, -6.0, "Kitchen waste", fontdict=font)plt.show() if predict == 2:print("Other garbage")plt.imshow(o_img)plt.text(0, -6.0, "Other garbage", fontdict=font)plt.show() if predict == 3:print("Recyclable garbage")plt.imshow(o_img)plt.text(0, -6.0, "Recyclable garbage", fontdict=font)plt.show()
源代码:
import torch.nn as nn
import torch
import numpy as np
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
import torchvision
import torch.nn.functional as F
import torch.optim as optim
import os
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as pltbatch_size = 8transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])torch.cuda.empty_cache()class garbage_datasets(Dataset):def __init__(self, filepath):self.images = []self.labels = []self.transform = transformfor filename in tqdm(os.listdir(filepath+'Hazardous waste')):image = Image.open(filepath+'Hazardous waste/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(0)for filename in tqdm(os.listdir(filepath+'Kitchen waste')):image = Image.open(filepath+'Kitchen waste/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(1) for filename in tqdm(os.listdir(filepath+'Other garbage')):image = Image.open(filepath+'Other garbage/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(2) for filename in tqdm(os.listdir(filepath+'Recyclable garbage')):image = Image.open(filepath+'Recyclable garbage/'+filename)image = image.resize((224,224))image = self.transform(image)self.images.append(image)self.labels.append(3) self.labels = torch.LongTensor(self.labels)def __getitem__(self, index):return self.images[index], self.labels[index]def __len__(self):images = np.array(self.images)len = images.shape[0]return lentrain_data = garbage_datasets('data/train/')
train_loader = DataLoader(train_data,batch_size = batch_size,shuffle = True)val_data = garbage_datasets('data/val/')
val_loader = DataLoader(val_data,batch_size = batch_size)class Bottleneck(nn.Module):extention=4def __init__(self,inplanes,planes,stride,downsample=None):super(Bottleneck, self).__init__()self.conv1=nn.Conv2d(inplanes,planes,kernel_size=1,stride=stride,bias=False)self.bn1=nn.BatchNorm2d(planes)self.conv2=nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False)self.bn2=nn.BatchNorm2d(planes)self.conv3=nn.Conv2d(planes,planes*self.extention,kernel_size=1,stride=1,bias=False)self.bn3=nn.BatchNorm2d(planes*self.extention)self.relu=nn.ReLU( )self.downsample=downsampleself.stride=stridedef forward(self,x):residual=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)out=self.relu(out)out=self.conv3(out)out=self.bn3(out)out=self.relu(out)if self.downsample is not None:residual=self.downsample(x)out=out + residualout=self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,layers,num_class):self.inplane=64super(ResNet, self).__init__()self.block=blockself.layers=layersself.conv1=nn.Conv2d(3,self.inplane,kernel_size=7,stride=2,padding=3,bias=False)self.bn1=nn.BatchNorm2d(self.inplane)self.relu=nn.ReLU()self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.stage1=self.make_layer(self.block,64,layers[0],stride=1)self.stage2=self.make_layer(self.block,128,layers[1],stride=2)self.stage3=self.make_layer(self.block,256,layers[2],stride=2)self.stage4=self.make_layer(self.block,512,layers[3],stride=2)self.avgpool=nn.AvgPool2d(7)self.fc=nn.Linear(512*block.extention,num_class)def forward(self,x):out=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.maxpool(out)out=self.stage1(out)out=self.stage2(out)out=self.stage3(out)out=self.stage4(out)out=self.avgpool(out)out=torch.flatten(out,1)out=self.fc(out)return outdef make_layer(self,block,plane,block_num,stride=1):block_list=[]downsample=Noneif(stride!=1 or self.inplane!=plane*block.extention):downsample=nn.Sequential(nn.Conv2d(self.inplane,plane*block.extention,stride=stride,kernel_size=1,bias=False),nn.BatchNorm2d(plane*block.extention))conv_block=block(self.inplane,plane,stride=stride,downsample=downsample)block_list.append(conv_block)self.inplane=plane*block.extentionfor i in range(1,block_num):block_list.append(block(self.inplane,plane,stride=1))return nn.Sequential(*block_list)model=ResNet(Bottleneck,[3,4,6,3],4)device = torch.device('cuda'if torch.cuda.is_available else 'cpu')model.to(device)model.load_state_dict(torch.load("model1.pth")) criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)def train(epoch):model.train()print("epoch:",epoch+1)running_loss = 0.0for batch_idx,data in enumerate(train_loader,0):inputs, targets = datainputs, targets = inputs.to(device),targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs,targets)loss.backward()optimizer.step()running_loss = running_loss + loss.item()print('train loss: %.3f' % (running_loss/batch_idx))torch.save(model.state_dict(), './model1.pth') def val():model.eval()correct = 0total = 0with torch.no_grad():for data in val_loader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('accuracy on test set: %d %% ' % (100*correct/total))return correct/totalif __name__ == '__main__':acc_list = []epoch_list = []for epoch in range(5):train(epoch)acc = val()acc_list.append(acc)epoch_list.append(epoch + 1)plt.plot(epoch_list,acc_list)plt.ylabel("ACC")plt.xlabel("Epoch")plt.show()
测试源码:
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nntransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])class Bottleneck(nn.Module):extention=4def __init__(self,inplanes,planes,stride,downsample=None):super(Bottleneck, self).__init__()self.conv1=nn.Conv2d(inplanes,planes,kernel_size=1,stride=stride,bias=False)self.bn1=nn.BatchNorm2d(planes)self.conv2=nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False)self.bn2=nn.BatchNorm2d(planes)self.conv3=nn.Conv2d(planes,planes*self.extention,kernel_size=1,stride=1,bias=False)self.bn3=nn.BatchNorm2d(planes*self.extention)self.relu=nn.ReLU( )self.downsample=downsampleself.stride=stridedef forward(self,x):residual=xout=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.conv2(out)out=self.bn2(out)out=self.relu(out)out=self.conv3(out)out=self.bn3(out)out=self.relu(out)if self.downsample is not None:residual=self.downsample(x)out=out + residualout=self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,layers,num_class):self.inplane=64super(ResNet, self).__init__()self.block=blockself.layers=layersself.conv1=nn.Conv2d(3,self.inplane,kernel_size=7,stride=2,padding=3,bias=False)self.bn1=nn.BatchNorm2d(self.inplane)self.relu=nn.ReLU()self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)self.stage1=self.make_layer(self.block,64,layers[0],stride=1)self.stage2=self.make_layer(self.block,128,layers[1],stride=2)self.stage3=self.make_layer(self.block,256,layers[2],stride=2)self.stage4=self.make_layer(self.block,512,layers[3],stride=2)self.avgpool=nn.AvgPool2d(7)self.fc=nn.Linear(512*block.extention,num_class)def forward(self,x):out=self.conv1(x)out=self.bn1(out)out=self.relu(out)out=self.maxpool(out)out=self.stage1(out)out=self.stage2(out)out=self.stage3(out)out=self.stage4(out)out=self.avgpool(out)out=torch.flatten(out,1)out=self.fc(out)return outdef make_layer(self,block,plane,block_num,stride=1):block_list=[]downsample=Noneif(stride!=1 or self.inplane!=plane*block.extention):downsample=nn.Sequential(nn.Conv2d(self.inplane,plane*block.extention,stride=stride,kernel_size=1,bias=False),nn.BatchNorm2d(plane*block.extention))conv_block=block(self.inplane,plane,stride=stride,downsample=downsample)block_list.append(conv_block)self.inplane=plane*block.extentionfor i in range(1,block_num):block_list.append(block(self.inplane,plane,stride=1))return nn.Sequential(*block_list)def test(imgpath):font={ 'color': 'red','size': 20,'family': 'Times New Roman','style':'italic'}o_img = Image.open(imgpath)o_img1 = o_img.resize((224,224))img = transform(o_img1)img = img.unsqueeze(0)img = img.cuda()print(img.shape)model = ResNet(Bottleneck,[3,4,6,3],4)model.load_state_dict(torch.load("model.pth")) model = model.cuda()output = model(img)_, predict = torch.max(output,dim=1)if predict == 0:print("Hazardous waste")plt.imshow(o_img)plt.text(0, -6.0, "Hazardous waste", fontdict=font)plt.show()if predict == 1:print("Kitchen waste")plt.imshow(o_img)plt.text(0, -6.0, "Kitchen waste", fontdict=font)plt.show() if predict == 2:print("Other garbage")plt.imshow(o_img)plt.text(0, -6.0, "Other garbage", fontdict=font)plt.show() if predict == 3:print("Recyclable garbage")plt.imshow(o_img)plt.text(0, -6.0, "Recyclable garbage", fontdict=font)plt.show() if __name__ == "__main__":test('data/test/Hazardous waste/2.jpg')
最终验证集的准确率可达到70%
附几张测试成功的图片
resnet50网络实现垃圾分类相关推荐
- 基于ResNet50网络的简单垃圾分类网络
前言: 偶然看到一个垃圾分类的文章,感觉很有趣,利用作者开源的数据集训练一个用于垃圾分类的ResNet50网络,回顾一下网络结构熟悉迁移学习的思想. 原始文章:How to build an imag ...
- 深度学习之基于Tensorflow2.0实现ResNet50网络
理论上讲,当网络层数加深时,网络的性能会变强,而实际上,在不断的加深网络层数后,分类性能不会提高,而是会导致网络收敛更缓慢,准确率也随着降低.利用数据增强等方法抑制过拟合后,准确率也不会得到提高,出现 ...
- PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
需要数据集和源码请点赞关注收藏后评论区留言~~~ 一.实验数据准备 我们使用的是MIT67数据集,这是一个标准的室内场景检测数据集,一个有67个室内场景,每类包括80张训练图片和20张测试图片 读者可 ...
- 基于keras与tensorflow手工实现ResNet50网络
前言 在文章 基于tensorflow的ResNet50V2网络识别动物,我们使用了keras已经提供的神经网络,完成了图像分类的.这个时候,小明同学就问了,那么我怎么自己去写一个神经网络来进行训练呢 ...
- 使用PyTorch搭建ResNet50网络
ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet101 ...
- 快速部署resnet50网络训练并用于图片检测
今天进行了一项图片分类工作用到了resnet50来进行,现对代码进行记录.本文前半部分中快速搭建resnet网络进行训练这一部分代码主要参考博客使用 resnet50 网络训练多分类模型完整代码,我在 ...
- ResNet-50网络理解
本文主要针对ResNet-50对深度残差网络进行一个理解和分析 ResNet已经被广泛运用于各种特征提取应用中,当深度学习网络层数越深时,理论上表达能力会更强,但是CNN网络达到一定的深度后,再加深, ...
- pytorch 搭建cnn resnet50网络进行图片分类 代码详解
数据样式: 直接上代码: import pathlib import tensorflow as tf import matplotlib.pyplot as plt import os, PIL, ...
- 小白入门计算机视觉系列——ReID(二):baseline构建:基于PyTorch的全局特征提取网络(Finetune ResNet50+tricks)
ReID(二):baseline构建:基于PyTorch的全局特征提取网络(Finetune ResNet50+tricks) 本次带来的是计算机视觉中比较热门的重点的一块,行人重识别(也叫Perso ...
最新文章
- CentOS 6.5 下Vim 配置图解
- 联合国隐私监督机构:大规模信息监控并非行之有效
- stc单片机入门c语言,谈谈单片机入门
- 基于visual Studio2013解决C语言竞赛题之1012连接字符串
- Promise 源码:静态方法
- 201671010139 徐楠 关于Java的一些体验
- python的json模块
- 文件的读写学习笔记和我的第一个网页
- 想问一下系统数据库工程师自学能过吗?
- J2EE是什么(二)
- android毫秒数转换为时分秒,如何将毫秒转换成单独的时分秒的形式?
- c加加语言complex的用法,complex的用法总结大全
- 解决node环境下SyntaxError: Cannot use import statement outside a module的问题
- 粤嵌实习-linux下madplay播放器的下载和使用、线程的介绍和创建一个广告循环播放线程
- APP服务器被恶意攻击怎么办?
- 【Visual Studio 2019 - Unknown override specifier error】Problems when compiling dbghelp.h
- 马云:如何组建自己的技术团队
- ElGamal加密体制
- 外国文学最佳中译本合集-选
- 基于python的全部开源的快速开发平台