CNN卷积神经网络

基础概念:

以卷积操作为基础的网络结构,每个卷积核可以看成一个特征提取器。

思想:

每次观察数据的一部分,如图,在整个矩阵中只观察黄色部分3×3的矩阵,将这【3×3】矩阵·(点乘)权重得到特征矩阵的第一项,然后进行平移进行第二项的计算。依此类推,得到最后的特征矩阵。

利用Pytorch框架实现CNN

import torch
import torch.nn as nn
import numpy as np"""
使用pytorch实现CNN
不考虑偏差值
"""class TorchCNN(nn.Module):def __init__(self, in_channel, out_channel, kernel):super(TorchCNN, self).__init__()self.layer = nn.Conv2d(in_channel, out_channel, kernel, bias=False)def forward(self, x):return self.layer(x)x = np.array([[0.1, 0.2, 0.3, 0.4],[-3, -4, -5, -6],[5.1, 6.2, 7.3, 8.4],[-0.7, -0.8, -0.9, -1]])  #网络输入#torch实验
in_channel = 1      #单通道(NLP中一般用单通道)
out_channel = 3     #多少个卷积核(每一个卷积核代表一个独立的权重)
kernel_size = 2     #2*2的方块(功能就是图中黄色[3×3]矩阵)
torch_model = TorchCNN(in_channel, out_channel, kernel_size)
# print(torch_model.state_dict())
torch_w = torch_model.state_dict()["layer.weight"]
# print(torch_w.numpy().shape)
torch_x = torch.FloatTensor([[x]])
#权重是4维,输入应该也为四维,通过多一个[],将输入由三维变成四维
output = torch_model.forward(torch_x)
output = output.detach().numpy()
print(output, output.shape, "torch模型预测结果\n")

自定义模型代码实现CNN:

采用自定义模型实现CNN,不考虑偏差值,因为要与Pytorch框架结果相对比,需要调取在Pytorch模型中的输入和随机权重。因此如果要运行,须将此代码放在Pytorch框架下运行。

"""
手动实现简单的神经网络
与Pytorch对比实验
"""
#自定义CNN模型
class DiyModel:def __init__(self, input_height, input_width, weights, kernel_size):self.height = input_heightself.width = input_widthself.weights = weightsself.kernel_size = kernel_sizedef forward(self, x):output = []for kernel_weight in self.weights:kernel_weight = kernel_weight.squeeze().numpy()#weight取出来时是[1×2×2],通过squeeze变成[2×2],然后变成numpy取出kernel_output = np.zeros((self.height - kernel_size + 1, self.width - kernel_size + 1)) #全0输出矩阵for i in range(self.height - kernel_size + 1):for j in range(self.width - kernel_size + 1):window = x[i:i+kernel_size, j:j+kernel_size]   #x是原始输入 剩下的是矩阵索引方法kernel_output[i, j] = np.sum(kernel_weight * window)  #np.dot != x*y   x*y是点乘(对应位置相乘)output.append(kernel_output)return np.array(output)diy_model = DiyModel(x.shape[0], x.shape[1], torch_w, kernel_size)
output = diy_model.forward(x)
print(output, "diy模型预测结果")

最终对比结果:


可以清楚看到Pytorch框架下的结果与自定义框架下的结果相同。

Pytorch中CNN入门思想及实现相关推荐

  1. Pytorch中DNN入门思想及实现

    DNN全连接层(线性层) 计算公式: y = w * x + b W和b是参与训练的参数 W的维度决定了隐含层输出的维度,一般称为隐单元个数(hidden size) b是偏差值(本文没考虑) 举例: ...

  2. Pytorch中RNN入门思想及实现

    RNN循环神经网络 整体思想: 将整个序列划分成多个时间步,将每一个时间步的信息依次输入模型,同时将模型输出的结果传给下一个时间步,也就是说后面的结果受前面输入的影响. RNN的实现公式: 个人思路: ...

  3. Pytorch中CNN图像回归问题预测值都一样

    ** Pytorch中CNN图像回归问题预测值都一样 ** 上网也查阅了许多资料,然后对比各种方法都试了一遍,归结为以下几点: 1.出现预测值都一样的情况,一般都是在某一层梯度消失了,然后导致输入到下 ...

  4. PyTorch中CNN网络参数计算和模型文件大小预估

    前言 在深度学习CNN构建过程中,网络的参数量是一个需要考虑的问题.太深的网络或是太大的卷积核.太多的特征图通道数都会导致网络参数量上升.写出的模型文件也会很大.所以提前计算网络参数和预估模型文件大小 ...

  5. CNN入门+猫狗大战(Dogs vs. Cats)+PyTorch入门

    一些修改(修改后的代码) 修改原网络的输出方式.原网络采用的交叉熵torch.nn.CrossEntropyLoss()进行Loss计算,而这个函数内部是已经进行了softmax处理的(参考),所以网 ...

  6. PyTorch 60 分钟入门教程中的一些疑惑

    PyTorch 60 分钟入门教程中的一些疑惑 自动微分 参考: PyTorch 60 分钟入门教程. 自动微分 y.data.norm()指的是y的范数,举一个例子 假设x是[1.,2.,3.],则 ...

  7. 简单介绍pytorch中分布式训练DDP使用 (结合实例,快速入门)

    文章目录 DDP原理 pytorch中DDP使用 相关的概念 使用流程 如何启动 torch.distributed.launch spawn调用方式 针对实例voxceleb_trainer多卡介绍 ...

  8. 「PyTorch深度学习入门」4. 使用张量表示真实世界的数据(中)

    来源 | Deep Learning with PyTorch 作者 |  Stevens, et al. 译者 | 杜小瑞 校对 | gongyouliu 编辑 | auroral-L 全文共784 ...

  9. Pytorch和CNN图像分类

    Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够实现强大的GPU加速, ...

最新文章

  1. 大S变汪太!与汪小菲注册结婚
  2. ionic 弹窗(alert, confirm)
  3. python 组合数库函数_Python数据分析之Numpy库(笔记)
  4. HanLPTokenizer HanLP分词器
  5. 图像处理(二十四)Gradient Domain High Dynamic Range Compression学习笔记
  6. Visual.Basic.2008编程参考手册
  7. Linux命令之dos2unix - 将DOS格式文本文件转换成UNIX格式
  8. 【渝粤教育】国家开放大学2018年春季 0408-22T管理学基础 参考试题
  9. 【Android】-- adb shell 命令探索
  10. c++类详解:访问权限,构造函数,拷贝构造函数,析构函数
  11. Linux学习笔记4 - Linux常用命令
  12. 利用C Free3.5 本身获得自身注册码
  13. 【Python精彩案例】生成动态二维码
  14. 根号3表白html,根号三的那句情话
  15. 是非人生 — 一个菜鸟程序员的5年职场路 第24节
  16. android 各个版本安全特性
  17. 字符数组 字符插入(c语言)
  18. 过滤器与拦截器的区别?
  19. Leetcode刷题 2021.01.22
  20. 银行定期存款利率,输入金额,输入年限,计算本息总额

热门文章

  1. linux 信号signal和sigaction理解
  2. 使用mmap实现大文件的复制:单进程与多进程情况
  3. 【操作系统】进程调度(2a):SJF(短任务优先) 算法 原理与实践
  4. IO多路复用之epoll
  5. Java集合(二):List列表
  6. C++ 网络开发工具
  7. Linux自有服务(2)-Linux从入门到精通第六天(非原创)
  8. 定义jQuery插件
  9. 解决 MyEclipse build workspace 慢,validation javascript 更慢的问题
  10. 高数.........