Torch搭网络学习笔记(一)逻辑回归
逻辑回归 regression logistic
逻辑回归的基本思想是寻求一个逻辑函数,将输入的特征(x1,x2)映射到(0-1)内,然后按照是否大于0.5进行二分类。
为了能够进行数学描述,现设定映射关系核逻辑函数;假设该映射关系为线性关系,则有:
即
此时f(x)的值可以是任何值,需要通过逻辑函数进行(0-1)再映射。
逻辑回归中逻辑函数为h(z),即sigmoid函数:
将h(z)在z域上画出,图像为:
可见,在f(x)>0时,特征X将被归于1类,f(x)<0时,特征X将被归于0类。
损失函数:cost损失函数应该越小越好。所以在真实标签为1时,h(z)越接近于1越好,则有:
越接近于0;在真实标签为0时,h(z)越接近于0越好,则有越接近于0;所以最终的损失函数定义为:
上式可以改写为:
有了损失函数,就可利用梯度下降法进行求导优化了。
torch
torch是专门用于搭建神经网络的框架工具,其中封装了nn.Module父类网络结构,可变参数张量Variable,nn.BCEWithLogitsLoss逻辑回归损失函数,.optim.SGD梯度下降函数等一系列网络函数,可以方便快捷的搭建一个神经网络。
1、Variable
就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。
具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生变化,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是由tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。tensor 是一个多维矩阵。
2、torch.Tensor([1,2])
生成张量,[1.0,2.0],网络中数据必须是张量,而torch.tensor则是函数,将某类型值转为tensor.
3、torch.manual_seed(2)
设定CPU固定随机数并设定维数,torch.cuda.manual_seed(number
)则为 GPU设定固定随机数,设定后,torch.rand(2)执行时,每次都会生成相同的随机数,便于观察模型的改进和优化。
举个栗子:
现在对数字的大小进行2分类,设定模式类别为大于2的数字分类到1,小于等于2的则分类到0.
假设训练样本一共有4个,样本及标签如下:
,
则利用torch搭建自己的网络并训练的代码如下:
import torch
from torch.autograd import Variabletorch.manual_seed(2)
x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0], [4.0]]))
y_data = Variable(torch.Tensor([[0.0], [0.0], [1.0], [1.0]]))#定义网络模型
#先建立一个基类Module,都是从父类torch.nn.Module继承过来,Pytorch写网络的固定写法
class Model(torch.nn.Module):def __init__(self):##固定写法,定义构造函数,参数为selfsuper(Model, self).__init__() #继承父类,初始父类 or nn.Module.__init__(self)####定义类中要用到的层和函数self.linear = torch.nn.Linear(1, 1) #线性映射层,torch自带封装,输入维度和输出维度都为1,即y=linear(x)==y=wx+b,会自动按照####x和y的维数确定W和B的维数。def forward(self, x):###定义网络的前向传递过程图y_pred = self.linear(x)return y_predmodel = Model() #网络实例化#定义loss和优化方法
criterion = torch.nn.BCEWithLogitsLoss() #损失函数,封装好的逻辑损失函数即cost函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) #进行优化梯度下降
#optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)
#Pytorch类方法正则化方法,添加一个weight_decay参数进行正则化#befor training
hour_var = Variable(torch.Tensor([[2.5]]))
y_pred = model(hour_var)####有x输入时,会调用forward函数,直接预测一个2.5的分类
print("predict (before training)given", 4, 'is', float(model(hour_var).data[0][0]>0.5))
########
#####开始训练
epochs = 40##迭代次数
for epoch in range(epochs):#计算grads和costy_pred = model(x_data) #x_data输入数据进入模型中####y为预测结果loss = criterion(y_pred, y_data)#####求损失print('epoch = ', epoch+1, loss.data[0])#反向传播三件套,固定使用。定义了loss是某种损失后,直接loss.backword即可更新线性映射中的w和b,直到分类正确。optimizer.zero_grad() #梯度清零loss.backward() #反向传播optimizer.step() #优化迭代#After training
hour_var = Variable(torch.Tensor([[4.0]]))
y_pred = model(hour_var)
print("predict (after training)given", 4, 'is', float(model(hour_var).data[0][0]>0.5))
运行结果如下:
predict (before training)given 4 is 0.0
predict (after training)given 4 is 1.0
可见,经过训练可以实现正确2分类.
Torch搭网络学习笔记(一)逻辑回归相关推荐
- 吴恩达《机器学习》学习笔记七——逻辑回归(二分类)代码
吴恩达<机器学习>学习笔记七--逻辑回归(二分类)代码 一.无正则项的逻辑回归 1.问题描述 2.导入模块 3.准备数据 4.假设函数 5.代价函数 6.梯度下降 7.拟合参数 8.用训练 ...
- 吴恩达《机器学习》学习笔记五——逻辑回归
吴恩达<机器学习>学习笔记五--逻辑回归 一. 分类(classification) 1.定义 2.阈值 二. 逻辑(logistic)回归假设函数 1.假设的表达式 2.假设表达式的意义 ...
- 吴恩达《机器学习》学习笔记八——逻辑回归(多分类)代码
吴恩达<机器学习>笔记八--逻辑回归(多分类)代码 导入模块及加载数据 sigmoid函数与假设函数 代价函数 梯度下降 一对多分类 预测验证 课程链接:https://www.bilib ...
- CS229学习笔记(3)逻辑回归(Logistic Regression)
1.分类问题 你要预测的变量yyy是离散的值,我们将学习一种叫做逻辑回归 (Logistic Regression) 的算法,这是目前最流行使用最广泛的一种学习算法. 从二元的分类问题开始讨论. 我们 ...
- 医咖会免费SPSS教程学习笔记—二元逻辑回归
1.假设检验 2.如何判断连续自变量与因变量的logit转换值之间存在线性关系 首先,创建连续自变量的自然对数值.方法在上一条博文. 其次,请依次点击:分析-回归-二元逻辑-拖入因变量,拖入自变量到协 ...
- 机器学习笔记-基于逻辑回归的分类预测
天池学习笔记:AI训练营机器学习-阿里云天池 基于逻辑回归的分类预测 1 逻辑回归的介绍和应用 1.1 逻辑回归的介绍 逻辑回归(Logistic regression,简称LR)虽然其中带有&quo ...
- SDN软件定义网络 学习笔记(4)--数据平面
SDN软件定义网络 学习笔记(4)--数据平面 1. 简介 2. SDN数据平面架构 2.1 传统网络交换设备架构 2.2 SDN交换设备架构 2.3 数据平面架构图 3. SDN芯片与交换机 3.1 ...
- 华为网络学习笔记(一) 网络通信协议
华为网络学习笔记(一) 一.网络通信协议 通讯协议:通讯协议又称通信规程,是指通信双方对数据传送控制的一种约定.约定中包括对数据格式,同步方式,传送速度,传送步骤,检纠错方式以及控制字符定义等问题做出 ...
- SDN软件定义网络 学习笔记(1)--基本概念
SDN软件定义网络 学习笔记(1)--基本概念 1. 定义 2. 提出背景 3. 体系结构 1. 定义 软件定义网络(Software Defined Network,SDN),顾名思义,SDN 与传 ...
- python中socket模块常用吗_python网络学习笔记——socket模块使用记录
此文章记录了笔者学习python网络中socket模块的笔记. 建议初次学习socket的读者先读一遍socket模块主要函数的介绍. socket模块的介绍可以参考笔者的前一篇关于socket官方文 ...
最新文章
- AI工程师面试知识点:神经网络相关
- 利用nginx和mongrel、unicorn 对puppet进行端口负载均衡
- ajax 加载partial view ,并且 附加validate验证
- MANIFEST.MF的用途(转载)
- Ubuntu18.04无法进入图形界面桌面的问题及解决
- 242种颜色样式、中英文名称及十六进制的值
- Flutter 接入iOS苹果内购支付踩坑过程
- win7 定时开关机命令
- mysql为什么要用b+树
- 2016年GitHub上史上最全的Android开源项目分类汇总
- Ta-lib学习笔记02--K线模式识别
- 网络营销招生方案及河南大学生高校名单
- opencv检测相交点_在网络摄像头feed opencv中检测2条线之间的交点
- luogu P1979 华容道
- mysql中depart_mysql实训
- comsol官方案例学习——轴对称瞬态传热
- mysql骚操作_关于MySQL的一些骚操作——提升正确性,抠点性能
- 【迁移学习】猫狗数据分类案例(TensorFlow2)
- c语言应用论文英文,c语言中英文翻译资料 本科毕业论文(设计).doc
- 如何正确关机,重启,以及常用的快捷键
热门文章
- 编译Linux内核4.4实现可读NTFS
- 象棋(Xiangqi, ACM/ICPC Fuzhou 2011, UVa1589)C++超详细解题
- 许小年:企业家精神的衰落与重振
- 八个研发物联网产品的重要问题
- php 遍历文件夹并压成zip_将文件夹压缩成zip文件的php代码
- UT000010 Session is Invalid
- 实时PPP多系统组合与单系统解算ZTD和Clock差异
- 腾讯汤道生:开放中台能力助力产业升级
- MySQL必知必会——实践习题
- 用只读打开服务器上的文档,打开WebDAV文档在MS Office中以IT只读方式打开WebDAV服务器...