cnn stride and padding_Pytorch实现神经网络CNN案例
import timeimport torchfrom torch import nn, optimfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms"""版本: python==3.7 pytorch==1.6"""# 超参LR = 0.001 ## 学习效率NUM_EPOCH = 2 ## 训练次数BATCH_SIZE = 50 ## 每批次训练 50 条数据device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ## 使用 GPU 加速## 数据集准备train_dataset = datasets.MNIST( root='./dataset', ## 数据储存位置 train=True, ## True 为训练集 transform=transforms.ToTensor(), download=False ## 是否下载 MNIST 数据集(我已经下好了),没下载的写 True)train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)test_dataset = datasets.MNIST( root='./dataset', train=False, ## False为测试集 transform=transforms.ToTensor(), download=False)test_data = torch.unsqueeze(test_dataset.data, dim=1).type(torch.FloatTensor)[:8000].to(device) / 255. ## 测试集数据与训练集数据维度保持一致test_label = test_dataset.targets.to(device)[:8000]val_data = torch.unsqueeze(test_dataset.data, dim=1).type(torch.FloatTensor)[8000:].to(device) / 255. val_label = test_dataset.targets.to(device)[8000:]## 搭建 CNN 神经网络class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d( ## 卷积层 1 in_channels=1, ## 输入图片维度为 1 * 28 * 28 out_channels=16, ## 输出图片维度为 16 * 28 * 28 kernel_size=5, stride=1, padding=2 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ## 池化层 经池化操作后图片维度变为 16 * 14 * 14 ) self.conv2 = nn.Sequential( nn.Conv2d( in_channels=16, out_channels=32, # 32 * 14 * 14 kernel_size=5, stride=1, padding=2 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2)# 32 * 7 * 7 ) self.out = nn.Linear(32*7*7, 10) ## 全连接层 def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size()[0], -1) out = self.out(x) return outcnn = CNN().to(device) ## 将模型移植到 GPU 上loss_func = nn.CrossEntropyLoss() ## 损失函数optimizer = optim.Adam(cnn.parameters(), lr=LR) ## 优化器 ## 输出信息格式def print_info(num, str, line=True, start=True, end=True): if line: str = "-"*5 + " " + str + " " + "-"*5 blank = " " * ((num - len(str)) // 2) if start: print("\n" + "*" * num) print("\n%s%s%s\n" % (blank, str, blank)) if end: print("*" * num)## 开始训练start_0 = time.time()num = 69start_str = "Start training"print_info(num, start_str, end=False)for epoch in range(NUM_EPOCH): for step, (data, label) in enumerate(train_loader): data = data.to(device) ## 将数据移植到 GPU上 label = label.to(device) ## 将数据移植到 GPU上 out = cnn(data) loss = loss_func(out, label) optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 == 0: with torch.no_grad(): test_out = cnn(test_data) pred_y = torch.max(test_out, 1)[1] accurancy = torch.sum(pred_y==test_label) / float(test_label.size()[0]) print("Epoch: [%s/%s] |, step: [%s/%s] |, loss: %.4f |, accuracy: %.2f%%" % (epoch+1, NUM_EPOCH, step, train_loader.__len__(), loss.data.item(), 100*accurancy.data.item()))### 烦死了,懒得写注释了, 不知道的去 scdn 或者 Google 或者 百度,不行再找我time_str = "Training time: %.2fs" % (time.time() - start_0)print_info(num, time_str, line=False, start=False, end=False)end_str = "The end of training"print_info(num, end_str, start=False)## 验证with torch.no_grad(): val_out = cnn(val_data) pred_y = torch.max(val_out, 1)[1] accuracy = torch.sum(pred_y == val_label) / float(val_label.size()[0]) start_str = "Start verification" print_info(num, start_str, start=False, end=False) val_str = "Validating set accurancy: %.2f%%" % (100*accurancy.data.item()) print_info(num, val_str, line=False, start=False, end=False) end_str = "The end of the validation" print_info(num, end_str, start=False)## savestart_str = "Start saving the model"print_info(num, start_str, start=False, end=False)start1 = time.time()torch.save(cnn, './models/cnn.pkl')save_str = "Save the elapsed time of model completely: %.2fs" % (time.time() - start1)print_info(num, save_str, line=False, start=False, end=False)start2 = time.time()torch.save(cnn.state_dict(), './models/cnn_params.pkl')save_str = "The time it takes to save model parameters: %.2fs" % (time.time() - start2)print_info(num, save_str, line=False, start=False, end=False)end_str = "The end of saving the model"print_info(num, end_str, start=False)
输出:
懒得写了,就这样。。。。。。
cnn stride and padding_Pytorch实现神经网络CNN案例相关推荐
- TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略
TF之CNN:Tensorflow构建卷积神经网络CNN的简介.使用方法.应用之详细攻略 目录 TensorFlow 中的卷积有关函数入门 1.tf.nn.conv2d函数 案例应用 1.TF之CNN ...
- cnn stride and padding_卷积神经网络(CNN) 第 4 课(上)
点此亲启 致各位 之前感觉深度学习高深莫测,paper中的各种名词让人望而生畏,高端落地应用络绎不绝,说实话有些恐惧,但是当你一点一点地开始接触它,慢慢了解到每个名词的含义,心里一句"昂~也 ...
- MATLAB卷积神经网络cnn,Matlab编程之——卷积神经网络CNN代码解析
deepLearnToolbox-master是一个深度学习matlab包,里面含有很多机器学习算法,如卷积神经网络CNN,深度信念网络DBN,自动编码AutoEncoder(堆栈SAE,卷积CAE) ...
- 卷积神经网络CNN图解
背景 之前在网上搜索了好多好多关于CNN的文章,由于网络上的文章很多断章取义或者描述不清晰,看了很多youtobe上面的教学视频还是没有弄懂,最后经过痛苦漫长的煎熬之后对于神经网络和卷积有了粗浅的了解 ...
- 【卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10)】
卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10) 在上一章已经完成了卷积神经网络的结构分析,并通过各个模块理解 ...
- DL之CNN:计算机视觉之卷积神经网络算法的简介(经典架构/论文)、CNN优化技术、调参学习实践、CNN经典结构及其演化、案例应用之详细攻略
DL之CNN:计算机视觉之卷积神经网络算法的简介(经典架构/论文).CNN优化技术.调参学习实践.CNN经典结构.案例应用之详细攻略 目录 卷积神经网络算法的简介 0.Biologically Ins ...
- 花书+吴恩达深度学习(十四)卷积神经网络 CNN 之经典案例(LetNet-5, AlexNet, VGG-16, ResNet, Inception Network)
目录 0. 前言 1. LeNet-5 2. AlexNet 3. VGG-16 4. ResNet 残差网络 5. Inception Network 如果这篇文章对你有一点小小的帮助,请给个关注, ...
- 【神经网络】(3) 卷积神经网络(CNN),案例:动物三分类,附python完整代码
各位同学好,今天和大家分享一下TensorFlow2.0深度学习中卷积神经网络的案例.现在有猫.狗.熊猫图片一千张,构建卷积神经网络实现图像的分类预测. 1. 数据加载 将训练测试数据划分好后放在同一 ...
- 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理(1)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
最新文章
- hibernate中 query 的list方法 用法
- 【转】android如何查看cpu的占用率和内存泄漏
- html 让表格在右侧显示不出来,css中怎么解决表格边框不显示的问题?
- YII2操作mongodb笔记(转)
- 操作系统【一】进程同步和信号量
- 【文末送书】调参太费力?自动化机器学习来帮你!
- recommend a cool calendar
- smtplib python教程_Python基于smtplib实现异步发送邮件服务
- Tomcat6.0 管理器配置
- 好风凭借力,送我上青云!
- Java内存模型以及happens-before规则
- 软件工程导论---软件测试(集成测试、单元测试、验收测试、系统测试)
- Ubuntu 安装磁盘分区及启动项添加
- linux下phylip软件构建NJ树,MEGA软件——系统发育树构建方法(图文讲解)
- 为什么是“止于至善”?
- 搜索引擎代码资源[转载]
- python写邮箱系统_Python django实现简单的邮件系统发送邮件功能
- 11、安全网络架构和保护网络组件
- EverNote开源协议-Android
- Dcat-Admin自定义Excel数据导出
热门文章
- UI设计干货素材|动效导航,漂亮的悬停动效
- php格式化输出字_PHP 输出格式化字符串
- QT保留小数点后位数
- Linux Socket C语言网络编程:Pthread Socket [code from GitHub, for study]
- Django:应用程序的两种架构:C/S架构,B/S架构,(TCP, URL)HTTP,HTTP request, HTTP response
- C++类的定义和创建
- OpneCV之图像的平移、翻转、旋转、缩放、裁剪(笔记04)
- kafka修改分区数_ELK|kafka增加分区或调整副本数
- shell_exec() php 执行shell脚本
- 循环计数_倒计数器:CountDownLatch | 循环栅栏:CyclicBarrier