07multiple dimension input.py

比如下图这个预测一个人在一年之后得糖尿病的概率的例子,这个时候我们的输入将会有很多的指标。你可以把它看成是我们体检的各种值。最后一排的外代表了他是否会得糖尿病。

# -*- codeing = utf-8 -*-
# @Time :2021/4/19 10:25
# @Author:sueong
# @File:07multiple dimension input.py
# @Software:PyCharmimport os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"import numpy as np
import torch
import matplotlib.pyplot as plt
#1 prepare the dataset
xy=np.loadtxt('diabetes.csv',delimiter=',',dtype=np.float32)#文件名,以‘,’作为分割符,常用32位浮点数x_data=torch.from_numpy(xy[:,:-1])#第一个":"切片所有行,第二个“:-1”是指从第一列开始,最后一列不要
y_data=torch.from_numpy(xy[:,[-1]])#[-1]这样拿出来的是一个矩阵#2design model using class
#init 三个linear8-6-4-1 sigmoid #forward 三次sigmoid 上一个输入做下一个输出
class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1=torch.nn.Linear(8,6)# 输入数据x的特征是8维,x有8个特征self.linear2=torch.nn.Linear(6,4)self.linear3 =torch.nn.Linear(4,1)self.sigmoid=torch.nn.Sigmoid()# 将其看作是网络的一层,而不是简单的函数使用def forward(self,x):#构建计算图x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel=Model()# 参数说明
# 第一层的参数:
layer1_weight = model.linear1.weight.data
layer1_bias = model.linear1.bias.data
print("layer1_weight", layer1_weight)
print("layer1_weight.shape", layer1_weight.shape)
print("layer1_bias", layer1_bias)
print("layer1_bias.shape", layer1_bias.shape)
'''
layer1_weight tensor([[ 0.1053,  0.1716,  0.1981,  0.1613, -0.1268,  0.1843, -0.3029, -0.1547],[-0.2075, -0.2407, -0.1529, -0.2438,  0.3339, -0.3276,  0.0095, -0.1153],[ 0.1969,  0.0073, -0.1312, -0.1668,  0.2570, -0.2317,  0.2036,  0.0433],[ 0.1871,  0.3029,  0.2014,  0.2805,  0.0691, -0.0206,  0.3492, -0.2535],[-0.0884,  0.2787, -0.0073, -0.1533, -0.0399, -0.1590,  0.2161,  0.3270],[-0.2526, -0.1705, -0.0183,  0.2450, -0.1937, -0.1331, -0.0771,  0.0410]])
layer1_weight.shape torch.Size([6, 8])
layer1_bias tensor([ 1.6963e-02,  2.0659e-01, -3.1528e-01, -2.0308e-01, -1.4773e-04,-3.5623e-02])
layer1_bias.shape torch.Size([6])'''#3 construct loss and optimizer
#criterion=torch.nn.BCELoss(size_average=False)
criterion=torch.nn.BCELoss(reduction='mean')#用size_average=False 会得到9999 26300.0这样很奇怪的额数字
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)epoch_list=[]
loss_list=[]# 4training cycle forward, backward, updatefor epoch in range(1000):y_pre=model(x_data)loss=criterion(y_pre,y_data)print(epoch,loss.item())epoch_list.append(epoch)loss_list.append(loss.item())optimizer.zero_grad()loss.backward()optimizer.step()plt.plot(epoch_list,loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

训练10000 趋于收敛 w在0.4左右
训练1000 趋于收敛 w在0.6左右

PyTorch深度学习实践07相关推荐

  1. PyTorch深度学习实践

    根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...

  2. 【Pytorch深度学习实践】B站up刘二大人之SoftmaxClassifier-代码理解与实现(8/9)

    这是刘二大人系列课程笔记的倒数第二个博客了,介绍的是多分类器的原理和代码实现,下一个笔记就是basicCNN和advancedCNN了: 写在前面: 这节课的内容,主要是两个部分的修改: 一是数据集: ...

  3. 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)

    从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...

  4. 《PyTorch 深度学习实践》第10讲 卷积神经网络(基础篇)

    文章目录 1 卷积层 1.1 torch.nn.Conv2d相关参数 1.2 填充:padding 1.3 步长:stride 2 最大池化层 3 手写数字识别 该专栏内容为对该视频的学习记录:[&l ...

  5. PyTorch 深度学习实践 第13讲

    PyTorch 深度学习实践 第13讲 引言 代码 结果 引言 近期学习了B站 刘二大人的PyTorch深度学习实践,传送门PyTorch 深度学习实践--循环神经网络(高级篇),感觉受益匪浅,发现网 ...

  6. 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归

    刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...

  7. 《PyTorch深度学习实践》 课堂笔记 Lesson7 神经网络多维特征输入的原理推导与实现

    文章目录 1.为什么使用多维的特征输入 2. 多维特征向量输入推导 3.实现过程 3.1源代码 3.2训练结果 写在最后 1.为什么使用多维的特征输入 对于现实世界来说,影响一个事物发展的因素有很多种 ...

  8. PyTorch深度学习实践概论笔记9-SoftMax分类器

    上一讲PyTorch深度学习实践概论笔记8-加载数据集中,主要介绍了Dataset 和 DataLoader是加载数据的两个工具类.这一讲介绍多分类问题如何解决,一般会用到SoftMax分类器. 0 ...

  9. 【PyTorch深度学习实践】P9 kaggle otto商品分类作业(含注释)

    <PyTorch深度学习实践>-刘二大人 Otto Group Product Classification作业 将商品进行十分类,输入为93个特征10个类别的商品数据集,输出为预测数据集 ...

最新文章

  1. Android如何使用读写cookie的方法
  2. linux和windows启动,Linux和Windows双系统的启动
  3. java i o是什么流_Java I/O流的总结
  4. linux系统下定时备份,在Linux系统中简单地实现定时备份的方法 -电脑资料
  5. 【原创】CLEVO P157SM外接鼠标键盘失灵解决:更换硅脂(附带最新跑分数据)
  6. Django_modelform组件
  7. robotframework-接口测试详解(上传文件)
  8. 瀑布模型,快速原型模型,增量模型,螺旋模型以及敏捷开发模型的相关概念
  9. python 统计计数
  10. 60-硅谷课堂6-硅谷课堂-公众号消息和微信授权-- 笔记
  11. mysql分区表去重复_MySQL分区表管理
  12. Java基础冒泡排序——高低输出十个学生的成绩
  13. AI人工智能自动化测试
  14. 分子力场简介 来自wiki百科
  15. linux切割文件命令,Linux系统下切割文件的split命令用法教程
  16. svn 文件前前面的标识符
  17. android内存检测方法,Android_Android系统检测程序内存占用各种方法,1.检查系统总内存 复制代码 - phpStudy...
  18. 买了SKS的W530
  19. 35岁转行数据分析师可以吗?
  20. 航空插头网线转接2.0排针线序图

热门文章

  1. 线上oom 自动kill 程序
  2. Activiti 7.1.4 发布,业务流程管理与工作流系统
  3. mac安装mysql遇到的坑
  4. sql 中CURSOR 的使用
  5. golang协程进行同步方法
  6. linux中shell变量$#,$@,$0,$1,$2的含义解释(转)
  7. Github 上 10 个值得学习的 Springboot 开源项目
  8. linux网站爬取,Kali下httrack 爬取网站页面
  9. java 绝对路径_java 获取绝对路径
  10. MySQL 高级 游标介绍