mnist数据集下载及使用
# mnist数据集 在百度云盘里
# 链接:https://pan.baidu.com/s/1ca2rL2-0_JLtnH1YQ3otvA
# 提取码:uq3d
# pytorch自带数据集的使用
import torchvision
from torchvision.datasets import MNISTmnist = MNIST(root="./data",train=True,download=False)
print(mnist[0])
mnist[0][0].show()
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
mnist数据集下载及使用相关推荐
- MNIST数据集下载及可视化
MNIST数据集介绍 MNIST数据集官网:http://yann.lecun.com/exdb/mnist/ MNIST数据库是非常经典的一个数据集,就像你学编程起初写一个"Hello W ...
- 深度学习入门-基于Python的理论入门与实现源代码加mnist数据集下载推荐
深度学习入门-基于Python的理论入门与实现源代码加mnist数据集下载推荐 书籍封面 1-图灵网站下载 书里也说了,可以图灵网站下载https://www.ituring.com.cn/book/ ...
- Python 手写数字识别 MNIST数据集下载失败
目录 一.MNIST数据集下载失败 1 失败的解决办法(经验教训): 2 亲测有效的解决方法: 一.MNIST数据集下载失败 场景复现:想要pytorch+MINIST数据集来实现手写数字识别,首先就 ...
- MNIST数据集下载+idx3-ubyte解析【超详细+上手简单】
前言 训练模型的时候经常会使用MNIST数据集来训练模型,那么如何获取到MNIST数据集呢?博主经过实践后,总结了经验,希望能帮助到屏幕前的你使用MNIST数据集. 目录 前言 1 下载MNIST数据 ...
- 关于mnist数据集下载的相关问题
文章目录 问题描述:在Tensorflow 2.0.1版本中下载mnist数据集 原因分析: 解决方案: 问题描述:在Tensorflow 2.0.1版本中下载mnist数据集 from tensor ...
- 手写数字识别MNIST数据集下载百度网盘链接快速下载
介绍 MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片. 下载 官方链接:http://ya ...
- 关于TensorFlow的MNIST数据集下载脚本input_data.py的坑
今天用github上的代码入门tensorflow但是发现似乎要下载数据集,但是这个我弄了一会才明白是怎么下的,所以把经验写在下面:(ubuntu14.04环境) 用github上的input_dat ...
- mnist 数据集 下载 训练 测试 pytorch
1.下载 可以使用 #train_set = mnist.MNIST('./data', train=True, download=True) 但是速度慢一般无法下载,官网下载也较慢 提供官网下载的压 ...
- pytorch Fashion MNIST 数据集下载慢怎么办
import torch # 导入pytorch from torchvision import datasets, transforms ## 导入数据集与数据预处理的方法 import matpl ...
- TensorFlow Mnist数据集下载问题
安装好TensorFlow后,按教程输入如下命令时,会出现不能下载数据的问题. from tensorflow.examples.tutorials.mnist import input_data m ...
最新文章
- Openpose+Tensorflow 这样实现人体姿态估计 | 代码干货
- 为什么一般用自增列作为主键?
- GMM高斯混合模型学习笔记(EM算法求解)
- python enumerate_python中enumerate的用法实例解析
- 【error】scripts/basic/fixdep: Syntax error: ( unexpected
- kali安装docker(有效详细的教程)
- xshell怎么让程序后台运行_使程序在Linux下后台运行
- noip模拟赛 对刚
- Toxophily(hdu2298三分+二分)
- 【Shiro第二篇】SpringBoot + Shiro实现用户身份认证功能
- vim复制,粘贴,删除,撤销,替换,光标移动等用法
- ffmpeg 将swf文件转 mp4
- Python学习之Python入门知识(一)
- 根据hash值找到bt种子的磁力下载链
- Amazon AWS 中国区 G2 服务器 配置运行
- Python程序员关于爬虫的一些常见面试题
- Frequency Estimation
- python和r语言哪个简单_python与r语言哪个简单
- MobaXterm登录密码重置
- SQL求用户的最大连续登陆天数