因为临近期末,老师让我们做一个期末项目,因为平时上课都在划水,所以这也是我第一个自己写的深度学习项目,写上博客留作纪念。

数据集是这样的

总共有7000张图片

首先先导入相关包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import cv2

定义超参数`

BATCH_SIZE=32 #定义超参数,每次处理32张图片
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu" )#检测电脑上是否有GPU,如果有就使用GPU,如果没有就使用cpu
EPOCHS=20 #将数据集训练20轮

对图片做预处理

Pipline=transforms.Compose([transforms.ToTensor( ),#将图片转化为tensortransforms.Normalize((0.1307),(0.3081))#降低模型复杂度,官网提供的数据]
)#对图像作相应处理

将数据集按照3:1划分训练集和测试集并分到不同文件夹中

import shutil
PATH='C:\\Users\\admin\\Desktop\\deeplearning\\PalmBigDataBase\\'
for i in range(1,386):for j in range(1,10):if j%3!=0:shutil.move(PATH+'P_F_'+str(i)+'_'+str(j)+'.bmp','C:\\Users\\admin\\Desktop\\train') #将一个文件夹下的图片分到不同文件夹elif j%3==0:shutil.move(PATH+'P_F_'+str(i)+'_'+str(j)+'.bmp','C:\\Users\\admin\\Desktop\\test')

加载数据集

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import numpy as np
import os
import re
import cv2
train_set=[]
train_label=[]
test_set=[]
train_dataset=[]
test_label=[]
test_dataset=[]
for filename in os.listdir('C:\\Users\\admin\\Desktop\\train'):imgPath='C:\\Users\\admin\\Desktop\\train'img=cv2.imread(imgPath+'\\'+filename,0)#循环读取该目录下的所有图片train_label.append(int(re.findall(r"\d+",filename)[0]))#将图片名的第一个数字作为该图片的类train_set.append(Pipline(img).numpy().tolist())#对图片作pipline操作后转化为数组
train_dataset=TensorDataset(torch.tensor(train_set),torch.tensor(train_label))#对两个列表进行压缩后作为训练集
for filename in os.listdir('C:\\Users\\admin\\Desktop\\test'):imgPath='C:\\Users\\admin\\Desktop\\test'img=cv2.imread(imgPath+'\\'+filename,0)test_label.append(int(re.findall(r"\d+",filename)[0]))test_set.append(Pipline(img).numpy().tolist())
test_dataset=TensorDataset(torch.tensor(test_set),torch.tensor(test_label))#对两个列表进行压缩后作为测试集
TrainLoader=DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)#加载训练集
TestLoader=DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=True)#加载测试集

关于TensorDataset函数,这篇博客讲的很详细,看了就能懂
https://blog.csdn.net/qq_40211493/article/details/107529148

构建cnn网络

class net(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(1,10,5)#进行2d卷积,因为是灰度图像,所以初试通道是1。输入通道为1,输出通道是10,进行5*5的2d卷积self.conv2=nn.Conv2d(10,20,3)self.fc1=nn.Linear(20*60*60,500)self.fc2=nn.Linear(500,386)#要分386类def forward(self,x):input_size=x.size(0) #batch_size)x=self.conv1(x) # 输入:batch *1*128*128 输出 batch *10*124*124x=F.relu(x) #size保持不变x=F.max_pool2d(x,2,2) #输入 batch *10*124*124 输出batch*10*62*62,对图片进行压缩,减少运算。x=self.conv2(x) #输入 batch*10*62*62 输出 batch*20*60*60x=F.relu(x)#激活层,不改变图片的shape,每次卷积之后进行一次激活,输出一个非线性函数,增强神经元的表达能力x=x.view(input_size,-1) #将图片转化为一维线性x=self.fc1(x)#输入 batch*30*60*60 输出 batch*500x=F.relu(x)x=self.fc2(x)#输入 batch*500 输出 batch*386out_put=F.log_softmax(x,dim=1)#计算损失函数,输出概率最大的类别 return out_put

创建优化器

model=net().to(DEVICE)#创建模型,将模型部署到设备上
optimizer=optim.Adam(model.parameters())#对参数进行优化

对模型进行训练和测试

def train_model(model,optimizer,epoch,device,TrainLoader):model.train()#模型训练for (batch_index ,data) in enumerate(TrainLoader):x_data,label=datax_data =x_data.to(device)label =label.to(device)optimizer.zero_grad()#梯度初始化为0 output=model(x_data)loss=F.cross_entropy(output,label)#计算损失loss.backward()optimizer.step()if batch_index %300 ==0:print("Train Epoch:{} \t Loss:{:.6f}".format(epoch,loss.item()))
def test(model,device,TestLoader):model.eval()  #模型验证correct=0.0  #正确率test_loss=0.0  #测试损失with torch.no_grad():  #不会计算梯度,也不会进行反向传播for(batch_index ,data) in enumerate(TestLoader):y_data,label=datay_data=y_data.to(device)  #部署到DEVICE上label =label.to(device)output=model(y_data) #测试数据test_loss+=F.cross_entropy(output,label).item()  #计算测试损失pred=output.max(1,keepdim=True)[1]  #[0]值  [1]索引  找到概率值最大的下标
#           pred=torch.max(output,dimm=1)
#           pred=output.argmax(dim=1)correct+=pred.eq(label.view_as(pred)).sum().item()  #累计正确的值test_loss/=len(TestLoader.dataset)print("Test -- Average loss:{:.4f},Accuracy:{:.3f}\n".format(test_loss,100.0*correct/len(TestLoader.dataset)))

训练

for epoch in  range(1,EPOCHS+1):train_model(model,optimizer,epoch,DEVICE,TrainLoader)test(model,DEVICE,TestLoader)

运行结果展示

祝大家天天开心,万事如意

基于CNN对掌纹图片进行分类相关推荐

  1. 初次跑CNN进行掌纹识别遇到的问题

    本意是想大概跑一下CNN对掌纹进行识别分类的代码,了解一下流程和框架.基本内容参考基于CNN对掌纹图片进行分类. 1-1.RuntimeError: mat1 and mat2 shapes cann ...

  2. 掌纹与掌静脉融合matlab代码,一种基于人脸和掌纹掌静脉识别的身份比对方法与流程...

    本发明涉及生物识别技术领域,具体为一种基于人脸和掌纹掌静脉识别的身份比对方法. 背景技术: 生物特征是指人体所固有的生理特征或行为特征,生理特征包括指纹.人脸.虹膜.掌静脉等,行为特征有声纹.步态以及 ...

  3. Opencv-Python提取掌纹图片ROI

    参考文章:基于图像相位及方向特征的掌纹识别的c++实现(二) - 知乎 (zhihu.com) 代码已经通过测试,可以跑通. 完整源码如下: import math import cv2 import ...

  4. python基于opencv工具掌纹主线提取

    我们将在这篇文章中使用Python和OpenCV库来找出我们手掌中的主线. 首先,让我们读取原始图像: import cv2 image = cv2.imread("palm.jpg&quo ...

  5. CNN简单实战:PyTorch搭建CNN对猫狗图片进行分类

    在上一篇文章:CNN训练前的准备:PyTorch处理自己的图像数据(Dataset和Dataloader),大致介绍了怎么利用pytorch把猫狗图片处理成CNN需要的数据,今天就用该数据对自己定义的 ...

  6. 基于深度学习的近红外掌纹识别原型系统设计与实现

    基于深度学习的近红外掌纹识别原型系统设计与实现 一.绪论 二.深度学习知识 三.Tensorflow 四.卷积神经网络 五.掌纹识别理论 掌纹图像采集 掌纹图像预处理 掌纹特征提取 掌纹特征匹配 掌纹 ...

  7. CNN非接触掌纹识别改进过程(二)

    1. 学习率衰减 这是之前在实习的时候进行的图像处理课题提供的思路. 学习率衰减策略可以实现两个目的:在训练初期较高的学习率使得网络模型更快的梯度下降,训练后期较小的学习率使得算法更容易收敛到最优值. ...

  8. 掌纹与掌静脉融合matlab代码,手形、掌纹和掌静脉多特征融合识别

    1引言当前,众多研究学者在基于手部特征的生物特征识别技术进行了大量的研究,包括指纹识别[1-2].手形识别[3-4].掌纹识别[5-6].静脉识别[7-8]等,在理论研究和实际系统应用都取得了一定的成 ...

  9. 嵌入式掌纹掌脉识别门禁控制系统的设计与实现

    本篇博文是对中科大硕士论文<嵌入式掌纹掌脉识别门禁控制系统的设计与实现>的总结. 1国内外研究进展 有以ARM为处理器的应用,目前大多采用三星的S3C2410和S3C2440,同时搭载FL ...

最新文章

  1. 机翻降重?掩饰抄袭?SCI期刊上的这些「奇言怪语」,不少来自中国作者
  2. Atitit.js跨域解决方案attilax大总结 后台java php c#.net的CORS支持
  3. IIS环境下如何批量添加、修改、删除绑定的域名
  4. shopify在哪里填写html,[Shopify开店教程]添加嵌入代码
  5. 前端面试之前要准备的那些事
  6. Redis源码阅读-Adlist双向链表
  7. Discuz = 7.2 SQL注入漏洞详情
  8. 数字电子技术基础(四):门电路(二极管)
  9. 电脑编程就业找哪方面
  10. 网卡驱动程序igb和ixgbe
  11. 机器学习-马尔可夫模型与隐马尔可夫模型
  12. pci 1751 java_PCI-1751快速安装使用手册.PDF
  13. 利用Vue制作一个商品管理页面(第二部分,小完结)
  14. 汤姆猫代码python_IOS 汤姆猫核心代码
  15. 在线会议中人脸面部轮廓图像提取(三)——Dlib库人脸面部轮廓图像特征提取
  16. 出差忘带电脑脑袋炸裂?鼓捣了下个人云,真香
  17. H.265的各种帧(详解):接入图像
  18. 如何有效的阅读一本书
  19. Raspbian 教学系统安装、配置流程
  20. 请问怎么设置默认浏览器

热门文章

  1. 云贝餐饮连锁V2 v2.5.6 外卖/店内/预定/排号 餐饮外卖扫码点餐 智慧新零售
  2. yaml文件的语法及注意事项
  3. 使用西瓜视频xgplayer播放MP4、m3u8、flv(直播、点播);videojs
  4. 分享一个免费清理苹果电脑Mac磁盘空间方法
  5. mysql对韵母分组,音的分组教案
  6. H.266/VVC技术学习:色度联合编码(JCCR)技术
  7. 帮别人开车,交通肇事应负怎样的刑事责任?
  8. BC2.0 以太坊应用技术交流01(不炒币,不传销)
  9. 手把手教你如何在湾区买房子比别人低13万
  10. 扎心了,程序员的2017到2019