视频理解TSM的训练与使用

tsm的github地址

总体评价:tsm是一个理解不难但效果优秀的视频理解模型,在我的视频分类任务中,其效果基本达到了使用要求。相比我在github上跑通的其他模型,tsm是最好的。百度团队在不久前也推出了pp-tsm,精度相比tsm提升了几个百分点,我也克隆并调试了,不过非常惭愧,训练模型没有跑通,以后有时间的话会再进行研究。

训练

训练方面我也是借鉴了其他优秀作者的建议,这里给出链接,大家可以参考他的步骤开始自己的训练。链接地址
先要强调的是,本人modality选择的是“RGB”,没有flow之类,感兴趣的可能要自己研究下了。
可以大概说一下tsm的训练原理,对于一个属于某类的视频,我们通过ffmpeg,或者opencv对视频进行抽帧,将一个视频的每一帧的图片根据排序存储至一个文件夹,在训练的采样阶段,模型对一个文件夹一定随机抽取n张(n默认为8)图片,进行concat操作,将concat后的tentor张量作为输入,视频转图片文件夹的文类作为标签,放入网络进行训练。

训练技巧:
1.更改num_segements:
num_segments即为对每一个视频转图片文件夹的采样张数,对于更多的采样,输入可以包含更多的特征信息,所以一般来说将这个参数增大可以提升模型的性能。
2.更改对图片信息的采样压缩方式:
在原始的tsm的源码中,对训练数据进行采样的是

    train_loader = torch.utils.data.DataLoader(TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,new_length=data_length,modality=args.modality,image_tmpl=prefix,transform=torchvision.transforms.Compose([train_augmentation,Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),normalize,]), dense_sample=args.dense_sample),batch_size=args.batch_size, shuffle=True,num_workers=args.workers, pin_memory=True,drop_last=True)

其中的train_augmention包括:

    def get_augmentation(self, flip=True):if self.modality == 'RGB':if flip:return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]),GroupRandomHorizontalFlip(is_flow=False)])

对与图片的裁减与缩放操作就是在Compose的GroupMultiScaleCrop中实现的。这里展示以下关键代码:

    def __call__(self, img_group):im_size = img_group[0].sizecrop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)for img in crop_img_group]return ret_img_group

我这里之贴出了部分,源码大家可以自己看一下,大概的内容就是,对一张原始的单张图片,在原图中以一定的偏移确定裁减的区域,裁减后再进行resize操作(默认为224x224)。
tsm的demo给出的是手势的识别,在这样的任务前提下,图像的大小以及图像的边缘信息似乎没有那么重要,然而,在更复杂任务的时候,一张图像的边缘也包含了重要的特征信息,且如果图像太小会损失重要的特征信息,根据这两点,我重新写了一个图片压缩的类(名字随意),其将图像resize到制定大小,并通过填充黑边保持原图像的形状。

class GroupScale_hyj(object):  def __init__(self,input_size):self.input_size = input_sizeself.interpolation = Image.BILINEAR# @classmethoddef _black_resize_img(self,ori_img):new_size = self.input_sizeori_img.thumbnail((new_size,new_size))w2,h2 = ori_img.sizebg_img = Image.new('RGB',(new_size,new_size),(0,0,0))if w2 == new_size:bg_img.paste(ori_img, (0, int((new_size - h2) / 2)))elif h2 == new_size:bg_img.paste(ori_img, (int((new_size - w2) / 2), 0))else:bg_img.paste(ori_img, (int((new_size - w2) / 2), (int((new_size - h2) / 2))))return bg_imgdef __call__(self,img_group):ret_img_group = [self._black_resize_img(img) for img in img_group]return ret_img_group

train_loader与val_loader均可使用,为了保证我的显存能够带动,我选取了图片大小为320,替换后的代码为:

    train_loader = torch.utils.data.DataLoader(TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,new_length=data_length,modality=args.modality,image_tmpl=prefix,transform=torchvision.transforms.Compose([GroupScale_hyj(input_size=320),GroupAugmentor(),  #img_augmentor for the train dataGroupRandomHorizontalFlip(is_flow=False),Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),normalize,]), dense_sample=args.dense_sample),batch_size=args.batch_size, shuffle=True,num_workers=args.workers, pin_memory=True,drop_last=True)  # prevent something not % n_GPUval_loader = torch.utils.data.DataLoader(TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,new_length=data_length,modality=args.modality,image_tmpl=prefix,random_shift=False,transform=torchvision.transforms.Compose([GroupScale_hyj(input_size=320),Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),normalize,]), dense_sample=args.dense_sample),batch_size=args.batch_size, shuffle=False,num_workers=args.workers, pin_memory=True)

再贴两张对比图(图一是tsm原始的图片缩放,图二是修改后的)

修改后的可以保存更多的特征信息。
3.保证采样的一串图片其对应的时间一致(个人觉得,欢迎指正):
我们是对一个视频转图片再抽取一定数量的帧(默认为8),如果我们是对一批数据训练的话,那我们应该要保证每一个文件夹抽取的图片所代表的时间长度是固定的,比如我们规定以2s为基本的时间长,那么每个文件夹抽取的第一张到最后一张所经历的时间应该接近2s,意思就是,我在2s的时间里,对该视频的行为进行分类。由此,当我们的视频数据集有不同的fps时,我们就要调整,使得抽取的一串图片经历时间都接近2s。

测试/使用

官方给出了一个手势识别的demo,想要成功运行的话,可以参考我前面给出的作者的博客,亲测有效。
不过更多的,我们想将自己的视频分类任务进行测试,而这方面的参考代码比较少。经过了之前的训练,其实我们需要的就是对读入的视频进行抽帧采样,将图片放入dataset中,经模型输出一个分类向量,将分类向量对应的种类名称写在视频流上显示就行.所以关键其实就是:对视频抽帧采样;初始化模型并加载训练参数;采样图片转成model能接受的输入格式(效果等同于TSN_DATASET)以下是本人针对打架检测的使用代码:

import os
import time
from ops.models import TSN
from ops.transforms import *
import cv2
from PIL import Imagearch = 'mobilenetv2'
num_class = 2
num_segments = 8
modality = 'RGB'
base_model = 'mobilenetv2'
consensus_type='avg'
dataset = 'ucf101'
dropout = 0.1
img_feature_dim = 256
no_partialbn = True
pretrain = 'imagenet'
shift = True
shift_div = 8
shift_place = 'blockres'
temporal_pool = False
non_local = False
tune_from = None#load model
model = TSN(num_class, num_segments, modality,base_model=arch,consensus_type=consensus_type,dropout=dropout,img_feature_dim=img_feature_dim,partial_bn=not no_partialbn,pretrain=pretrain,is_shift=shift, shift_div=shift_div, shift_place=shift_place,fc_lr5=not (tune_from and dataset in tune_from),temporal_pool=temporal_pool,non_local=non_local)model = torch.nn.DataParallel(model, device_ids=None).cuda()
resume = '/home/hyj/桌面/master_projects/temporal-shift-module-master/best_weights/mobilenet_360_93.916/ckpt.pth.tar' #  the last weights
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['state_dict'])
model.eval()#how to deal with the pictures
input_mean = [0.485, 0.456, 0.406]
input_std = [0.229, 0.224, 0.225]
normalize = GroupNormalize(input_mean, input_std)
transform_hyj = torchvision.transforms.Compose([GroupScale_hyj(input_size=320),Stack(roll=(arch in ['BNInception', 'InceptionV3'])),ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),normalize,
])video_path = '/home/hyj/桌面/master_projects/temporal-shift-module-master/test_videos/fight5.mp4'pil_img_list = list()cls_text = ['nofight','fight']
cls_color = [(0,255,0),(0,0,255)]import timecap = cv2.VideoCapture(video_path) #导入的视频所在路径
start_time = time.time()
counter = 0
frame_numbers = 0
training_fps = 30
training_time = 2.5
fps = cap.get(cv2.CAP_PROP_FPS) #视频平均帧率
if fps < 1:fps = 30
duaring = int(fps * training_time / num_segments)
print(duaring)
# exit()state = 0
while cap.isOpened():ret, frame = cap.read()if ret:frame_numbers+=1print(frame_numbers)# print(len(pil_img_list))if frame_numbers%duaring == 0 and len(pil_img_list)<8:frame_pil = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))pil_img_list.extend([frame_pil])if frame_numbers%duaring == 0 and  len(pil_img_list)==8:frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))pil_img_list.pop(0)pil_img_list.extend([frame_pil])input = transform_hyj(pil_img_list)input = input.unsqueeze(0).cuda()out = model(input)print(out)output_index = int(torch.argmax(out).cpu())state = output_index#键盘输入空格暂停,输入q退出key = cv2.waitKey(1) & 0xffif key == ord(" "):cv2.waitKey(0)if key == ord("q"):breakcounter += 1#计算帧数if (time.time() - start_time) != 0:#实时显示帧数cv2.putText(frame, "{0} {1}".format((cls_text[state]),float('%.1f' % (counter / (time.time() - start_time)))), (50, 50),cv2.FONT_HERSHEY_SIMPLEX, 2, cls_color[state],3)cv2.imshow('frame', frame)counter = 0start_time = time.time()time.sleep(1 / fps)#按原帧率播放# time.sleep(2/fps)# observe the outputelse:breakcap.release()
cv2.destroyAllWindows()

打架识别运行效果(我的基准时间为2.5s,即行为进行2.5s后判断其分类,所以从视觉上会感觉到一定的延迟,ubuntu录视频软件不好找,直接手机录了)

最后提醒一下,当你的分类数目少于5时,需要将main.py中top5的代码去掉,不然会报错。欢迎各位讨论与建议。

视频理解TSM的训练与使用相关推荐

  1. 基于视频理解TSM和数据集Kinetics-400的视频行为识别

    基于视频理解TSM和数据集Kinetics-400的视频行为识别 基于视频理解TSM和数据集Kinetics-400的视频行为分类 基于视频理解TSM-mobilenetv2和数据集Kinetics- ...

  2. 基于视频理解TSM和数据集20bn-jester-v1的27类手势识别

    基于视频理解TSM-mobilenetv2和数据集20bn-jester-v1的27类手势识别 基于视频理解TSM-resnet50和数据集20bn-jester-v1的27类手势识别 基于视频理解T ...

  3. 【视频理解论文】——TSM:Temporal Shift Module for Efficient Video Understanding

    TSM: Temporal Shift Module for Efficient Video Understanding(ICCV2019) 这是一篇关于视频理解的文章,主要介绍了一种可以达到3DCN ...

  4. 自动分类打标签!飞桨TSM模型帮你做视频理解

    导读:目前互联网视频数据日益增多,用户观看短视频.小视频的时长也迅速增长,如何对海量的视频资源快速准确地分析.处理.归类是一个亟待解决的问题.视频理解技术可以多维度解析视频内容,理解视频语义,自动分类 ...

  5. VLM:Meta AI CMU提出任务无关视频语言模型视频理解预训练VLM,代码已开源!(ACL 2021)...

    关注公众号,发现CV技术之美 本文分享 ACL 2021 论文『VLM: Task-agnostic Video-Language Model Pre-training for Video Under ...

  6. 数行代码训练视频模型,PyTorch视频理解利器出炉

    本文转自机器之心. Facebook人工智能实验室在 PySlowFast 之后时隔两年,携 PyTorchVideo 重入战场. 视频作为当今最被广为使用的媒体形式,已逐渐占超过文字和图片,据了人们 ...

  7. ​MMIT冠军方案 | 用于行为识别的时间交错网络,商汤公开视频理解代码库

    作者 | 商汤 出品 | AI科技大本营(ID:rgznai100) 本文主要介绍三个部分: 一个高效的SOTA视频特征提取网络TIN,发表于AAAI2020 ICCV19 MMIT多标签视频理解竞赛 ...

  8. 管中窥“视频”,“理解”一斑 —— 视频理解概览

    ©PaperWeekly 原创 · 作者|Lingyun Zeng 学校|北京航空航天大学 研究方向|计算机视觉 本文通过对视频理解/分类(Video Understanding/Classifica ...

  9. AAAI 2020 时间交错网络 | ICCV19多标签视频理解冠军方案

    本文主要介绍三个部分: 一个高效的 SOTA 视频特征提取网络 TIN,发表于 AAAI 2020 ICCV19 MMIT 多标签视频理解竞赛冠军方案,基于 TIN 和 SlowFast 一个基于 P ...

最新文章

  1. zzzp0371 属于本人
  2. 【开源】WeChatRobot+WeChatHelper 制作自己的微信机器人
  3. 【JavaSE04】Java中循环语句for,while,do···while
  4. 基于 Laravel Route 的 ThinkSNS+ Component
  5. 对VOC目标检测数据进行增强
  6. 【MFC】MFC对话框类
  7. SAP BTP Kyma Runtime dashboard 打开报缺少缺陷的错误消息该如何解决
  8. Zigbee 电动智能窗帘系统 解决方案
  9. 随笔--互联网进化论
  10. 昨夜洪峰抵达主城,重庆人是这么过的......
  11. php $rs1- gt eof,PHP_PHP速成大法,简单介绍一下PHP的语法 1、嵌 - phpStudy
  12. spring和spring_Spring MVCGradle
  13. JavaScript网页特效---对联广告,网站对联广告
  14. 推荐:年度巨献:《Ubuntu桌面生存指南》(作者:ghosert)
  15. 已解决-内部版本7601 此windows副本不是正版
  16. blankcount函数python,统计函数第五讲:计数函数COUNT和COUNTBLANK
  17. 文件服务器不能打印,服务器不能用作打印服务器 - Windows Server | Microsoft Docs
  18. 树莓派操控SG90舵机
  19. 【Arduino】入门篇——人体红外自动报警
  20. Python 打印购物小票

热门文章

  1. 百度工具问题以下对URL规则的阐述,哪些是错误的
  2. 软件的各个版本和英文缩写
  3. 程序员从互联网转行公务员:工资一万多变四千,但过得美滋滋
  4. 忆阻器课题 读书笔记(二)
  5. 如何用三年时间获得十年工作经验?
  6. 数学菜菜的2023秋季风雨失业路
  7. ionic中的ToastController小弹窗用法。提示信息。toast长时间不消失解决方案
  8. 炎炎夏日适合在屋里学习深度学习
  9. 2020年第十届C/C++ A组第一场蓝桥杯省赛真题
  10. Milogs正式发布工作日志管理软件