)全连接神经网络(FC)
全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC。

FC的准则很简单:神经网络中除输入层之外的每个节点都和上一层的所有节点有连接。例如下面这个网络结构就是典型的全连接:

神经网络的第一层为输入层,最后一层为输出层,中间所有的层都为隐藏层。在计算神经网络层数的时候,一般不把输入层算做在内,所以上面这个神经网络为2层。其中输入层有3个神经元,隐层有4个神经元,输出层有2个神经元。

更多关于神经网络的基本介绍,请参考此链接:https://www.zybuluo.com/hanbingtao/note/476663

2)三层FC实现MNIST手写数字分类
我们定义了三个不层次的神经网络模型:简单的FC,加激活函数的FC,加激活函数和批标准化的FC。

# !/usr/bin/python
# coding: utf8
# @Time    : 2018-08-04 19:09
# @Author  : Liam
# @Email   : luyu.real@qq.com
# @Software: PyCharm
#                        .::::.
#                      .::::::::.
#                     :::::::::::
#                  ..:::::::::::'
#               '::::::::::::'
#                 .::::::::::
#            '::::::::::::::..
#                 ..::::::::::::.
#               ``::::::::::::::::
#                ::::``:::::::::'        .:::.
#               ::::'   ':::::'       .::::::::.
#             .::::'      ::::     .:::::::'::::.
#            .:::'       :::::  .:::::::::' ':::::.
#           .::'        :::::.:::::::::'      ':::::.
#          .::'         ::::::::::::::'         ``::::.
#      ...:::           ::::::::::::'              ``::.
#     ```` ':.          ':::::::::'                  ::::..
#                        '.:::::'                    ':'````..
#                     美女保佑 永无BUG
from torch import nn

class simpleNet(nn.Module):
    """
    定义了一个简单的三层全连接神经网络,每一层都是线性的
    """
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(simpleNet, self).__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer3 = nn.Linear(n_hidden_2, out_dim)

def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

class Activation_Net(nn.Module):
    """
    在上面的simpleNet的基础上,在每层的输出部分添加了激活函数
    """
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Activation_Net, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.ReLU(True))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
        """
        这里的Sequential()函数的功能是将网络的层组合到一起。
        """

def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

class Batch_Net(nn.Module):
    """
    在上面的Activation_Net的基础上,增加了一个加快收敛速度的方法——批标准化
    """
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Batch_Net, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1), nn.ReLU(True))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))

def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
接下来导入一些必要的包,除了之前常用的torch, nn, Variable之外,还导入了DataLoader用于加载数据,使用了torchvision进行图片的预处理,然后,导入了上面所定义的三个模型。

import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 之前所定义的神经网络模型
import net
然后定义一些超参数。

# 定义一些超参数
batch_size = 64
learning_rate = 0.02
num_epoches = 20
在torchvision中提供了transforms用于帮我们对图片进行预处理和标准化。其中我们需要用到的有两个:ToTensor()和Normalize()。前者用于将图片转换成Tensor格式的数据,并且进行了标准化处理。后者用均值和标准偏差对张量图像进行归一化:给定均值: (M1,...,Mn) 和标准差: (S1,..,Sn) 用于 n 个通道, 该变换将标准化输入 torch.*Tensor 的每一个通道。其处理公式为:

而Compose函数可以将上述两个操作合并到一起执行。

# 数据预处理。transforms.ToTensor()将图片转换成PyTorch中处理的对象Tensor,并且进行标准化(数据在0~1之间)
# transforms.Normalize()做归一化。它进行了减均值,再除以标准差。两个参数分别是均值和标准差
# transforms.Compose()函数则是将各种预处理的操作组合到了一起
data_tf = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5], [0.5])])
PyTorch提供了一些常用数据集,所以我们定义数据集下载器如下:

# 数据集的下载器
train_dataset = datasets.MNIST(
    root='./data', train=True, transform=data_tf, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
然后选择相应的神经网路模型来进行训练和测试,我们这里定义的神经网络输入层为28*28,因为我们处理过的图片像素为28*28,两个隐层分别为300和100,输出层为10,因为我们识别0~9十个数字,需要分为十类。损失函数和优化器这里采用了交叉熵和梯度下降。

# 选择模型
model = net.simpleNet(28 * 28, 300, 100, 10)
# model = net.Activation_Net(28 * 28, 300, 100, 10)
# model = net.Batch_Net(28 * 28, 300, 100, 10)
if torch.cuda.is_available():
    model = model.cuda()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
下面是训练阶段,和之前大同小异。

# 训练模型
epoch = 0
for data in train_loader:
    img, label = data
    img = img.view(img.size(0), -1)
    if torch.cuda.is_available():
        img = img.cuda()
        label = label.cuda()
    else:
        img = Variable(img)
        label = Variable(label)
    out = model(img)
    loss = criterion(out, label)
    print_loss = loss.data.item()

optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch+=1
    if epoch%50 == 0:
        print('epoch: {}, loss: {:.4}'.format(epoch, loss.data.item()))
然后测试一下我们的模型:

# 模型评估
model.eval()
eval_loss = 0
eval_acc = 0
for data in test_loader:
    img, label = data
    img = img.view(img.size(0), -1)
    if torch.cuda.is_available():
        img = img.cuda()
        label = label.cuda()

out = model(img)
    loss = criterion(out, label)
    eval_loss += loss.data.item()*label.size(0)
    _, pred = torch.max(out, 1)
    num_correct = (pred == label).sum()
    eval_acc += num_correct.item()
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(
    eval_loss / (len(test_dataset)),
    eval_acc / (len(test_dataset))
))
之前我们定义了三个不同层次的神经网络,选择不同的网络进行训练和测试,都可以看到类似于下面这样的程序执行结果:

完整代码请移步我的GitHub

--------------------- 
作者:Liam Coder 
来源:CSDN 
原文:https://blog.csdn.net/out_of_memory_error/article/details/81414986 
版权声明:本文为博主原创文章,转载请附上博文链接!

PyTorch基础入门五:PyTorch搭建多层全连接神经网络实现MNIST手写数字识别分类相关推荐

  1. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

  2. 全连接神经网络实现MNIST手写数字识别

    在对全连接神经网络的基本知识(全连接神经网络详解)学习之后,通过MNIST手写数字识别这个小项目来学习如何实现全连接神经网络. MNIST数据集 对于深度学习的任何项目来说,数据集是其中最为关键的部分 ...

  3. 深蓝学院第二章:基于全连接神经网络(FCNN)的手写数字识别

    如何用全连接神经网络去做手写识别??? 使用的是jupyter notebook这个插件进行代码演示.(首先先装一个Anaconda的虚拟环境,然后自己构建一个自己的虚拟环境,然后在虚拟环境中安装ju ...

  4. PyTorch入门一:卷积神经网络实现MNIST手写数字识别

    先给出几个入门PyTorch的好的资料: PyTorch官方教程(中文版):http://pytorch123.com <动手学深度学习>PyTorch版:https://github.c ...

  5. 深度学习3—用三层全连接神经网络训练MNIST手写数字字符集

    上一篇文章:深度学习2-任意结点数的三层全连接神经网络 距离上篇文章过去了快四个月了,真是时光飞逝,之前因为要考博所以耽误了更新,谁知道考完博后之前落下的接近半个学期的工作是如此之多,以至于弄到现在才 ...

  6. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  7. Pytorch入门——MNIST手写数字识别代码

    MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...

  8. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

  9. 使用PYTORCH复现ALEXNET实现MNIST手写数字识别

    网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军,下面是Alexnet的网络结构: 网络结构较为简单,共有五个卷积层和三个全连接层,原 ...

最新文章

  1. H264 NALU 使用PS封装 RTP发送
  2. Google Container Engine进军生产环境,容器技术势不可挡
  3. sql server T-SQL 基础
  4. ssh服务常见问题及其解决办法
  5. 【springboot+easypoi】一行代码搞定excel导入导出
  6. shell 12 21 filename重定向的含义和区别
  7. 【kafka】Group coordinator xx is unavailable or invalid, will attempt rediscovery
  8. 洛谷P2534 [AHOI2012]铁盘整理
  9. shell(九)几个字符转换命令
  10. Oracle时间日期操作
  11. ubuntu搭建nfs网络文件系统
  12. [机器学习][三维重建] 凸包算法——Graham扫描
  13. Rplidar A2M8屏蔽角度
  14. Java5的倍数_关于java:将数字四舍五入到最接近的5的倍数
  15. AndroidStudio与Eclipse快捷键
  16. iOS 的 (签名验签)Code Signing 体系
  17. 服务器快速操作pc文件,如何将普通pc做服务器
  18. 大学计算机作业互评评语简短,同学作业互评评语
  19. 【梯度消失和梯度爆炸问题详解】
  20. MATLAB: 2018a百度云资源、迅雷资源、安装步骤

热门文章

  1. C# 四舍五入round函数使用的代码
  2. 扩展源_Ubuntu14版本下无法使用php7.2版本的bcmath扩展
  3. python中math.ceil是什么意思_python中的数字取整(ceil,floor,round)概念和用法
  4. 输入20本书的书名,作者,出版社,出版日期,单价,按书名排序输出
  5. localStorage、sessionStorage、Cookie的区别及用法
  6. windows核心编程之进程间共享数据
  7. 浅谈(线性)卷积公式为什么要翻转
  8. 查询整个数据库中某个特定值所在的表和字段的方法
  9. C语言字符串操作函数
  10. mysql搭建主从的目的_mysql搭建主从