pytorch0.4版的CNN对minist分类
卷积神经网络(Convolutional Neural Network, CNN)是深度学习技术中极具代表的网络结构之一,在图像处理领域取得了很大的成功,在国际标准的ImageNet数据集上,许多成功的模型都是基于CNN的。
卷积神经网络CNN的结构一般包含这几个层:
- 输入层:用于数据的输入
- 卷积层:使用卷积核进行特征提取和特征映射
- 激励层:由于卷积也是一种线性运算,因此需要增加非线性映射
- 池化层:进行下采样,对特征图稀疏处理,减少数据运算量。
- 全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失
- 输出层:用于输出结果
用pytorch0.4 做的cnn网络做的minist 分类,代码如下:
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 import torch.optim as optim 5 from torchvision import datasets, transforms 6 from torch.autograd import Variable 7 8 # Training settings 9 batch_size = 64 10 11 # MNIST Dataset 12 train_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True) 13 test_dataset = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor()) 14 15 # Data Loader (Input Pipeline) 16 train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True) 17 test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False) 18 19 class Net(nn.Module): 20 def __init__(self): 21 super(Net, self).__init__() 22 # 输入1通道,输出10通道,kernel 5*5 23 self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # 定义conv1函数的是图像卷积函数:输入为图像(1个频道,即灰度图),输出为 10张特征图, 卷积核为5x5正方形 24 self.conv2 = nn.Conv2d(10, 20, kernel_size=5) # # 定义conv2函数的是图像卷积函数:输入为10张特征图,输出为20张特征图, 卷积核为5x5正方形 25 self.mp = nn.MaxPool2d(2) 26 # fully connect 27 self.fc = nn.Linear(320, 10) 28 29 def forward(self, x): 30 # in_size = 64 31 in_size = x.size(0) # one batch 32 # x: 64*10*12*12 33 x = F.relu(self.mp(self.conv1(x))) 34 # x: 64*20*4*4 35 x = F.relu(self.mp(self.conv2(x))) 36 # x: 64*320 37 x = x.view(in_size, -1) # flatten the tensor 38 # x: 64*10 39 x = self.fc(x) 40 return F.log_softmax(x,dim=0) 41 42 43 model = Net() 44 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 45 46 def train(epoch): 47 for batch_idx, (data, target) in enumerate(train_loader): 48 data, target = Variable(data), Variable(target) 49 optimizer.zero_grad() 50 output = model(data) 51 loss = F.nll_loss(output, target) 52 loss.backward() 53 optimizer.step() 54 if batch_idx % 200 == 0: 55 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 56 epoch, batch_idx * len(data), len(train_loader.dataset), 57 100. * batch_idx / len(train_loader), loss.item())) 58 59 60 def test(): 61 test_loss = 0 62 correct = 0 63 for data, target in test_loader: 64 data, target = Variable(data), Variable(target) 65 output = model(data) 66 # sum up batch loss 67 #test_loss += F.nll_loss(output, target, size_average=False).item() 68 test_loss += F.nll_loss(output, target, reduction = 'sum').item() 69 # get the index of the max log-probability 70 pred = output.data.max(1, keepdim=True)[1] 71 correct += pred.eq(target.data.view_as(pred)).cpu().sum() 72 73 test_loss /= len(test_loader.dataset) 74 print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 75 test_loss, correct, len(test_loader.dataset), 76 100. * correct / len(test_loader.dataset))) 77 78 79 if __name__=="__main__": 80 for epoch in range(1, 4): 81 train(epoch) 82 test()
运行效果如下:
Train Epoch: 1 [0/60000 (0%)] Loss: 4.163342 Train Epoch: 1 [12800/60000 (21%)] Loss: 2.689871 Train Epoch: 1 [25600/60000 (43%)] Loss: 2.553686 Train Epoch: 1 [38400/60000 (64%)] Loss: 2.376630 Train Epoch: 1 [51200/60000 (85%)] Loss: 2.321894Test set: Average loss: 2.2703, Accuracy: 9490/10000 (94%)Train Epoch: 2 [0/60000 (0%)] Loss: 2.321601 Train Epoch: 2 [12800/60000 (21%)] Loss: 2.293680 Train Epoch: 2 [25600/60000 (43%)] Loss: 2.377935 Train Epoch: 2 [38400/60000 (64%)] Loss: 2.150829 Train Epoch: 2 [51200/60000 (85%)] Loss: 2.201805Test set: Average loss: 2.1848, Accuracy: 9658/10000 (96%)Train Epoch: 3 [0/60000 (0%)] Loss: 2.238524 Train Epoch: 3 [12800/60000 (21%)] Loss: 2.224833 Train Epoch: 3 [25600/60000 (43%)] Loss: 2.240626 Train Epoch: 3 [38400/60000 (64%)] Loss: 2.217183 Train Epoch: 3 [51200/60000 (85%)] Loss: 2.357141Test set: Average loss: 2.1426, Accuracy: 9723/10000 (97%)
转载于:https://www.cnblogs.com/www-caiyin-com/p/9955779.html
pytorch0.4版的CNN对minist分类相关推荐
- 基于CNN的MINIST手写数字识别项目代码以及原理详解
文章目录 项目简介 项目下载地址 项目开发软件环境 项目开发硬件环境 前言 一.数据加载的作用 二.Pytorch进行数据加载所需工具 2.1 Dataset 2.2 Dataloader 2.3 T ...
- TensorFlow实现多层感知机MINIST分类
TensorFlow实现多层感知机MINIST分类 TensorFlow 支持自动求导,可以使用 TensorFlow 优化器来计算和使用梯度.使用梯度自动更新用变量定义的张量.本文将使用 Tenso ...
- 使用CNN做文本分类——将图像2维卷积换成1维
使用CNN做文本分类from __future__ importdivision, print_function, absolute_importimporttensorflow as tfimpor ...
- CNN在文本分类的应用(内有代码实现) 论文Convolutional Neural Networks for Sentence Classification
一.CNN文本分类简介 文本分类是NLP领域的一个重要子任务,文本分类的目标是自动的将文本打上已经定义好的标签,常见的文本分类任务有: 用户评论的情感识别 垃圾邮件过滤 用户查询意图识别 新闻分类 由 ...
- 文本分类(下) | 卷积神经网络(CNN)在文本分类上的应用
正文共3758张图,4张图,预计阅读时间18分钟. 1.简介 原先写过两篇文章,分别介绍了传统机器学习方法在文本分类上的应用以及CNN原理,然后本篇文章结合两篇论文展开,主要讲述下CNN在文本分类上的 ...
- 【论文复现】使用CNN进行文本分类
今天要写的是关于NLP领域的一个关键问题:文本分类. 相对应的论文是:Convolutional Neural Networks for Sentence Classification 参考的博客为: ...
- 论文复现:用 CNN 进行文本分类
前一篇文章中我们学习了 CNN 的基础结构,并且知道了它是计算机视觉领域的基础模型,其实 CNN 不仅仅可以用于计算机视觉,还可以用于处理自然语言处理问题,今天就来看看如何用 CNN 处理文本分类任务 ...
- linux tf2 中文,tf2+cnn+中文文本分类优化系列(2)
1 前言 接着上次的tf2+cnn+中文文本分类优化系列(1),本次进行优化:使用多个卷积核进行特征抽取.之前是使用filter_size=2进行2-gram特征的识别,本次使用filter_size ...
- CNN对句子分类(tensorflow)
卷积神经网络是一种特殊的深层的神经网络模型,它的特殊性体现在两个方面,一方面它的神经元间的连接是非全连接的, 另一方面同一层中某些神经元之间的连接的权重是共享的(即相同的).它的非全连接和权值共享的网 ...
最新文章
- 第三十二课.脉冲神经网络SNN
- 苹果数据线不能充电_20亿个不能用的苹果充电器,库克,你的这波强制“环保”翻车了...
- Adtran加入SDN大潮,剑指运营商SDN转型
- AMQP Connection 127.0.0.1:5672] ERROR [o.s.a.rabbit.connection.CachingConnectionFactory] CachingConn
- 历史上的今天:阿帕网退役;Quintus 收购 Mustang;同步电流磁芯存储器获得专利...
- NodeJs——子进程
- MATLAB的sum函数
- android矢量图 内存大,Android内存控制小技巧-使用矢量图来节省你的内存并简化你的开发。...
- Python中字符的匹配
- AdaBoost算法实例详解
- 【OpenCV 例程200篇】220.对图像进行马赛克处理
- SpringBoot添加压力测试
- 圣诞要到了~教你用Python制作一个表白神器——照片墙,祝你成功
- 电脑上不去无线网如何解决
- oracle cpu使用率高怎么排查解决,OracleCPU占用率较高的处理方法
- android 手机震动功能吗,Android编程实现手机震动功能的方法
- 金庸的「射雕三部曲」,其实还有一个隐藏的第一部
- Python:dbus监控U盘插拔
- android畅言作业平台,畅言作业平台学生端
- vertical-align 垂直对齐方式
热门文章
- 短信发送:webservice调用第三方接口发送短信
- 自己写的计算时间坐标的代码
- 为什么EClipse不显示错误
- STRUTS模拟试题
- mybatis 知识1
- python实现轨迹回放供应_运动轨迹回放 百度地图api示例源码
- 攻防演练中的业务逻辑漏洞及检测思路
- 【Numpy学习记录】np.transpose讲解
- python离线安装国内镜像OpenCV
- 运行报错error: (-215:Assertion failed) !ssize.empty() in function 'cv::resize'