本文代码完全借鉴pytorch中文手册

'''我们找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim  #实现各种优化算法的库
from torchvision import datasets,transformsBATCH_SIZE=512  #大概需要2G的显存
EPOCHS=20       #总共训练20次
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu") #让torch判断是否使用GPU#对数据进行预处理
transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
#准备数据集,路径是关于py文件的相对路径
trainset=datasets.MNIST(root='./MNIST_data',train=True,download=False,transform=transforms)#加载数据集
train_loader=torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,shuffle=True)#准备测试集
testset=datasets.MNIST(root='./MNIST_data',train=True,download=False,transform=transforms)#加载测试集
test_loader=torch.utils.data.DataLoader(testset,batch_size=BATCH_SIZE,shuffle=True)#定义卷积神经网络
class ConvNet(nn.Module):def __init__(self):super().__init__()#batch*1*28*28(每次会送入batch个样本,输入通道数1(黑白图像)),图像分辨率是28*28)#下面的卷积层Conv2d的第一个参数指输入通道数,第二个参数指输出通道数,第三个参数指卷积核的大小self.conv1=nn.Conv2d(1,10,5)self.conv2=nn.Conv2d(10,20,3)#下面的全连接层Linear的第一个参数指输入通道数,第二个参数指输出通道数self.fc1=nn.Linear(20*10*10,500) #输入通道数是2000,输出通道数是500self.fc2=nn.Linear(500,10)  #输入通道数是500,输出通道数是10,即10分类def forward(self,x):in_size=x.size(0)  #在本例中in_size=512,也就是BATCH_SIZE的值。输入的x可以看成是512*1*28*28的张量out=self.conv1(x)  #batch*1*28*28 -> batch*10*24*24(28×28的图像经过一次核为5×5的卷积,输出变为24×24)out=F.relu(out)    #batch*10*24*24(激活函数ReLU不改变形状)out=F.max_pool2d(out,2,2)#batch*10*24*24 -> batch*10*12*12(2×2的池化层会减半)out=self.conv2(out) #batch*10*12*12 -> batch*20*10*10(再卷积一次,核的大小是3)out=F.relu(out)out=out.view(in_size,-1) #batch*20*10*10 -> batch*2000(out的第二维是-1,说明是自动推算,本例中第二维是20*10*10)out=self.fc1(out)        #batch*2000 -> batch*500out=F.relu(out)         out=self.fc2(out)        #batch*500 -> batch*10out=F.log_softmax(out,dim=1) #计算log(softmax(x)),用log是为了防止数过大。return out#我们实例化一个网络,实例化后使用.to方法将网络移动到GPU
model=ConvNet().to(DEVICE)#优化器我们也直接选择简单暴力的Adam
optimizer=optim.Adam(model.parameters())#定义一个训练函数
def train(model,device,train_loader,optimizer,epoch):model.train() #启用BatchNormalization和Dropout,将BatchNormalization和Dropout置为Truefor batch_idex,(data,target) in enumerate(train_loader): ##将迭代器的数据组成一个索引系列,并输出索引和值,batch_diex是序号,后者是数据data,target=data.to(device),target.to(device)  #在gpu上跑optimizer.zero_grad() #梯度清零output=model(data)  #将数据放入模型loss=F.nll_loss(output,target)  #计算损失函数loss.backward()    #计算梯度optimizer.step()if (batch_idex+1)%30 == 0:  #每训练30个打印一次print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,batch_idex*len(data),len(train_loader.dataset),100. * batch_idex/len(train_loader.dataset),loss.item()))#定义一个测试函数
def test(model,device,test_loader):model.eval()  #不启用BatchNormalization和Dropout,将BatchNormalization和Dropout置为Falsetest_loss=0correct=0with torch.no_grad():for data,target in test_loader:data,target=data.to(device),target.to(device)  #在gpu上跑output=model(data)test_loss+=F.nll_loss(output,target,reduction='sum').item()  #将一批的损失相加pred=output.max(1,keepdim=True)[1] #找到概率最大的下标correct+=pred.eq(target.view_as(pred)).sum().item()test_loss/=len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss,correct,len(test_loader.dataset),100. * correct/len(test_loader.dataset)))#下面开始训练,这里就体现出封装起来的好处了,只要写两行就可以了
for epoch in range(1,EPOCHS+1):train(model,DEVICE,train_loader,optimizer,epoch)test(model,DEVICE,test_loader)

深度学习基础实战使用MNIST数据集对图片分类相关推荐

  1. 【深度学习】实战之MNIST

    既是实战,也是入门~ MNIST介绍 MNIST是深度学习领域的一个经典数据集,内含60000张训练图像与10000张预测图像,每张图片为28像素*28像素的灰度图像,并被划分到10个类别中(0-9) ...

  2. 深度学习4:使用MNIST数据集(tensorflow)

    本文将介绍MNIST数据集的数据格式和使用方法,使用到的是tensorflow中封装的类,包含代码. MNIST数据集来源于这里, 如果希望下载原始格式的数据集,可以从这里下载.而本文中讲解的是已经使 ...

  3. Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

    文章目录 一.概述 二.代码编写 1. 数据处理 2. 准备配置文件 3. 自定义DataSet和DataLoader 4. 构建模型 5. 训练模型 6. 编写预测模块 三.效果展示 四.源码地址 ...

  4. PyTorch深度学习项目实战100例数据集

    前言

  5. [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98%+

    [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98.8%+ 我们在博文,使用CNN做Kaggle比赛手写数字识别准确率99%+,在此基础之 ...

  6. 第3章(3.11~3.16节)模型细节/Kaggle实战【深度学习基础】--动手学深度学习【Tensorflow2.0版本】

    项目地址:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0 UC 伯克利李沐的<动手学深度学习>开源书一经推出便广受好评.很多开 ...

  7. 【PyTorch深度学习项目实战100例目录】项目详解 + 数据集 + 完整源码

    前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...

  8. 深度学习基础之卷积神经网络

    摘要 受Hubel和Wiesel对猫视觉皮层电生理研究启发,有人提出卷积神经网络(CNN),Yann Lecun 最早将CNN用于手写数字识别并一直保持了其在该问题的霸主地位.近年来卷积神经网络在多个 ...

  9. TensorFlow 2.0深度学习案例实战

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 基于TensorFlow 2.0正式版, 理论与实战结合,非常适合入门学习! 这是一本面向人工 ...

最新文章

  1. ETH Zurich提出新型网络「ROAD-Net」,解决语义分割域适配问题
  2. 多个sphinx配置文件合并
  3. JAVA复习(date)
  4. 计算机会计综合作业,20年7月东财《通用财务软件X》综合作业(100分)
  5. JS框架设计读书笔记之-函数
  6. SQOOP 基础及安装
  7. ant centos环境下 编译没有将配置文件加载_Linux 下的动态库、静态库与环境变量...
  8. 硬盘分区变为RAW文件系统后的解决办法
  9. linux环境下给文件加密/解密的方法
  10. 运动计步app开发的功能分析
  11. python考试报名官网安徽_2019年3月安徽宿州学院全国计算机等级考试报名通知
  12. html 中各种鼠标手势
  13. 微软收购雅虎遇新难题 或遭中国反垄断法阻碍
  14. 微信订阅消息(后端)教程
  15. 文件的元数据信息的含义及查看和修改
  16. js对文字批注_HTML 页面添加批注 - JavaScript - ITeye
  17. android studio try catch自动生成,Android Studio:Try-catch异常崩溃了应用程序
  18. android截视频软件,裁剪切视频app
  19. c语言建立26个字母的顺序表,线性表的操作建立一个含26个英文字母的数据元素的线性表并输出该表 爱问知识人...
  20. 大数相乘 (模板)

热门文章

  1. rhq监控软件_RHQ指标的WildFly子系统
  2. TestNG中的参数化– DataProvider和TestNG XML(带有示例)
  3. AWS Lambda事件源映射:使您的触发器混乱无序
  4. Java –从列表中删除所有空值
  5. input发送a.jax_JAX-RS 2.0:自定义内容处理
  6. 用普罗米修斯和格拉法纳乐器来刺豪猪
  7. tomcat 正常关闭_Tomcat的带有守护程序和关闭钩子的正常关闭
  8. Nutshell中的Java 8语言功能-第1部分
  9. Java命令行界面(第22部分):argparser
  10. 为@Cacheable设置TTL – Spring