import timeimport torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms"""版本:    python==3.7    pytorch==1.6"""# 超参LR = 0.001  ## 学习效率NUM_EPOCH = 2  ## 训练次数BATCH_SIZE = 50  ## 每批次训练 50 条数据device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  ## 使用 GPU 加速## 数据集准备train_dataset = datasets.MNIST(    root='./dataset',  ## 数据储存位置    train=True,  ## True 为训练集    transform=transforms.ToTensor(),    download=False  ## 是否下载 MNIST 数据集(我已经下好了),没下载的写 True)train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)test_dataset = datasets.MNIST(    root='./dataset',    train=False,  ## False为测试集    transform=transforms.ToTensor(),    download=False)test_data = torch.unsqueeze(test_dataset.data, dim=1).type(torch.FloatTensor)[:8000].to(device) / 255.  ## 测试集数据与训练集数据维度保持一致test_label = test_dataset.targets.to(device)[:8000]val_data = torch.unsqueeze(test_dataset.data, dim=1).type(torch.FloatTensor)[8000:].to(device) / 255.  val_label = test_dataset.targets.to(device)[8000:]## 搭建 CNN 神经网络class CNN(nn.Module):    def __init__(self):        super(CNN, self).__init__()        self.conv1 = nn.Sequential(            nn.Conv2d(              ## 卷积层 1                in_channels=1,      ## 输入图片维度为 1 * 28 * 28                out_channels=16,    ## 输出图片维度为 16 * 28 * 28                kernel_size=5,                 stride=1,                padding=2            ),            nn.ReLU(),            nn.MaxPool2d(kernel_size=2) ## 池化层  经池化操作后图片维度变为 16 * 14 * 14        )        self.conv2 = nn.Sequential(            nn.Conv2d(                in_channels=16,                out_channels=32,    # 32 * 14 * 14                kernel_size=5,                stride=1,                padding=2            ),            nn.ReLU(),            nn.MaxPool2d(kernel_size=2)# 32 * 7 * 7        )        self.out = nn.Linear(32*7*7, 10)  ## 全连接层    def forward(self, x):        x = self.conv1(x)        x = self.conv2(x)        x = x.view(x.size()[0], -1)        out = self.out(x)        return outcnn = CNN().to(device)  ## 将模型移植到 GPU 上loss_func = nn.CrossEntropyLoss()  ## 损失函数optimizer = optim.Adam(cnn.parameters(), lr=LR)  ## 优化器 ## 输出信息格式def print_info(num, str, line=True, start=True, end=True):    if line:        str = "-"*5 + " " + str + " " + "-"*5    blank = " " * ((num - len(str)) // 2)    if start:        print("\n" + "*" * num)    print("\n%s%s%s\n" % (blank, str, blank))    if end:        print("*" * num)## 开始训练start_0 = time.time()num = 69start_str = "Start training"print_info(num, start_str, end=False)for epoch in range(NUM_EPOCH):    for step, (data, label) in enumerate(train_loader):        data = data.to(device)  ## 将数据移植到 GPU上        label = label.to(device)  ## 将数据移植到 GPU上        out = cnn(data)        loss = loss_func(out, label)        optimizer.zero_grad()        loss.backward()        optimizer.step()        if step % 100 == 0:            with torch.no_grad():                test_out = cnn(test_data)                pred_y = torch.max(test_out, 1)[1]                accurancy = torch.sum(pred_y==test_label) / float(test_label.size()[0])                print("Epoch: [%s/%s] |, step: [%s/%s] |, loss: %.4f |, accuracy: %.2f%%" % (epoch+1, NUM_EPOCH, step, train_loader.__len__(), loss.data.item(), 100*accurancy.data.item()))### 烦死了,懒得写注释了, 不知道的去 scdn 或者 Google 或者 百度,不行再找我time_str = "Training time: %.2fs" % (time.time() - start_0)print_info(num, time_str, line=False, start=False, end=False)end_str = "The end of training"print_info(num, end_str, start=False)## 验证with torch.no_grad():    val_out = cnn(val_data)    pred_y = torch.max(val_out, 1)[1]    accuracy = torch.sum(pred_y == val_label) / float(val_label.size()[0])    start_str = "Start verification"    print_info(num, start_str, start=False, end=False)    val_str = "Validating set accurancy: %.2f%%" % (100*accurancy.data.item())    print_info(num, val_str, line=False, start=False, end=False)    end_str = "The end of the validation"    print_info(num, end_str, start=False)## savestart_str = "Start saving the model"print_info(num, start_str, start=False, end=False)start1 = time.time()torch.save(cnn, './models/cnn.pkl')save_str = "Save the elapsed time of model completely: %.2fs" % (time.time() - start1)print_info(num, save_str, line=False, start=False, end=False)start2 = time.time()torch.save(cnn.state_dict(), './models/cnn_params.pkl')save_str = "The time it takes to save model parameters: %.2fs" % (time.time() - start2)print_info(num, save_str, line=False, start=False, end=False)end_str = "The end of saving the model"print_info(num, end_str, start=False)

输出:

懒得写了,就这样。。。。。。

cnn stride and padding_Pytorch实现神经网络CNN案例相关推荐

  1. TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略

    TF之CNN:Tensorflow构建卷积神经网络CNN的简介.使用方法.应用之详细攻略 目录 TensorFlow 中的卷积有关函数入门 1.tf.nn.conv2d函数 案例应用 1.TF之CNN ...

  2. cnn stride and padding_卷积神经网络(CNN) 第 4 课(上)

    点此亲启 致各位 之前感觉深度学习高深莫测,paper中的各种名词让人望而生畏,高端落地应用络绎不绝,说实话有些恐惧,但是当你一点一点地开始接触它,慢慢了解到每个名词的含义,心里一句"昂~也 ...

  3. MATLAB卷积神经网络cnn,Matlab编程之——卷积神经网络CNN代码解析

    deepLearnToolbox-master是一个深度学习matlab包,里面含有很多机器学习算法,如卷积神经网络CNN,深度信念网络DBN,自动编码AutoEncoder(堆栈SAE,卷积CAE) ...

  4. 卷积神经网络CNN图解

    背景 之前在网上搜索了好多好多关于CNN的文章,由于网络上的文章很多断章取义或者描述不清晰,看了很多youtobe上面的教学视频还是没有弄懂,最后经过痛苦漫长的煎熬之后对于神经网络和卷积有了粗浅的了解 ...

  5. 【卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10)】

    卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10) 在上一章已经完成了卷积神经网络的结构分析,并通过各个模块理解 ...

  6. DL之CNN:计算机视觉之卷积神经网络算法的简介(经典架构/论文)、CNN优化技术、调参学习实践、CNN经典结构及其演化、案例应用之详细攻略

    DL之CNN:计算机视觉之卷积神经网络算法的简介(经典架构/论文).CNN优化技术.调参学习实践.CNN经典结构.案例应用之详细攻略 目录 卷积神经网络算法的简介 0.Biologically Ins ...

  7. 花书+吴恩达深度学习(十四)卷积神经网络 CNN 之经典案例(LetNet-5, AlexNet, VGG-16, ResNet, Inception Network)

    目录 0. 前言 1. LeNet-5 2. AlexNet 3. VGG-16 4. ResNet 残差网络 5. Inception Network 如果这篇文章对你有一点小小的帮助,请给个关注, ...

  8. 【神经网络】(3) 卷积神经网络(CNN),案例:动物三分类,附python完整代码

    各位同学好,今天和大家分享一下TensorFlow2.0深度学习中卷积神经网络的案例.现在有猫.狗.熊猫图片一千张,构建卷积神经网络实现图像的分类预测. 1. 数据加载 将训练测试数据划分好后放在同一 ...

  9. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理(1)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

最新文章

  1. hibernate中 query 的list方法 用法
  2. 【转】android如何查看cpu的占用率和内存泄漏
  3. html 让表格在右侧显示不出来,css中怎么解决表格边框不显示的问题?
  4. YII2操作mongodb笔记(转)
  5. 操作系统【一】进程同步和信号量
  6. 【文末送书】调参太费力?自动化机器学习来帮你!
  7. recommend a cool calendar
  8. smtplib python教程_Python基于smtplib实现异步发送邮件服务
  9. Tomcat6.0 管理器配置
  10. 好风凭借力,送我上青云!
  11. Java内存模型以及happens-before规则
  12. 软件工程导论---软件测试(集成测试、单元测试、验收测试、系统测试)
  13. Ubuntu 安装磁盘分区及启动项添加
  14. linux下phylip软件构建NJ树,MEGA软件——系统发育树构建方法(图文讲解)
  15. 为什么是“止于至善”?
  16. 搜索引擎代码资源[转载]
  17. python写邮箱系统_Python django实现简单的邮件系统发送邮件功能
  18. 11、安全网络架构和保护网络组件
  19. EverNote开源协议-Android
  20. Dcat-Admin自定义Excel数据导出

热门文章

  1. UI设计干货素材|动效导航,漂亮的悬停动效
  2. php格式化输出字_PHP 输出格式化字符串
  3. QT保留小数点后位数
  4. Linux Socket C语言网络编程:Pthread Socket [code from GitHub, for study]
  5. Django:应用程序的两种架构:C/S架构,B/S架构,(TCP, URL)HTTP,HTTP request, HTTP response
  6. C++类的定义和创建
  7. OpneCV之图像的平移、翻转、旋转、缩放、裁剪(笔记04)
  8. kafka修改分区数_ELK|kafka增加分区或调整副本数
  9. shell_exec() php 执行shell脚本
  10. 循环计数_倒计数器:CountDownLatch | 循环栅栏:CyclicBarrier