常见图像处理的任务

(1)分类

给定一幅图像,我们用计算机模型预测图片中有什么对象 。

(2)分类加定位

不仅需要我们知道图片中的对象是什么,还要在对象的附近画一个边框,确定该对象所处的位置。

(3)语义分割

区分到图中每一个像素点,而不仅仅是矩形框框住。

(4)目标检测

目标检测简单来说就是回答图片里面有什么?分别在哪里?(把它们用矩形框框住)。

(5)实例分割

实例分割是目标检测和语义分割的结合。相对目标检测的边界框,实例分割可精确到物体的边缘;相对语义分割,实例分割需要标注出图上同一物体的不同个体

图像定位

对于单纯的分类问题,比较容易理解,给定一幅图片,我们输出一个标签类别,我们已经跟熟悉。

而定位有点复杂,需要输出四个数字(x,y,w,h),图像中某一个点的坐标(x,y),以及图像的宽度和高度,有了这四个数字,我们可以很容易地找到物体的边框。

简单定位网络架构

本质是回归问题,使用L2损失进行优化。

Oxford-IIIT数据集

The Oxford-IIIT Pet Dataset是一个宠物图像数据集,包含37种宠物,每种宠物200张左右宠物图片,该数据集同时包含宠物分类、头部轮廓标注和语义分割信息

代码实战

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
from lxml import etree  # etree网页解析模块
from matplotlib.patches import Rectangle    # Rectangle画矩形
import glob # 获取所有路径
from PIL import Image   # 读取图像

单张图片演示

BATCH_SIZE = 8
print('一、图片解析演示')
pil_img = Image.open(r'dataset/images/Abyssinian_1.jpg')
np_img = np.array(pil_img)  # 转换成ndarray形式
print(np_img.shape)
plt.imshow(np_img)
plt.show()# 打开对应的xml图像定位信息文件
xml = open(r'dataset/annotations/xmls/Abyssinian_1.xml').read()
# 使用Etree库解析xml文件
sel = etree.HTML(xml)   # 首先创建一个解析对象
width = sel.xpath('//size/width/text()')[0]   # 使用xpath方法获取它的位置。
# //表示从根目录查找,width/text()表示查找width标签里的文本,返回对象是一个列表,所以要切片处理
height = sel.xpath('//size/height/text()')[0]
print(width,' ',height)# 查找头部所在的像素位置
xmin = sel.xpath('//bndbox/xmin/text()')[0]
ymin = sel.xpath('//bndbox/ymin/text()')[0]
xmax = sel.xpath('//bndbox/xmax/text()')[0]
ymax = sel.xpath('//bndbox/ymax/text()')[0]# 将获取到的文本转换成整数。
#转换数据类型,因为matplotlib只能接受整形标注它的位置
width = int(width)
height = int(height)
xmin = int(xmin)
ymin = int(ymin)
xmax = int(xmax)
ymax = int(ymax)plt.imshow(np_img)
rect = Rectangle((xmin, ymin), (xmax-xmin), (ymax-ymin), fill=False, color='blue')   # 实例化Rectangle对象,需要标注它的一些位置
# 参数1是xy,即最小值所在的点;第二、三个参数是width、height;fill=False表示不需要填充矩形。
ax = plt.gca()  # 获取当前坐标系
ax.axes.add_patch(rect) # 在当前坐标系上添加矩形框。
plt.show()


在原始数据集中,各张图片大小不一,我们在输入模型时,想要把它变成固定的大小。但是,图像改变尺寸后,对应的xmin和ymin就不对了,因为xmin和ymin是相对于原先图片的大小。实际上,我们可以将它转换为一个比值就可以了

# 例如:
img = pil_img.resize((224,224))
xmin = (xmin/width)*224
ymin = (ymin/height)*224
xmax = (xmax/width)*224
ymax = (ymax/height)*224plt.imshow(img)
rect = Rectangle((xmin, ymin), (xmax-xmin), (ymax-ymin), fill=False, color='red')
ax = plt.gca()  # 获取当前坐标系
ax.axes.add_patch(rect) # 在当前坐标系上添加矩形框。
plt.show()

输出是比值,使用比值作为目标值。


创建输入

images = glob.glob('dataset/images/*.jpg')  # 返回类型是列表。返回在images目录下所有以 jpg 结尾的文件的路径
xmls = glob.glob('dataset/annotations/xmls/*.xml')len(images) #7390
len(xmls) # 3686
# 这说明了数据集并没有对全部的图片进行标注"""
我们不知道对哪些图片做了标注,为了取出标注的图片;
我们要将这些被标注数据的文件名 ;
也即'dataset/annotations/xmls\\Abyssinian_1.xml'中的Abyssinian_1取出来;
然后使用文件名对原有的图片进行一个筛选。
"""xmls_names = [x.split('\\')[-1].split('.xml')[0] for x in xmls] # 获取到了所有被标注图片的文件名
# xmls_names = [x.split('\\')[-1].replace('.xml','') for x in xmls] 这种办法和上面的效果一样len(xmls_names) # 3686# 根据标注图片的文件名对所有图片进行一个筛选
imgs = [img for img in images if img.split('\\')[-1].split('.jpg')[0] in xmls_names]
len(imgs) #3686# 在创建输入之前,要保证图片和标注信息是一一对应的
print('len(imgs)==len(xmls_names)?:',len(imgs)==len(xmls_names))
print('imgs[:5]:\n',imgs[:5])
print('xmls[:5]:\n',xmls[:5])

将xml文件转换成标签的格式:下面我们需要将xml文件给它转换成标签的形式,在转换之前,我们首先要明确一点,我们的目标值不再是xmin这个实际的值,因为每一张图片的大小都是不一的,这个时候我们只是取出它的一个比例值。我们的预测值是头部宽高度所占的比值。

# 为了将xml列表文件里的数值解析出来,我们专门定义一个toLabel函数
def to_labels(path):xml = open(r'{}'.format(path)).read() # 用格式化形式,加r,防止转义。sel = etree.HTML(xml) # 创建选择器width = int(sel.xpath('//size/width/text()')[0]) height = int(sel.xpath('//size/height/text()')[0])xmin = int(sel.xpath('//bndbox/xmin/text()')[0])ymin = int(sel.xpath('//bndbox/ymin/text()')[0])xmax = int(sel.xpath('//bndbox/xmax/text()')[0])ymax = int(sel.xpath('//bndbox/ymax/text()')[0])return [xmin/width, ymin/height, xmax/width, ymax/height]
labels = [to_labels(path) for path in xmls]
labels[0],type(labels)

# zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
out1_label, out2_label, out3_label, out4_label = list(zip(*labels))len(out1_label), len(out2_label), len(out3_label), len(out4_label)


 划分数据集

index = np.random.permutation(len(imgs))
# img和label都是一一对应的,但是顺序乱序了。
images = np.array(imgs)[index]
labels = np.array(labels)[index]
out1_label = np.array(out1_label).astype(np.float32).reshape(-1, 1)[index]
out2_label = np.array(out2_label).astype(np.float32).reshape(-1, 1)[index]
out3_label = np.array(out3_label).astype(np.float32).reshape(-1, 1)[index]
out4_label = np.array(out4_label).astype(np.float32).reshape(-1, 1)[index]labels = labels.astype(np.float32) # 由于之前做了除法,最好转换成浮点型数据,这样有利于模型不报错
labels.shape # (3686,4)"""
out1_label = out1_label.astype(np.float32)
out2_label = out2_label.astype(np.float32)
out3_label = out3_label.astype(np.float32)
out4_label = out4_label.astype(np.float32)
"""i = int(len(imgs)*0.8) # 训练集比例train_images = images[:i]
train_labels = labels[:i]
out1_train_label = out1_label[:i]
out2_train_label = out2_label[:i]
out3_train_label = out3_label[:i]
out4_train_label = out4_label[:i]test_images = images[i:]
test_labels = labels[i:]
out1_test_label = out1_label[i: ]
out2_test_label = out2_label[i: ]
out3_test_label = out3_label[i: ]
out4_test_label = out4_label[i: ]

创建输入模型

transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])# 创建一个DataSet类
class Oxford_dataset(data.Dataset):def __init__(self, img_paths, out1_label, out2_label,out3_label, out4_label, transform):self.imgs = img_pathsself.out1_label = out1_labelself.out2_label = out2_labelself.out3_label = out3_labelself.out4_label = out4_labelself.transforms = transformdef __getitem__(self, index):img = self.imgs[index]  # 切出来是一条路径out1_label = self.out1_label[index]out2_label = self.out2_label[index]out3_label = self.out3_label[index]out4_label = self.out4_label[index]pil_img = Image.open(img)imgs_data = np.asarray(pil_img, dtype=np.uint8)if len(imgs_data.shape) == 2: # 如果不是rgb图像,就多增加一个维度imgs_data = np.repeat(imgs_data[:, :, np.newaxis], 3, axis=2)img_tensor = self.transforms(Image.fromarray(imgs_data))else:img_tensor = self.transforms(pil_img)return (img_tensor,out1_label,out2_label,out3_label,out4_label)def __len__(self):return len(self.imgs)train_dataset = Oxford_dataset(train_images, out1_train_label,out2_train_label, out3_train_label,out4_train_label, transform)test_dataset = Oxford_dataset(test_images, out1_test_label,out2_test_label, out3_test_label,out4_test_label, transform)train_dl = data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,
)test_dl = data.DataLoader(test_dataset,batch_size=BATCH_SIZE,
)(imgs_batch,
out1_batch,
out2_batch,
out3_batch,
out4_batch) = next(iter(train_dl))imgs_batch.shape, out1_batch.shape

plt.figure(figsize=(12, 8))
for i,(img, label1, label2,label3,label4,) in enumerate(zip(imgs_batch[:2],out1_batch[:2], out2_batch[:2], out3_batch[:2], out4_batch[:2])):img = (img.permute(1,2,0).numpy() + 1)/2  # permute交换维度。plt.subplot(2, 3, i+1)plt.imshow(img)xmin, ymin, xmax, ymax = label1*224, label2*224, label3*224, label4*224, rect = Rectangle((xmin, ymin), (xmax-xmin), (ymax-ymin), fill=False, color='red')ax = plt.gca()ax.axes.add_patch(rect)


创建定位模型 

resnet = torchvision.models.resnet101(pretrained=True)  # 使用卷积基提取特征。使用预训练参数作为初始化参数
"""
resnet101里面包含很多层,conv、batch.........
最后是avgpool和fc全连接层。
avgpool之前的层都是我们需要的"""in_f = resnet.fc.in_features    # 全连接层的输入
print(in_f) #2048resnet.children()   # 会返回所有层的生成器
list(resnet.children()) # 使用list将它返回
print(len(list(resnet.children()))) # 一共包含10个子层
print(list(resnet.children())[-1])  # 看看最后一层
list(resnet.children())[:-1]    # 这些层才是我们需要的,帮助我们提取特征。
conv_base = nn.Sequential(*list(resnet.children())) # *代表解包。这样相当于每一层都列在了Sequential里面class Net(nn.Module):def __init__(self):super(Net, self).__init__() # 继承父类属性self.conv_base = nn.Sequential(*list(resnet.children())[:-1])# 使用全连接模型。分别输出4个坐标值self.fc1 = nn.Linear(in_f, 1)   # 输出是1,因为我们要输出一个标量值。self.fc2 = nn.Linear(in_f, 1)self.fc3 = nn.Linear(in_f, 1)self.fc4 = nn.Linear(in_f, 1)def forward(self, x):x = self.conv_base(x)   # 提取特征x = x.view(x.size(0), -1)x1 = self.fc1(x)x2 = self.fc2(x)x3 = self.fc3(x)x4 = self.fc4(x)return x1, x2, x3, x4model = Net()
if torch.cuda.is_available():model.to('cuda')
loss_fn = nn.MSELoss()  # 回归问题,并不是分类问题。误差函数取平均绝对误差,MSE损失函数from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)   #每次间隔7步,学习速率衰减0.1def fit(epoch, model, trainloader, testloader):total = 0running_loss = 0model.train()for x, y1, y2, y3, y4 in trainloader:if torch.cuda.is_available():x, y1, y2, y3, y4 = (x.to('cuda'), y1.to('cuda'), y2.to('cuda'),y3.to('cuda'), y4.to('cuda'))       y_pred1, y_pred2, y_pred3, y_pred4 = model(x)loss1 = loss_fn(y_pred1, y1)loss2 = loss_fn(y_pred2, y2)loss3 = loss_fn(y_pred3, y3)loss4 = loss_fn(y_pred4, y4)loss = loss1 + loss2 + loss3 + loss4optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():running_loss += loss.item()exp_lr_scheduler.step()epoch_loss = running_loss / len(trainloader.dataset)test_total = 0test_running_loss = 0 model.eval()with torch.no_grad():for x, y1, y2, y3, y4 in testloader:if torch.cuda.is_available():x, y1, y2, y3, y4 = (x.to('cuda'), y1.to('cuda'), y2.to('cuda'),y3.to('cuda'), y4.to('cuda'))y_pred1, y_pred2, y_pred3, y_pred4 = model(x)loss1 = loss_fn(y_pred1, y1)loss2 = loss_fn(y_pred2, y2)loss3 = loss_fn(y_pred3, y3)loss4 = loss_fn(y_pred4, y4)loss = loss1 + loss2 + loss3 + loss4test_running_loss += loss.item()epoch_test_loss = test_running_loss / len(testloader.dataset)print('epoch: ', epoch, 'loss: ', round(epoch_loss, 3),'test_loss: ', round(epoch_test_loss, 3),)return epoch_loss, epoch_test_loss

开始训练

epochs = 10 # 总共训练十次train_loss = []
test_loss = []for epoch in range(epochs):epoch_loss, epoch_test_loss = fit(epoch, model, train_dl, test_dl)train_loss.append(epoch_loss)test_loss.append(epoch_test_loss)plt.figure()
plt.plot(range(1, len(train_loss)+1), train_loss, 'r', label='Training loss')
plt.plot(range(1, len(train_loss)+1), test_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.show()


模型保存

PATH = 'location_model.pth'
torch.save(model.state_dict(), PATH) # 保存的是权重,可训练参数。plt.figure(figsize=(8, 24))
imgs, _, _, _, _ = next(iter(test_dl))  # _占位符。意思是我不需要你实际的位置,我们要自己去预测。
imgs = imgs.to('cuda')  # 将图片添加到显卡
out1, out2, out3, out4 = model(imgs)    # 进行预测,返回四个坐标值(头部位置)
for i in range(6):plt.subplot(6, 1, i+1)plt.imshow(imgs[i].permute(1,2,0).cpu().numpy())    # 放到cpu上xmin, ymin, xmax, ymax = (out1[i].item()*224,   # out[i]代表第i个batch的第一个位置out2[i].item()*224,out3[i].item()*224,out4[i].item()*224)rect = Rectangle((xmin, ymin), (xmax-xmin), (ymax-ymin), fill=False, color='red')ax = plt.gca()ax.axes.add_patch(rect)plt.show()

PyTorch 11—简单图像定位相关推荐

  1. PyTorch入门(三)--实现简单图像分类器

    实现简单图像分类器 1. 数据加载 1.1 常用公共数据集加载 1.2 私人数据集加载方法 2. 定义神经网络 3. 定义权值更新与损失函数 4. 训练与测试神经网络 5. 神经网络的保存与载入 本篇 ...

  2. 【项目实战课】基于Pytorch的SRGAN图像超分辨实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的SRGAN图像超分辨实战>.所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战 ...

  3. ECCV 2020 Spotlight | 图像定位上的细粒化区域相似性自监督

    ©PaperWeekly · 作者|葛艺潇 学校|香港中文大学博士生 研究方向|图像检索.图像生成等 本文介绍一篇我们发表于 ECCV 2020 的论文,很荣幸该论文被收录为 spotlight pr ...

  4. ECCV2020 Spotlight | 图像定位上的细粒化区域相似性自监督

    本文转载自知乎,作者为香港中文大学MMLab博士生葛艺潇,已获作者授权转载. https://zhuanlan.zhihu.com/p/169596514 本文介绍一篇我们发表于ECCV 2020的论 ...

  5. PyTorch搭建简单神经网络实现回归和分类

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 安装 PyTorch 会安装两个模块,一个是torch,一个 torchvision, tor ...

  6. 学习C++项目——一个基于C++11简单易用的轻量级网络编程框架 1

    一个基于C++11简单易用的轻量级网络编程框架 一.项目下载.导入.编译和运行   现在准备深入学习 C++ ,先肝一个项目,这个项目是<一个基于C++11简单易用的轻量级网络编程框架>, ...

  7. pyautogui脱离屏幕基于图片的图像定位

    用过pyautogui的同学应该都知道,locateOnScreen可以传入一张图片与当前屏幕(截屏)进行匹配,但是我的需求是能基于已经截屏的图片来进行图像定位,看了pyautogui的源码没有相关的 ...

  8. 简单WIFI定位分析与比较(文末有手机采集RSSI软件推荐!!!)

    导师是做室内定位的,最近让我复现一篇论文中基于图像和WiFi融合定位的实现方法,目前图像定位已经简单的实现了,采用HOG+SVM.然鹅定位效果并不理想,但也能得出个大概的位置. 下面就开始搞基于WIF ...

  9. UA PHYS515A 电磁理论V 电磁波与辐射11 简单辐射问题 电偶极子的辐射

    UA PHYS515A 电磁理论V 电磁波与辐射11 简单辐射问题 电偶极子的辐射 一对带等量相反电量的点电荷构成一对电偶极子,假设电量为qqq,两个点电荷的距离为aaa,dipole moment为 ...

最新文章

  1. cv2.threshold() 阈值:使用Python,OpenCV进行简单的图像分割
  2. gorm软删除_gorm踩坑:软删除与某个字段的唯一性
  3. Python2 Python3 爬取赶集网租房信息,带源码分析
  4. mysql修改主键生成策略信息_常用Hibernate 主键生成策略
  5. [Nginx]nginx 配置实例-负载均衡
  6. JS的三大组成(ES,DOM,BOM)
  7. ecms 列表模板php,帝国CMS列表页模板list.var分别调用年月日(显示个性时间日期)
  8. 【转】字符编码笔记:ASCII、Unicode、UTF-8 和 Base64
  9. utilities(matlab)—— PSNR 值的计算
  10. 关于Java书籍的最佳阅读顺序
  11. 关于2020年各省市GDP和各省人均GDP的探索
  12. imx6ul mqs 音频爆破音
  13. nginx 的proxy 时间讲解
  14. 曾鸣分享:阿里集团及阿里眼里的电子商务(瑞士信贷中国投资年会)
  15. Apache Velocity 模板语言 特殊字符${ $!{ 原样输出问题 转义符 # ! 无效
  16. 分享几个好用的导航导航网站
  17. c语言一维数组n个元素求和,C++编程一维数组元素求和?
  18. iMac2021 在重新安装mac os系统后,电脑账户创建失败
  19. IDEA显示树状目录结构
  20. php extraxt,php中关于extract方法的使用和建议

热门文章

  1. snipaste截图软件编辑时修改方框粗细
  2. 【YOLOv5 数据集划分】训练和验证、训练验证和测试(train、val)(train、val、test)
  3. (USB:VCP+HID复合设备与系统配置)
  4. 人人自媒体的时代,程序员该如何利用好自己的优势?我记住了这些神器...
  5. bilibili缓存文件在哪里_不要再胡乱清理手机内存,花1分钟删掉这些文件夹,释放大量空间...
  6. [笑语天下]风景、照片与评论古今
  7. Windows登录日志详解
  8. Journal of Electronic Imaging 投稿分享
  9. 如何在控制台创建文件夹
  10. 使用aria2为网盘下载加速