一. 数据集介绍

  1. 数据集可以在kaggle上面下载地址.
  2. 识别的手势是26个英文字母,如下所示,图片中好像缺少了一个Z。
  3. 数据是csv格式的,第一列是label,其余的像素,也就是我们的图片数据,28*28大小的。

二. 构造DataLoader

  1. 由于数据是csv格式,和之前图片的有所不同,之前直接重写Dataset类就可以加载我们的数据,这次其实可以重写Dataset类,但是我这样做的时候,训练的时候准确率一直只有0.07左右,这次我们需要使用另外一种比较简单的方法,来构造数据集

  2. 利用TensorDataset函数将tensor数据变成TensorDataset数据,也就是将数据变成pytorch可以分批次加载的数据。

train = pd.read_csv('../input/sign-language-mnist/sign_mnist_train.csv',dtype=np.float32)    #读取csvy = train.label.values                           #label
x = train.loc[:,train.columns!= 'label'].values / 255     #图片数据train_x,test_x,train_y,test_y = train_test_split(x,y,test_size=0.2,random_state=42)     #分成数据集,测试集
print(test_x.shape)
train_x = torch.from_numpy(train_x)          #转换成tensor
train_y = torch.from_numpy(train_y).type(torch.LongTensor)    #转换成Longtensor,交叉熵损失函数的label需要long类型的数据test_x = torch.from_numpy(test_x)
test_y = torch.from_numpy(test_y).type(torch.LongTensor)train = TensorDataset(train_x,train_y)     #转换成Dataset,类似于TensorFlow的from_tensor_slices
test = TensorDataset(test_x,test_y)train_loader = DataLoader(train,batch_size=100,shuffle=False,drop_last=True)  # 构造DataLoader
test_loader = DataLoader(test,batch_size=100,shuffle=False,drop_last=True)

三. 构造网络

  1. 构造网络这比较简单,由于没有合适的GPU,只能使用一些比较简单的网络。
  2. 这里也可以使用残差网络,随便学一下什么是残差网络,残差网络在很深的网络中经常用到,非常重要的一个网络结构。
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.layers1 = nn.Sequential(nn.Conv2d(1,16,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(16),nn.ReLU(inplace=True))self.layers2 = nn.Sequential(nn.Conv2d(16,32,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2))self.layers3 = nn.Sequential(nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.layers4 = nn.Sequential(nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),)self.fc = nn.Sequential(nn.Linear(7*7*128,1024),nn.ReLU(inplace=True),nn.Linear(1024,100),nn.ReLU(inplace=True),nn.Linear(100,26))def forward(self, x):x = self.layers1(x)x = self.layers2(x)x = self.layers3(x)x = self.layers4(x)x = x.view(x.size(0),-1)x = self.fc(x)return x

四. 训练网络

  1. 在训练网络的时候,这次我们分成了训练集合测试集,和之前写有些不一样,但也都差不多,
  2. 每训练50个epoch,就在测试集上测试一下准确率,这样可以明确模型的好与坏,精度如何。
  3. 最后如果准确率大于0.983后,就退出训练,训练太久而准确率没有太多的提升,还不如提前结束训练。
error = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(cnn.parameters(),lr=0.1)for epochs in range(100):for i ,(img,label) in enumerate(train_loader):img = img.view(100,1,28,28)img = Variable(img)label = Variable(label)optimizer.zero_grad()output = cnn(img)loss = error(output,label)loss.backward()optimizer.step()if i %50 == 0:    #测试模型accuracy = 0for x,y in test_loader:x = x.view(100,1,28,28)x = Variable(x)out = cnn(x)pre = torch.max(out.data,1)[1]accuracy += (pre == y).sum()print('accuracy:',accuracy.item()/(5491))if(accuracy.item()/(5491)>0.975):torch.save(cnn.state_dict(),'model.pth')if(accuracy.item()/5491 > 0.983):break

最后打印一下训练的log,训练到最后准确率基本上没有变化了。

完整代码可在GitHub上面下载地址.

Thank for your reading !!!

Pytorch手势识别相关推荐

  1. 【PyTorch深度学习项目实战100例】—— Python+OpenCV+MediaPipe手势识别系统 | 第2例

    前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...

  2. Pytorch+opencv 手势识别

    配置 opencv安装 使用清华源 pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python opencv的使用 cv ...

  3. 树莓派实时(30fps)手势识别,从数据集采集开始,全部流程开源

    目录结构 1.背景介绍 2.数据采集 3.网络设计 4.网络训练 5.网络部署 6.总结 1.背景介绍 最近采购了一块新的树莓派,迫不及待的想要在树莓派上实现一个实时的手势识别.从算法的角度讲,并不是 ...

  4. Jetson nano (4GB B01) 系统安装,官方Demo测试 (目标检测、手势识别)

    Jetson nano (4GB B01) 系统安装,官方Demo测试 (目标检测.手势识别) 此文确保你可以正确搭建jetson nano环境,并跑通官方"hello AI world&q ...

  5. 计算机视觉研究院手把手教你深度学习的部署(手势识别,源码已开源)

    计算机视觉研究院专栏 作者:Edison_G 今天我们继续基于姿态估计的运动计数APP开发! 公众号ID|ComputerVisionGzq 学习群|扫码在主页获取加入方式 关注并星标 从此不迷路 计 ...

  6. 【目标检测】你想知道的手势识别都在这里 【YOLO】网络

    基于YOLO+ResNet50的手势识别 目录 基于YOLO+ResNet50的手势识别 写在前面 (一)项目背景以及系统环境 1.1 项目背景 1.2 硬件环境 1.3 操作系统 1.4 主要界面 ...

  7. 基于CNN的动态手势识别:Real-time Hand Gesture Detection and Classification Using Convolutional Neural Networks

    Real-time Hand Gesture Detection and Classification Using Convolutional Neural Networks论文解读 1. 概述 2. ...

  8. 人工智能AI:TensorFlow Keras PyTorch MXNet PaddlePaddle 深度学习实战 part1

    日萌社 人工智能AI:TensorFlow Keras PyTorch MXNet PaddlePaddle 深度学习实战 part1 人工智能AI:TensorFlow Keras PyTorch ...

  9. 初学入门YOLOv5手势识别之制作并训练自己的数据集

    随着短视频vlog时代的到来,自动驾驶技术.人脸识别门禁系统.智慧视频监控.AI机器人等贴近人们日常生活的视频信息量的暴增,视频目标检测的研究具有无比的现实研究意义与未来行业潜力.视频是由一系列具有时 ...

最新文章

  1. linux的简单面试题,收集的一些简单的UNIX/Linux面试题
  2. 【linux回炉 档案权限与目录配置】
  3. JAVA笔记18-容器之二增强的for循环(不重要)
  4. opencv画框返回坐标 python_python opencv鼠标事件实现画框圈定目标获取坐标信息
  5. springboot读取自定义properties文件
  6. 在Apache配置反向代理即实现输出内容替换
  7. 【python数字信号处理】——循环卷积(也叫圆圈卷积)
  8. 【Python CheckiO 题解】Bigger Price
  9. 汇编jnl_汇编指令集
  10. 作者:​邓波(1973-),男,博士,北京系统工程研究所研究员。
  11. 数据统计告诉你,程序员是不是35岁就退休
  12. Create an offline installation of Visual Studio 2017 RC
  13. 教你在CentOS 8上安装和配置Redmine项目管理系统
  14. 看微软“第四代模块化数据中心”宣传片之后的思考
  15. Leetcode799. 香槟塔
  16. windows批处理:start的用法
  17. “自动修复”无法修复你的电脑-SATAFIRM S11-固态硬盘坏了
  18. 使用百度在某个网站内进行搜索
  19. 我的世界服务器头像文件,端游我的世界怎么换头像,端游我的世界怎么换头像框...
  20. 鼻子上爱出油,还有黑头、粉刺怎么办???

热门文章

  1. IT人物之搜狗公司COO茹立云 听学霸分享成长故事
  2. 记住 逆境并不是尽头 而是更好的结果的一个转角而已。
  3. 实现一个安卓学习助手app
  4. 简单工厂模式--女娲造人造啥做啥
  5. 微信公众号开发---机器人
  6. DC初级摄友必学摄影技巧(转贴) 1
  7. Oracle11gR2(二)-图形安装
  8. 网络图片URL转化为Bitmap对象
  9. 有哪些十分惊艳的书值得推荐?
  10. html制作满天星,干花满天星如何制作