多输入模型 Multiple-Dimension 数学原理分析以及源码源码详解 深度学习 Pytorch笔记 B站刘二大人(6/10)

数学推导

在之前实现的模型普遍都是单输入单输出模型,显然,在现实场景中更多的是多输入多输入模型。在本文中将主要推导,多输入模型中的内部数据传输变化,以及内部矩阵运算过程

使用Mini-batch 模块,多个维度线性层共同使用一组权重w的线性组合,共享权重可以极大减小运算量。转化为矩阵运算的意义是希望通过转化为矩阵实现并行运算,进行gpu的挂载

在本次的实践数据中使用的是糖尿病病人的数据集,通过维度为8的输入数据,即利用8种病人自身的评价指标的数据,对该样本病人是否患有糖尿病进行判断。

假设输入训练数据量为N,则输入数据为一个8*N的矩阵,如上图所示。

在之前的文章中已经进行了强调,要将数据运算的过程视为矩阵的运算,因此需要维度为8的权重向量w与原始数据进行矩阵乘法,N8的矩阵与81的矩阵右乘,则得到N*1的矩阵,加上偏置量b,通过激活函数sigmoid转化为概率,通过概率进行判别。

虽然在原理上,是直接通过8*1的权重矩阵w进行构造,但是实际构造为了提高准确性,通常会将单个线性层拆分为多层,例如在本文的代码实现中就将 8 -> 1 层的网络结构转换为 8 -> 6 -> 4 -> 1 的多层网络结构。

将8维空间的数据转化为1维,通过多个线性层与激活函数的组合,模拟多个空间非线性变换,从而达到不同的设计目的。一般隐层越多,学习能力越强,但是必须考虑泛化能力

数据下载

链接:https://pan.baidu.com/s/1IJpTM1_gd4Tln01A5JYOSA?pwd=ws2r
提取码:ws2r

代码解读与实现

代码细节,loadtxt函数dtype选择 .float32类型 ,原因:绝大部分的显卡仅支持float32位的数据

其次本次代码在进行损失计算中,选用的是average = True,意味着将损失将取平均值,损失值和梯度将较大,优化器的迭代步长可以适当调大

''' coding:utf-8 '''
"""
作者:shiyi
日期:年 09月 03日
通过pytorch模块复现多输入线性模型
"""# prepare datasetimport torch
import numpy as npxy = np.loadtxt('D:\\pytorch_prac\\dataset\\diabetes.csv.gz', delimiter=',', dtype=np.float32)          # 使用float32位的数据,以支持绝大部分GPU的数据格式
x_data = torch.from_numpy(xy[:, :-1])           # 读取输入数据,xy[:,:-1] 意思是读取 除了最后1列 的其余所有行数据
y_data = torch.from_numpy(xy[:, [-1]])          # 读取标签数据,xy[:,[-1]]意思是读取 最后1列 所有行的数据# design model using class
class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))       # 注意全部都用x,防止传输出错x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()     # 实例化# Construct loss and optimizer
cirterison = torch.nn.BCELoss(size_average=True)
opimizer = torch.optim.ASGD(model.parameters(), lr=0.05)# Training cycle
for epoch in range(3000):# Forwardy_pred = model(x_data)loss = cirterison(y_pred, y_data)print(epoch, loss.item())# Bcakwardopimizer.zero_grad()loss.backward()# Updateopimizer.step()# Test Model
x_test = torch.Tensor([[4.0, 5.0, 3.0, 4.0, 1.0, 6.0, 7.0, 8.0]])
y_test = model(x_test)print('y_pred =', y_test.item())

与其余的模型不同的是,本次的多输入模型训练的损失始终难以下降到一个令人满意的程度,这与损失构造算法和优化器类型都存在关系,可以尝试更换通过matplotlib库画出图形进行比对和优化

优化代码与结果

(更新中,我真是擅长给自己挖坑呢。。。。)

【多输入模型 Multiple-Dimension 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人 (6/10)】相关推荐

  1. 【 卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10)】

    卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10) 本章主要进行卷积神经网络的相关数学原理和pytorch的对应模块进行推导分析 代码也是通过demo实 ...

  2. 【分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人(8/10)】

    分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人 (8/10) 在进行本章的数学推导前,有必要先粗浅的介绍一下,笔者在广泛查找 ...

  3. 【 梯度下降算法 Gradient-Descend 数学推导与源码详解 深度学习 Pytorch笔记 B站刘二大人(2/10)】

    梯度下降算法 Gradient-Descend 数学推导与源码详解 深度学习 Pytorch笔记 B站刘二大人(2/10) 数学原理分析 在第一节中我们定义并构建了线性模型,即最简单的深度学习模型,但 ...

  4. 【 反向传播算法 Back-Propagation 数学推导以及源码详解 深度学习 Pytorch笔记 B站刘二大人(3/10)】

    反向传播算法 Back-Propagation 数学推导以及源码详解 深度学习 Pytorch笔记 B站刘二大人(3/10) 数学推导 BP算法 BP神经网络可以说机器学习的最基础网络.对于普通的简单 ...

  5. 【 线性模型 Linear-Model 数学原理分析以及源码实现 深度学习 Pytorch笔记 B站刘二大人(1/10)】

    线性模型 Linear-Model 数学原理分析以及源码实现 深度学习 Pytorch笔记 B站刘二大人(1/10) 数学原理分析 线性模型是我们在初级数学问题中所遇到的最普遍也是最多的一类问题 在线 ...

  6. Android 事件分发机制分析及源码详解

    Android 事件分发机制分析及源码详解 文章目录 Android 事件分发机制分析及源码详解 事件的定义 事件分发序列模型 分发序列 分发模型 事件分发对象及相关方法 源码分析 事件分发总结 一般 ...

  7. fdct算法 java_ImageSharp源码详解之JPEG压缩原理(3)DCT变换

    DCT变换可谓是JPEG编码原理里面数学难度最高的一环,我也是因为DCT变换的算法才对JPEG编码感兴趣(真是不自量力).这一章我就把我对DCT的研究心得体会分享出来,希望各位大神也不吝赐教. 1.离 ...

  8. Tensorflow 2.x(keras)源码详解之第九章:模型训练和预测的三种方法(fittf.GradientTapetrain_steptf.data)

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

  9. Mapreduce源码分析(一):FileInputFormat切片机制,源码详解

    FileInputFormat切片机制,源码详解 1.InputFormat:抽象类 只有两个抽象方法 public abstract List<InputSplit> getSplits ...

最新文章

  1. 16 分频 32 分频是啥意思_Verilog中任意分频的实现
  2. 机器学习中的训练集 验证集 测试集的关系
  3. [js] 字符串拼接有哪些方式?哪种性能好?
  4. 星巴克“啡快”宣布接入支付宝、口碑等阿里应用
  5. 中国移动老功臣退休致辞:工作结束了 人生没结束
  6. Android按键灯,指示灯总结【Android源码解析十一】
  7. ibm语音识别输入系统
  8. eigrp 扩散算法_EIGRP扩散更新算法-FC规则
  9. Java反射初探 ——“当类也学会照镜子”
  10. mysql查看enum和set值_mysql中的enum和set类型_MySQL
  11. 权重的计算(熵权法)
  12. 刘毅5000词汇_不熟词汇整理_lesson_14 and part_4
  13. 详细解LeetCode 1284. Minimum Number of Flips to Convert Binary Matrix to Zero Matrix
  14. word文档页眉清除和页码设置
  15. 人生无捷径「一万小时定律·正篇」
  16. read write file
  17. sortWith与sortBy
  18. 奇安信漏扫设备与堡垒机问题解析
  19. 关于uni-app手机nfc开启、读取、写入功能
  20. java与gis开发

热门文章

  1. tkinter绘制组件(18)——菜单
  2. Rust之Sea-orm快速入门指南
  3. React Hooks核心原理与实战
  4. 用计算机术语赞美老师,【用一句话赞美各个学科】_赞美各学科老师的对联
  5. 今天给自己分享下我的心得体会
  6. 计算机毕业设计之android平台的出租打车软件app(源码+系统+mysql数据库+Lw文档)
  7. 改进A星算法+dwa
  8. 保险业未来生态的起点与三条演化路径 | 李有龙生态矩阵
  9. 游匣G15怎么样 游戏评测来了
  10. 使用expect ftp免交互上传文件