DNN全连接层(线性层)

计算公式:

y = w * x + b
W和b是参与训练的参数
W的维度决定了隐含层输出的维度,一般称为隐单元个数(hidden size)
b是偏差值(本文没考虑)
举例:
输入:x (维度1 x 3)
隐含层1:w(维度3 x 5)
隐含层2: w(维度5 x 2)

个人思想如下:

比如说如上图,我们有输入层是3个,中间层是5个,输出层要求是2个。利用线性代数,输入是【1×3】,那么需要乘【3×5】的权重矩阵得到【1×5】,再由【1×5】乘【5×2】的权重矩阵,最后得到【1×2】的结果。在本代码中没有考虑偏差值(bias),利用pytorch中随机初始化的权重实现模型预测。

import torch
import torch.nn as nn
import numpy as np
"""
用pytorch框架实现单层的全连接网络
不使用偏置bias
"""
class TorchModel(nn.Module):    #nn.module是torch自带的库def __init__(self, input_size, hidden_size, output_size):super(TorchModel, self).__init__()self.layer1 = nn.Linear(input_size, hidden_size, bias=False)#nn.linear是torch的线性层,input_size是输入的维度,hidden_size是这一层的输出的维度self.layer2 = nn.Linear(hidden_size, output_size, bias=False)#这个线性层可以有很多个def forward(self, x):   #开始计算的函数hidden = self.layer1(x)     #传入输入第一层# print("torch hidden", hidden)y_pred = self.layer2(hidden)       #传入输入第二层return y_pred
x = np.array([1, 0, 0])  #网络输入#torch实验
torch_model = TorchModel(len(x), 5, 2)  #这三个数分别代表输入,中间,结果层的维度
#print(torch_model.state_dict())        #可以打印出pytorch随机初始化的权重
torch_model_w1 = torch_model.state_dict()["layer1.weight"].numpy()
#通过取字典方式将权重取出来并把torch的权重转化为numpy的
torch_model_w2 = torch_model.state_dict()["layer2.weight"].numpy()
#print(torch_model_w1, "torch w1 权重")
#这里你会发现随机初始化的权重矩阵是5×3,所以当自定义模型时需要转置,但是在pytorch中会自动转置相乘
#print(torch_model_w2, "torch w2 权重")
torch_x = torch.FloatTensor([x])    #numpy的输入转化为torch
y_pred = torch_model.forward(torch_x)
print("torch模型预测结果:", y_pred)

以上是pytorch模型实现DNN的简单方法。

自定义模型手工实现:

(注意因为自定义模型需要得到模型中的权重,而上面代码利用的是pytorch的随机自定义模型,为了能让两者对比答案是否相同,自定义模型中的权重需要继承pytorch的随机权重)

"""
手动实现简单的神经网络
用自定义框架实现单层的全连接网络
不使用偏置bias
"""
#自定义模型
class DiyModel:def __init__(self, weight1, weight2):self.weight1 = weight1      #收到在torch随机的权重self.weight2 = weight2def forward(self, x):hidden = np.dot(x, self.weight1.T)  #将输入与第一层权重的转置相乘y_pred = np.dot(hidden, self.weight2.T)return y_preddiy_model = DiyModel(torch_model_w1, torch_model_w2)
y_pred_diy = diy_model.forward(np.array([x]))
print("diy模型预测结果:", y_pred_diy)

如需运行须将自定义模型放入pytorch的代码下面继承输入和随机权重,通过最后结果能发现两者相同。

结果如下:


可以发现两者代码结果相同~

Pytorch中DNN入门思想及实现相关推荐

  1. Pytorch中CNN入门思想及实现

    CNN卷积神经网络 基础概念: 以卷积操作为基础的网络结构,每个卷积核可以看成一个特征提取器. 思想: 每次观察数据的一部分,如图,在整个矩阵中只观察黄色部分3×3的矩阵,将这[3×3]矩阵·(点乘) ...

  2. Pytorch中RNN入门思想及实现

    RNN循环神经网络 整体思想: 将整个序列划分成多个时间步,将每一个时间步的信息依次输入模型,同时将模型输出的结果传给下一个时间步,也就是说后面的结果受前面输入的影响. RNN的实现公式: 个人思路: ...

  3. PyTorch 60 分钟入门教程中的一些疑惑

    PyTorch 60 分钟入门教程中的一些疑惑 自动微分 参考: PyTorch 60 分钟入门教程. 自动微分 y.data.norm()指的是y的范数,举一个例子 假设x是[1.,2.,3.],则 ...

  4. 简单介绍pytorch中分布式训练DDP使用 (结合实例,快速入门)

    文章目录 DDP原理 pytorch中DDP使用 相关的概念 使用流程 如何启动 torch.distributed.launch spawn调用方式 针对实例voxceleb_trainer多卡介绍 ...

  5. 「PyTorch深度学习入门」4. 使用张量表示真实世界的数据(中)

    来源 | Deep Learning with PyTorch 作者 |  Stevens, et al. 译者 | 杜小瑞 校对 | gongyouliu 编辑 | auroral-L 全文共784 ...

  6. Lesson 16.5 在Pytorch中实现卷积网络(上):卷积核、输入通道与特征图在PyTorch中实现卷积网络(中):步长与填充

    卷积神经网络是使用卷积层的一组神经网络.在一个成熟的CNN中,往往会涉及到卷积层.池化层.线性层(全连接层)以及各类激活函数.因此,在构筑卷积网络时,需从整体全部层的需求来进行考虑. 1 二维卷积层n ...

  7. PyTorch自然语言处理入门与实战 | 文末赠书

    文末赠书 注:本文选自人民邮电出版社出版的<PyTorch自然语言处理入门与实战>一书,略有改动.经出版社授权刊登于此. 处理中文与英文的一个显著区别是中文的词之间缺乏明确的分隔符.分词是 ...

  8. 什么是embedding(把物体编码为一个低维稠密向量),pytorch中nn.Embedding原理及使用

    文章目录 使embedding空前流行的word2vec 句子的表达 训练样本 损失函数 输入向量表达和输出向量表达vwv_{w}vw​ 从word2vec到item2vec 讨论环节 pytorch ...

  9. gpu处理信号_在PyTorch中使用DistributedDataParallel进行多GPU分布式模型训练

    先进的深度学习模型参数正以指数级速度增长:去年的GPT-2有大约7.5亿个参数,今年的GPT-3有1750亿个参数.虽然GPT是一个比较极端的例子但是各种SOTA模型正在推动越来越大的模型进入生产应用 ...

最新文章

  1. android 跳转到应用市场
  2. OpenCV 错误:无法打开摄像头(打开摄像头卡机)
  3. c# rar解压大小_C#利用WinRAR实现压缩和解压缩
  4. 15个Java多线程面试题
  5. 9203 0427 随堂小结
  6. vb.net word 自定义工具栏_word重点标记新玩法:应用绘图工具手写笔进行划线涂抹...
  7. 最全的iOS真机调试教程(证书生成等)
  8. 邮件合并保存为一个个单独的文档_你还在为考计算机二级烦恼吗? 基本操作步骤分享...
  9. DRF的解析器和渲染器
  10. 海康SDK开发2—SpringBoot+海康SDK
  11. OPPO a1刷机包下载_OPPOA1密码忘记了?来这里搞定
  12. matlab平均脸,BFM使用 - 获取平均脸模型的68个特征点坐标
  13. 【组件篇】ionic3开源组件
  14. 中兴跳楼程序员妻子:他们就这样把我老公逼死了
  15. 微信SDK中含有的支付功能怎么去掉?
  16. 1481: 考试排名(一)(结构体专题)
  17. Diva-Tp项测试详解
  18. PicGO+阿里云OSS或PicGO+Github+Jsdelivr搭建图床(图解)
  19. 成了!刚刚登顶全球首富的他,花440亿美元将推特买下 | 美通社头条
  20. 网络安全与网站安全及计算机安全:如何使用Kali Linux进行MS08-067安全演练

热门文章

  1. Linux的ext4文件系统学习笔记
  2. shell编程题(二)
  3. 1G.小a的排列(C++)
  4. Python小数据池,代码块
  5. go map数据结构
  6. Js实现div随鼠标移动的方法
  7. 如何用js获取浏览器URL中查询字符串的参数
  8. Visual Studio 快捷键汇总
  9. MFC学习中遇到的小问题和解决方案
  10. 关于 p3p ie 跨域 问题