PyTorch深度学习实践07
07multiple dimension input.py
比如下图这个预测一个人在一年之后得糖尿病的概率的例子,这个时候我们的输入将会有很多的指标。你可以把它看成是我们体检的各种值。最后一排的外代表了他是否会得糖尿病。
# -*- codeing = utf-8 -*-
# @Time :2021/4/19 10:25
# @Author:sueong
# @File:07multiple dimension input.py
# @Software:PyCharmimport os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"import numpy as np
import torch
import matplotlib.pyplot as plt
#1 prepare the dataset
xy=np.loadtxt('diabetes.csv',delimiter=',',dtype=np.float32)#文件名,以‘,’作为分割符,常用32位浮点数x_data=torch.from_numpy(xy[:,:-1])#第一个":"切片所有行,第二个“:-1”是指从第一列开始,最后一列不要
y_data=torch.from_numpy(xy[:,[-1]])#[-1]这样拿出来的是一个矩阵#2design model using class
#init 三个linear8-6-4-1 sigmoid #forward 三次sigmoid 上一个输入做下一个输出
class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1=torch.nn.Linear(8,6)# 输入数据x的特征是8维,x有8个特征self.linear2=torch.nn.Linear(6,4)self.linear3 =torch.nn.Linear(4,1)self.sigmoid=torch.nn.Sigmoid()# 将其看作是网络的一层,而不是简单的函数使用def forward(self,x):#构建计算图x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel=Model()# 参数说明
# 第一层的参数:
layer1_weight = model.linear1.weight.data
layer1_bias = model.linear1.bias.data
print("layer1_weight", layer1_weight)
print("layer1_weight.shape", layer1_weight.shape)
print("layer1_bias", layer1_bias)
print("layer1_bias.shape", layer1_bias.shape)
'''
layer1_weight tensor([[ 0.1053, 0.1716, 0.1981, 0.1613, -0.1268, 0.1843, -0.3029, -0.1547],[-0.2075, -0.2407, -0.1529, -0.2438, 0.3339, -0.3276, 0.0095, -0.1153],[ 0.1969, 0.0073, -0.1312, -0.1668, 0.2570, -0.2317, 0.2036, 0.0433],[ 0.1871, 0.3029, 0.2014, 0.2805, 0.0691, -0.0206, 0.3492, -0.2535],[-0.0884, 0.2787, -0.0073, -0.1533, -0.0399, -0.1590, 0.2161, 0.3270],[-0.2526, -0.1705, -0.0183, 0.2450, -0.1937, -0.1331, -0.0771, 0.0410]])
layer1_weight.shape torch.Size([6, 8])
layer1_bias tensor([ 1.6963e-02, 2.0659e-01, -3.1528e-01, -2.0308e-01, -1.4773e-04,-3.5623e-02])
layer1_bias.shape torch.Size([6])'''#3 construct loss and optimizer
#criterion=torch.nn.BCELoss(size_average=False)
criterion=torch.nn.BCELoss(reduction='mean')#用size_average=False 会得到9999 26300.0这样很奇怪的额数字
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)epoch_list=[]
loss_list=[]# 4training cycle forward, backward, updatefor epoch in range(1000):y_pre=model(x_data)loss=criterion(y_pre,y_data)print(epoch,loss.item())epoch_list.append(epoch)loss_list.append(loss.item())optimizer.zero_grad()loss.backward()optimizer.step()plt.plot(epoch_list,loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
训练10000 趋于收敛 w在0.4左右
训练1000 趋于收敛 w在0.6左右
PyTorch深度学习实践07相关推荐
- PyTorch深度学习实践
根据学习情况随时更新. 2020.08.14更新完成. 参考课程-刘二大人<PyTorch深度学习实践> 文章目录 (一)课程概述 (二)线性模型 (三)梯度下降算法 (四)反向传播 (五 ...
- 【Pytorch深度学习实践】B站up刘二大人之SoftmaxClassifier-代码理解与实现(8/9)
这是刘二大人系列课程笔记的倒数第二个博客了,介绍的是多分类器的原理和代码实现,下一个笔记就是basicCNN和advancedCNN了: 写在前面: 这节课的内容,主要是两个部分的修改: 一是数据集: ...
- 【Pytorch深度学习实践】B站up刘二大人课程笔记——目录与索引(已完结)
从有代码的课程开始讨论 [Pytorch深度学习实践]B站up刘二大人之LinearModel -代码理解与实现(1/9) [Pytorch深度学习实践]B站up刘二大人之 Gradient Desc ...
- 《PyTorch 深度学习实践》第10讲 卷积神经网络(基础篇)
文章目录 1 卷积层 1.1 torch.nn.Conv2d相关参数 1.2 填充:padding 1.3 步长:stride 2 最大池化层 3 手写数字识别 该专栏内容为对该视频的学习记录:[&l ...
- PyTorch 深度学习实践 第13讲
PyTorch 深度学习实践 第13讲 引言 代码 结果 引言 近期学习了B站 刘二大人的PyTorch深度学习实践,传送门PyTorch 深度学习实践--循环神经网络(高级篇),感觉受益匪浅,发现网 ...
- 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归
刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归 P6 逻辑斯蒂回归 1.torchversion 提供的数据集 2.基本概念 3.代码实现 P6 逻辑斯蒂回归 1.torchversi ...
- 《PyTorch深度学习实践》 课堂笔记 Lesson7 神经网络多维特征输入的原理推导与实现
文章目录 1.为什么使用多维的特征输入 2. 多维特征向量输入推导 3.实现过程 3.1源代码 3.2训练结果 写在最后 1.为什么使用多维的特征输入 对于现实世界来说,影响一个事物发展的因素有很多种 ...
- PyTorch深度学习实践概论笔记9-SoftMax分类器
上一讲PyTorch深度学习实践概论笔记8-加载数据集中,主要介绍了Dataset 和 DataLoader是加载数据的两个工具类.这一讲介绍多分类问题如何解决,一般会用到SoftMax分类器. 0 ...
- 【PyTorch深度学习实践】P9 kaggle otto商品分类作业(含注释)
<PyTorch深度学习实践>-刘二大人 Otto Group Product Classification作业 将商品进行十分类,输入为93个特征10个类别的商品数据集,输出为预测数据集 ...
最新文章
- Android如何使用读写cookie的方法
- linux和windows启动,Linux和Windows双系统的启动
- java i o是什么流_Java I/O流的总结
- linux系统下定时备份,在Linux系统中简单地实现定时备份的方法 -电脑资料
- 【原创】CLEVO P157SM外接鼠标键盘失灵解决:更换硅脂(附带最新跑分数据)
- Django_modelform组件
- robotframework-接口测试详解(上传文件)
- 瀑布模型,快速原型模型,增量模型,螺旋模型以及敏捷开发模型的相关概念
- python 统计计数
- 60-硅谷课堂6-硅谷课堂-公众号消息和微信授权-- 笔记
- mysql分区表去重复_MySQL分区表管理
- Java基础冒泡排序——高低输出十个学生的成绩
- AI人工智能自动化测试
- 分子力场简介 来自wiki百科
- linux切割文件命令,Linux系统下切割文件的split命令用法教程
- svn 文件前前面的标识符
- android内存检测方法,Android_Android系统检测程序内存占用各种方法,1.检查系统总内存
复制代码 - phpStudy...
- 买了SKS的W530
- 35岁转行数据分析师可以吗?
- 航空插头网线转接2.0排针线序图
热门文章
- 线上oom 自动kill 程序
- Activiti 7.1.4 发布,业务流程管理与工作流系统
- mac安装mysql遇到的坑
- sql 中CURSOR 的使用
- golang协程进行同步方法
- linux中shell变量$#,$@,$0,$1,$2的含义解释(转)
- Github 上 10 个值得学习的 Springboot 开源项目
- linux网站爬取,Kali下httrack 爬取网站页面
- java 绝对路径_java 获取绝对路径
- MySQL 高级 游标介绍