softmax函数

代码

import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置train = True, # 载入训练集transform=transforms.ToTensor(), # 把数据变成tensor类型download = True # 下载)
# 测试集
test_data = datasets.MNIST(root="./",train = False,transform=transforms.ToTensor(),download = True)
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):inputs,labels = dataprint(inputs.shape)print(labels.shape)break


# 定义网络结构
class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 初始化self.fc1 = nn.Linear(784,10) # 784个输入10个输出self.softmax = nn.Softmax(dim=1) # 激活函数 dim=1表示对第一个维度进行概率计算def forward(self,x):# torch.Size([64, 1, 28, 28]) -> (64,784)x = x.view(x.size()[0],-1) # 4维变2维 (在全连接层做计算只能2维)x = self.fc1(x) # 传给全连接层继续计算x = self.softmax(x) # 使用softmax激活函数进行计算return x
# 定义模型
model = Net()
# 定义代价函数
mse_loss = nn.MSELoss()
# 定义优化器
optimizer = optim.SGD(model.parameters(),lr=0.5)
# 定义模型训练和测试的方法
def train():for i,data in enumerate(train_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# to onehot 把数据标签变成独热编码labels = labels.reshape(-1,1) # 先把1维变成2维(64)-(64,1)# tensor.scatter(dim,index,src)# dim:对那个维度进行独热编码# index:要将src中对应的值放到tensor那个位置# src:插入index的数值one_hot = torch.zeros(inputs.shape[0],10).scatter(1,labels,1)# 计算loss   mse_loss的两个数据的shape要一致loss = mse_loss(out,one_hot)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test():correct = 0for i,data in enumerate(test_loader):# 获得一个批次的数据和标签inputs,labels = data# 获得模型预测结果(64,10)out = model(inputs)# 获得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 预测正确的数量correct += (predicted==labels).sum()print("Test acc:{0}".format(correct.item()/len(test_data)))
# 训练
for epoch in range(10):print("epoch:",epoch)train()test()

PyTorch基础-softmax函数mnist数据集识别-03相关推荐

  1. PyTorch基础-交叉熵函数mnist数据集识别-04

    交叉熵 代码 import numpy as np import torch from torch import nn,optim from torch.autograd import Variabl ...

  2. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  3. 基于一个线性层的softmax回归模型和MNIST数据集识别自己手写数字

    原博文是用cnn识别,因为我是在自己电脑上跑代码,用不了处理器,所以参考Mnist官网上的一个线性层的softmax回归模型的代码,把两篇文章结合起来识别. 最后效果 源代码识别mnist数据集的准确 ...

  4. TensorFlow:实战Google深度学习框架(四)MNIST数据集识别问题

    第5章 MNIST数字识别问题 5.1 MNIST数据处理 5.2 神经网络的训练以及不同模型结果的对比 5.2.1 TensorFlow训练神经网络 5.2.2 使用验证数据集判断模型的效果 5.2 ...

  5. 深度学习基础实战使用MNIST数据集对图片分类

    本文代码完全借鉴pytorch中文手册 '''我们找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整.''' import torch ...

  6. 深度学习入门-误差反向传播法(人工神经网络实现mnist数据集识别)

    文章目录 误差反向传播法 5.1 链式法则与计算图 5.2 计算图代码实践 5.3激活函数层的实现 5.4 简单矩阵求导 5.5 Affine 层的实现 5.6 softmax-with-loss层计 ...

  7. [转载] 卷积神经网络做mnist数据集识别

    参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...

  8. Keras【Deep Learning With Python】MNIST数据集识别优化

    文章目录 前言 1 线性回归预测 2 手写数字识别 3 模型优化 前言 本文分为三部分: a.线性回归 b.手写数字识别 c.手写数字识别模型优化. 1 线性回归预测 import keras Usi ...

  9. 深度学习之基于卷积神经网络实现超大Mnist数据集识别

    在以往的手写数字识别中,数据集一共是70000张图片,模型准确率可以达到99%以上的准确率.而本次实验的手写数字数据集中有120000张图片,而且数据集的预处理方式也是之前没有遇到过的.最终在验证集上 ...

最新文章

  1. 字节跳动AI Lab社招以及实习生内推
  2. 简单的Writer和Reader
  3. Spring 整合 RocketMQ
  4. 45张令程序员泪流满面的趣图
  5. java输入最大10位数,倒数输出(很鸡肋)
  6. 【离散数学】实验 旅行路线规划问题
  7. 中国股市暴涨暴跌全记录
  8. python kivy canvas_python – Kivy:使用canvas为动画设置动画的正确方法是什么?
  9. Dubbo学习总结(3)——Dubbo-Admin管理平台和Zookeeper注册中心的搭建
  10. FoneDog Data Recovery数据恢复教程
  11. Python中的条件判断和循环
  12. latex 改变字体颜色
  13. Windows XP 系统中内置的AT命令
  14. python画围棋棋盘_Python语言程序设计之二--用turtle库画围棋棋盘和正、余弦函数图形...
  15. 产品经理的年终总结可以这样写
  16. depts: deep expansion learning for periodic time series forecasting
  17. 什么是云数据库RDS
  18. ASEMI代理AD823AARZ-RL原装ADI车规级AD823AARZ-RL
  19. opencv4.1无法加载python-cnn模型,编译第三方库libtensorflow_cc.so巨坑
  20. 计算机学院新生篮球赛名字,计算机学院新生篮球赛圆满结束,获奖队伍公布!...

热门文章

  1. mysql 事务sqlserver_MYSQL高级特性 -- 事务处理_sqlserver
  2. java在己有的类创子类怎么创_使用Java创建自己的异常子类
  3. mysql oracle mvcc_PostgreSQL、Oracle/MySQL和SQL Server的MVCC实现原理方式
  4. Git在公司内部的使用规范
  5. smarty模板概念及应用场合
  6. php 本地mysql 代码_基于本地数据库的 IP 地址查询 PHP 源码
  7. linux查看睡眠进程,关于 Linux 进程的睡眠和唤醒 ,来看这篇就够了~
  8. java file类详解_Java File类详解及IO介绍及使用
  9. mysql blob 导出_mysql blob导出文本解密 | 学步园
  10. linux如何导出加密卡私钥,linux – 如何使用gpg中的私钥加密文件