数据及具体讲解来源:
基于PyTorch搭建CNN实现视频动作分类任务

import torch
import torch.nn as nn
import torchvision.transforms as T
import scipy.io
from torch.utils.data import DataLoader,Dataset
import os
from PIL import Image
from torch.autograd import Variable
import numpy as np"""
加载数据
"""
#获得标签
label_mat = scipy.io.loadmat('./datasets/q3_2_data.mat')
#获得训练集标签
label_train = label_mat['trLb']
print(len(label_train))
#获得验证集标签
label_val = label_mat['valLb']
print(len(label_val))"""
通过Dataset类进行数据预处理
"""
class ActionDataset(Dataset):def __init__(self,root_dir,labels = [],transform=None):"""Args::param root_dir: 数据路径:param labels: 图片标签:param transform: 数据处理函数"""self.root_dir = root_dirself.transform = transformself.length = len(os.listdir(self.root_dir))self.labels = labelsdef __len__(self):  #返回数据数量return self.length*3    #一个视频片段包含3帧(3个图片)def __getitem__(self, idx): #图片处理及返回数据folder = idx//3+1   #判断该帧属于第几个视频中imidx = idx%3 + 1   #判断该帧在该视频中是第几帧folder = format(folder,'05d')   #将folder格式化,05d代表五位数,若不到五位用0填充imgname = str(imidx) + '.jpg'img_path = os.path.join(self.root_dir,folder,imgname)image = Image.open(img_path)"""当输入标签有值时,说明是训练集和验证集,输出的样本也是有标签的,若没有值,说明是测试集,输出的样本是没有标签的"""if len(self.labels)!=0:Label = self.labels[idx//3][0]-1#如果有对数据的处理先对数据进行处理if self.transform:image = self.transform(image)if len(self.labels)!=0:sample = {'image':image,'img_path':img_path,'Label':Label}else:sample = {'image':image,'img_path':img_path}return sampleimage_datast = ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
# torchvision.transforms中定义了非常多对图像的预处理方法,这里使用的ToTensor方法为将0~255的RGB值映射到0~1的Tensor类型。
# #测试一下
# for i in range(3):
#     sample = image_datast[i]
#     print(sample['image'].shape)
#     print(sample['Label'])
#     print(sample['img_path'])"""
Dataloader类进行封装
注意:Windows不要用num_works
"""
#image_dataloader = DataLoader(image_datast,batch_size=4,shuffle=True)
# for i , sample in enumerate(image_dataloader):
#     #enumerate(iteration, start):返回一个枚举的对象
#     sample['image'] = sample['image']
#     print(sample[i,sample['image'].shape,sample['img_path'],'Label'])
#     if i == 6:
#         break
#训练集
image_dataset_train=ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
image_dataloader_train = DataLoader(image_dataset_train, batch_size=32,shuffle=True)
#验证集
image_dataset_val=ActionDataset(root_dir='./datasets/valClips/',labels=label_val,transform=T.ToTensor())
image_dataloader_val = DataLoader(image_dataset_val, batch_size=32,shuffle=False)
#测试集:没有给定labels
image_dataset_test=ActionDataset(root_dir='./datasets/testClips/',labels=[],transform=T.ToTensor())
image_dataloader_test = DataLoader(image_dataset_test, batch_size=32,shuffle=False)"""
搭建模型
"""dtype = torch.FloatTensor # 这是pytorch所支持的cpu数据类型中的浮点数类型。print_every = 100   # 这个参数用于控制loss的打印频率,因为我们需要在训练过程中不断的对loss进行检测。def reset(m):   # 这是模型参数的初始化if hasattr(m, 'reset_parameters'):m.reset_parameters()#数据解释和处理
class Flatten(nn.Module):def forward(self, x):N, C, H, W = x.size() # 读取各个维度。return x.view(N, -1)  # -1代表除了特殊声明过的以外的全部维度。fixed_model_base = nn.Sequential(nn.Conv2d(3,8,kernel_size=7,stride=1),   ##3*64*64 -> 8*58*58nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2),    # 8*58*58 -> 8*29*29nn.Conv2d(8, 16, kernel_size=7, stride=1), # 8*29*29 -> 16*23*23nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2), # 16*23*23 -> 16*11*11Flatten(),nn.ReLU(inplace=True),nn.Linear(1936, 10)     # 1936 = 16*11*11
)
fixed_model = fixed_model_base.type(dtype)  #将模型数据转换成pytorch所支持的cpu数据类型中的浮点数类型。
# #测试:
# x = torch.randn(32, 3, 64, 64).type(dtype)
# x_var = Variable(x.type(dtype)) # 需要将其封装为Variable类型。
# ans = fixed_model(x_var)
# print(np.array(ans.size())) # 检查模型输出。
# np.array_equal(np.array(ans.size()), np.array([32, 10]))"""
训练步骤及模块
"""
optimizer = torch.optim.RMSprop(fixed_model_base.parameters(), lr = 0.0001)
loss_fn = nn.CrossEntropyLoss()def train(model,loss_fn,optimizer,dataloader,num_epoch = 1):for epoch in range(num_epoch):check_accuracy(fixed_model,image_dataloader_val)    #在验证集验证模型效果model.train()   #模型的.train()方法让模型进入训练模式,参数保留梯度,dropout层等部分正常工作for t,sample in enumerate(dataloader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'].long())scores = model(x_var)   #得到输出loss = loss_fn(scores,y_var)if (t+1)%print_every ==0:print('t = %d, loss = %.4f' % (t + 1, loss.item()))#三步更新optimizer.zero_grad()loss.backward()optimizer.step()def check_accuracy(model,loader):num_correct = 0num_samples = 0model.eval()    # 模型的.eval()方法切换进入评测模式,对应的dropout等部分将停止工作。for t,sample in enumerate(loader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'])scores = model(x_var)_,preds = scores.data.max(1)    # 找到可能最高的标签作为输出。num_correct += (preds.numpy() == y_var.numpy()).sum()num_samples += preds.size(0)acc = float(num_correct)/num_samplesprint('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))"""
训练并验证
"""
torch.random.manual_seed(54321)
fixed_model.cpu()
fixed_model.apply(reset)
#pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身
fixed_model.train()
train(fixed_model, loss_fn, optimizer,image_dataloader_train, num_epoch=5)
check_accuracy(fixed_model, image_dataloader_val)"""
测试
"""def predict_on_test(model, loader):model.eval()results = open('results.csv', 'w')  # 模型预测结果会被放在这里。count = 0results.write('Id' + ',' + 'Class' + '\n')for t, sample in enumerate(loader):x_var = Variable(sample['image'])scores = model(x_var)_, preds = scores.data.max(1)for i in range(len(preds)):results.write(str(count) + ',' + str(preds[i]) + '\n')count += 1results.close()return countcount = predict_on_test(fixed_model, image_dataloader_test)  # 放入你想要测试的训练集,然后打开文件去看一看结果吧。
print(count)

基于PyTorch搭建CNN实现视频动作分类任务代码详解相关推荐

  1. 基于PyTorch搭建CNN实现视频动作分类任务 有数据有代码 可直接运行

    目录 介绍 任务描述 数据集 运行环境 模型概述

  2. 基于Keras搭建CNN、TextCNN文本分类模型

    基于Keras搭建CNN.TextCNN文本分类模型 一.CNN 1.1 数据读取分词 1.2.数据编码 1.3 数据序列标准化 1.4 构建模型 1.5 模型验证 二.TextCNN文本分类 2.1 ...

  3. 基于CNN实现视频动作分类任务

    前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...

  4. 做什么样的视频收益高?自媒体视频各分类领域收益详解

    做自媒体做什么领域视频的收益高呢?今天大周再来给你们分享一点干货,记得点赞收藏起来哦! 一.涨粉慢变现好 这类领域有旅游领域.同城领域.美妆领域.美食领域.开箱领域,这几个领域并不需要你有多少粉丝,足 ...

  5. 在PyTorch中进行双线性采样:原理和代码详解

    ↑ 点击蓝字 关注视学算法 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/257958558 编辑丨极市平台 在pytorch中的双线性采样(Bilinear Sa ...

  6. [机器学习与scikit-learn-15]:算法-决策树-分类问题代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址: 目录 第1章 scikit-learn对决策树的支持 1.1 决策树的基本原理 1.2 决策树的 ...

  7. 基于蚁群算法(ACO)的函数寻优代码详解

    前言   蚁群算法与遗传算法一并属于启发式算法,其原理有一定的相似性.   蚁群算法的仿生原理可以这样举例:在不远处的地上有一块奶糖,这时候你用手放个蚂蚁在地上,在无其他因素影响的情况下,这只蚂蚁会爬 ...

  8. CNN(卷积神经网络)在视频动作分类中的应用

    简介 Large-scale Video Classification with Convolutional Neural Networks Fusion Method Multi-resolutio ...

  9. Kaggle猫狗大战——基于Pytorch的CNN网络分类:数据获取、预处理、载入(1)

    Kaggle猫狗大战--基于Pytorch的CNN网络分类:数据获取.预处理.载入(1) 第一次写CSDN博客,之前一直是靠着CSDN学学代码,这次不得不亲自上场了,就想着将学习的过程都记录下来.新人 ...

最新文章

  1. 产销平衡的运输问题上机实验matlab_MATLAB实验上机练习(三)
  2. js中写java集合代码,JS实现JAVA的List功能
  3. 1 分钟记住 docker 镜像和容器常用基本命令
  4. 信息学奥赛一本通(1051:分段函数)
  5. Leetcode 1219.黄金矿工
  6. 深度学习TF—3.神经网络全连接层
  7. 3S基础知识:MapX应用教程—创建地图对象
  8. [转载]关于wm系统同步时ActiveSync出现85010014错误号的解决办法
  9. 2021年笔迹鉴定收费?江西南昌收费标准是什么?
  10. 《可以量化的经济学》凯恩斯主义与…
  11. 物联网全景动态图谱2.0|PaaS物联网平台汇总(上篇)
  12. 如何简单的管理API
  13. 20届春秋招数据分析面筋分享
  14. npm scripts
  15. weixuan -奥利给turtle
  16. 外网连接腾讯云mysql
  17. Script Insertion -客户端脚本植入攻击
  18. 面试系列——Java工作6年面试拼多多和阿里经历附带面试题
  19. 博贤科技管理系统漏洞oday
  20. Mysql-DQL基本语法

热门文章

  1. C 结构体嵌套一级指针 二级指针 动态分配内存
  2. 【Linux系统编程学习】 动态库的制作与使用
  3. Leetcode 102. 二叉树的层次遍历
  4. 学生档案c语言编程,学生档案管理问题
  5. 右键添加git-bash
  6. LeetCode-46. Permutations
  7. 微信小程序把玩(三十三)Record API
  8. WorldWind源码剖析系列:可渲染对象类RenderableObject
  9. c# datagridviewcomboboxcell值无效的解决办法
  10. 这是我们的第一篇博客----偕行软件