skimage教程
skimage开发文档
机器学习新手工程师常犯的6大错误http://www.duozhishidai.com/article-12200-1.html
深度学习需掌握的知识https://blog.csdn.net/duozhishidai/article/details/87301056
从数据集的建立,到模型的建立,到训练,到预测!!!
中间遇到很多坑
数据建立:MSELoss损失函数要求独热编码,CrossEntropyLosss损失函数要求预测端是字符型,label要求长整型,函数将label转成独热编码
模型建立:全连接时需将输入的batch_size与其它维度共4维转成2维
训练:没什么好说的
预测:也没什么好说的,总之实现对输入图片的预测功能

附代码:
训练代码,其中数据集解码函数直接复制的

import cv2
import numpy as np
import struct
import torch.nn as nn
import torch.utils.data as Data
from torch.autograd import Variable
import torch
from torch.utils.data import Dataset,DataLoader,TensorDatasetclass Net(nn.Module):def __init__(self):super(Net,self).__init__()body = []body.append(nn.Conv2d(1,16,3))body.append(nn.BatchNorm2d(16))body.append(nn.ReLU(True))body.append(nn.Conv2d(16,32,3))body.append(nn.BatchNorm2d(32))body.append(nn.ReLU(True))body.append(nn.MaxPool2d(kernel_size=2,stride=2))body.append(nn.Conv2d(32,64,3))body.append(nn.BatchNorm2d(64))body.append(nn.ReLU(True))body.append(nn.Conv2d(64,128,3))body.append(nn.BatchNorm2d(128))body.append(nn.ReLU(True))body.append(nn.MaxPool2d(kernel_size=2,stride=2))tail = []tail.append(nn.Linear(128*4*4,1024))tail.append(nn.ReLU(True))tail.append(nn.Linear(1024,128))tail.append(nn.ReLU(True))tail.append(nn.Linear(128,10))self.body = nn.Sequential(*body)self.tail = nn.Sequential(*tail)def forward(self,x):ret = self.body(x)#手动四维转二维ret = ret.view(ret.size(0),-1)ret = self.tail(ret)return retdef decode_train(idx3_ubyte_file):"""解析idx3文件的通用函数:param idx3_ubyte_file: idx3文件路径:return: 数据集"""# 读取二进制数据bin_data = open(idx3_ubyte_file, 'rb').read()# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽offset = 0fmt_header = '>iiii'magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)print ('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))# 解析数据集image_size = num_rows * num_colsoffset += struct.calcsize(fmt_header)fmt_image = '>' + str(image_size) + 'B'images = np.empty((num_images, num_rows, num_cols))for i in range(num_images):if (i + 1) % 10000 == 0:print ('已解析 %d' % (i + 1) + '张')images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))offset += struct.calcsize(fmt_image)return imagespath_train = 'train-images.idx3-ubyte'
image_list = decode_train(path_train)def decode_label(idx1_ubyte_file):"""解析idx1文件的通用函数:param idx1_ubyte_file: idx1文件路径:return: 数据集"""# 读取二进制数据bin_data = open(idx1_ubyte_file, 'rb').read()# 解析文件头信息,依次为魔数和标签数offset = 0fmt_header = '>ii'magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)print ('魔数:%d, 图片数量: %d张' % (magic_number, num_images))# 解析数据集offset += struct.calcsize(fmt_header)fmt_image = '>B'labels = np.empty(num_images)for i in range(num_images):if (i + 1) % 10000 == 0:print ('已解析 %d' % (i + 1) + '张')labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]offset += struct.calcsize(fmt_image)return labelspath_label = 'train-labels.idx1-ubyte'
label_list = decode_label(path_label)def data_loader(images,labels):#独热编码第二个参数需要对其进行转置# labels = labels.reshape(labels.shape[0],1)images = torch.from_numpy(images)labels = torch.from_numpy(labels)# labels_map = torch.zeros((labels.shape[0],10))#独热编码,函数第1个参数是轴,用于定位1是定位列,第二个参数是位置,具体那一列,第三个参数是给予该位置赋值# labels_map = labels_map.scatter_(1,labels.long(),1)dataset = TensorDataset(images,labels)return Data.DataLoader(dataset,batch_size=100,shuffle=True,num_workers=1)def train(data):net = Net()optimizer = torch.optim.SGD(net.parameters(),lr=0.001)loss_function = torch.nn.CrossEntropyLoss()for  i in range(20):for item in data:data_x = Variable(item[0]).float().unsqueeze(0).view(100,1,28,28)data_y = Variable(item[1])prediction = net(data_x)loss = loss_function(prediction,data_y.long())optimizer.zero_grad()loss.backward()optimizer.step()print(loss)torch.save(net,'minist.pkl')def main():data = data_loader(image_list,label_list)train(data)if __name__=='__main__':main()

预测代码:

import cv2
import numpy as np
import torch
from minis import Net
from torch.autograd import Variabledef main():model_path = 'minist.pkl'net = torch.load(model_path)image_path = r'../test1.jpg'image = cv2.imread(image_path,0)image = cv2.resize(image,(28,28))image = cv2.threshold(image,156,255,cv2.THRESH_BINARY_INV)#thres = cv2.threshold(image,127,255,cv2.THRESH_BINARY)image = image[1]thres = torch.from_numpy(np.asarray(image))thres = thres.unsqueeze(0).unsqueeze(0)thres = Variable(thres)prediction = net(thres.float())loss_function = torch.nn.CrossEntropyLoss()min_loss = 100pre = Nonefor i in range(10):val = np.asarray([i])val = Variable(torch.from_numpy(val))loss = loss_function(prediction,val.long())if min_loss>loss:min_loss = losspre = iprint(pre)if __name__=='__main__':main()

深度学习之torch(一)MINIST手写字符分类相关推荐

  1. 深度学习导论(5)手写数字识别问题步骤

    深度学习导论(5)手写数字识别问题步骤 手写数字识别分类问题具体步骤(Training an handwritten digit classification) 加载数据 显示训练集中的图片 定义神经 ...

  2. [Python人工智能] 三十.Keras深度学习构建CNN识别阿拉伯手写文字图像

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN). ...

  3. [Python图像识别] 四十七.Keras深度学习构建CNN识别阿拉伯手写文字图像

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  4. Keras深度学习实战(37)——手写文字识别

    Keras深度学习实战(37)--手写文字识别 0. 前言 1. 手写文字识别相关背景 1.1 Connectionist temporal classification (CTC) 1.2 解码 C ...

  5. 深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了

    大家好,我是微学AI,今天给大家带来手写OCR识别的项目.手写的文稿在日常生活中较为常见,比如笔记.会议记录,合同签名.手写书信等,手写体的文字到处都有,所以针对手写体识别也是有较大的需求.目前手写体 ...

  6. 深度学习 第三章 tensorflow手写数字识别

    深度学习入门视频-唐宇迪 (笔记加自我整理) 深度学习 第三章 tensorflow手写数字识别 1.tensorflow常见操作 这里使用的是tensorflow1.x版本,tensorflow基本 ...

  7. Java软件研发工程师转行之深度学习(Deep Learning)进阶:手写数字识别+人脸识别+图像中物体分类+视频分类+图像与文字特征+猫狗分类

    本文适合于对机器学习和数据挖掘有所了解,想深入研究深度学习的读者 1.对概率基本概率有所了解 2.具有微积分和线性代数的基本知识 3.有一定的编程基础(Python) Java软件研发工程师转行之深度 ...

  8. 深度学习之基于GAN实现手写数字生成

    在弄毕设的时候,室友的毕设是基于DCGAN实现音乐的自动生成.那是第一次接触对抗神经网络,当时听室友的描述就是两个CNN,一个生成一个监测,在互相博弈. 最近我关注的一个大神在弄有关于GAN的东西,所 ...

  9. 手把手写深度学习(8)——用LSTM生成手写英文文章

    前言:本系列前文介绍了用GANs生成手写数字,生成手写数字的任务是一件非常简单.入门的事情,因为MNIST数据集提供的,像素点非常低,最后生成的效果也非常模糊.要知道,高分辨率的生成一直是深层生成问题 ...

  10. tensorflow 中文字体训练集_深度学习与TensorFlow:自建手写字体数据集上的模型测试...

    在上一篇文章中,我们使用mnist数据集去做了一个识别的小型神经网络,在今天的这篇文章里,我们将要通过使用自建数据集去检验上一篇文章的模型,从而真正的可以去应用神经网络. 先解决上一篇文章中一些不完美 ...

最新文章

  1. Windows server 2008 远程桌面建立
  2. eureka之EurekaInstanceConfig接口的作用
  3. 笔记本windows7设置WIFI教程(超详细)
  4. 微信小程序云开发之云函数的创建与环境配置
  5. 将MfgTool工具改造为自己的烧写工具
  6. 华栖云科技图形图像视音频算法岗面试经验
  7. 简单的BBcode parsing
  8. android开发的学习路线
  9. [e袋购APP]高校物业管理的特点
  10. node mysql 坑_菜鸟Node.js MySQL教程遇到的坑
  11. c++实验题:设计两个酒店管理员客房管理的类:一个是Person类,要求储存房号、客户姓名和身份证号的信息;另一个类是Client类,要求新增客户的订房、退房和消费金额等信息,并给出相关测试算法。
  12. 5面阿里,终获offer(Java后端)
  13. (转)国企,私企与外企利弊通观--关键时刻给应届毕业生及时点拨5
  14. MIUI13来了,米粉们还期待吗?
  15. 使用layer.open打开自定义弹窗,获取表单内容发送到后端
  16. 车联网是什么_车联网有什么用_车联网功能介绍
  17. 将图片和文字写到pdf文件中
  18. Android版本Oppo电视,oppo电视助手app下载
  19. 大牧絮叨设计模式:工厂方法模式
  20. TP-Link WR703N升级64M内存+外接SMA天线+刷OpenWRT(2)升级内存

热门文章

  1. 刺客信条3免uplay破解补丁
  2. 车牌拍照系统上传服务器,服务器端车牌拍照识别
  3. Java性能优化的35种方法
  4. Java面向对象编程三大特征-多态
  5. 【ffplay播放器】ffplay 播放器整体架构
  6. 单片机学习(四)——ESP8266(最全教程和说明)
  7. 设计原则:单一职责(SRP)原则
  8. 500G JAVA视频网盘分享 JEECG开源社区
  9. Java JDK 动态代理(AOP)使用及实现原理分析
  10. php中wamp具体指的是,phpwamp和wampserver有什么关系,为什么名字都带wamp这几个字母,两者有区别吗?我应该用哪个?...