逻辑回归  regression logistic

逻辑回归的基本思想是寻求一个逻辑函数,将输入的特征(x1,x2)映射到(0-1)内,然后按照是否大于0.5进行二分类。

为了能够进行数学描述,现设定映射关系核逻辑函数;假设该映射关系为线性关系,则有:

即     

此时f(x)的值可以是任何值,需要通过逻辑函数进行(0-1)再映射。

逻辑回归中逻辑函数为h(z),即sigmoid函数:

将h(z)在z域上画出,图像为:

可见,在f(x)>0时,特征X将被归于1类,f(x)<0时,特征X将被归于0类。

损失函数:cost损失函数应该越小越好。所以在真实标签为1时,h(z)越接近于1越好,则有:

越接近于0;在真实标签为0时,h(z)越接近于0越好,则有越接近于0;所以最终的损失函数定义为:

上式可以改写为:

有了损失函数,就可利用梯度下降法进行求导优化了。

torch

torch是专门用于搭建神经网络的框架工具,其中封装了nn.Module父类网络结构,可变参数张量Variable,nn.BCEWithLogitsLoss逻辑回归损失函数,.optim.SGD梯度下降函数等一系列网络函数,可以方便快捷的搭建一个神经网络。

1、Variable

就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。

具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生变化,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是由tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。tensor 是一个多维矩阵。

2、torch.Tensor([1,2])

生成张量,[1.0,2.0],网络中数据必须是张量,而torch.tensor则是函数,将某类型值转为tensor.

3、torch.manual_seed(2)

设定CPU固定随机数并设定维数,torch.cuda.manual_seed(number)则为 GPU设定固定随机数,设定后,torch.rand(2)执行时,每次都会生成相同的随机数,便于观察模型的改进和优化。

举个栗子:

现在对数字的大小进行2分类,设定模式类别为大于2的数字分类到1,小于等于2的则分类到0.

假设训练样本一共有4个,样本及标签如下:

,     

则利用torch搭建自己的网络并训练的代码如下:

import torch
from torch.autograd import Variabletorch.manual_seed(2)
x_data = Variable(torch.Tensor([[1.0], [2.0], [3.0], [4.0]]))
y_data = Variable(torch.Tensor([[0.0], [0.0], [1.0], [1.0]]))#定义网络模型
#先建立一个基类Module,都是从父类torch.nn.Module继承过来,Pytorch写网络的固定写法
class Model(torch.nn.Module):def __init__(self):##固定写法,定义构造函数,参数为selfsuper(Model, self).__init__()  #继承父类,初始父类 or nn.Module.__init__(self)####定义类中要用到的层和函数self.linear = torch.nn.Linear(1, 1)  #线性映射层,torch自带封装,输入维度和输出维度都为1,即y=linear(x)==y=wx+b,会自动按照####x和y的维数确定W和B的维数。def forward(self, x):###定义网络的前向传递过程图y_pred = self.linear(x)return y_predmodel = Model()  #网络实例化#定义loss和优化方法
criterion = torch.nn.BCEWithLogitsLoss()  #损失函数,封装好的逻辑损失函数即cost函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)   #进行优化梯度下降
#optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)
#Pytorch类方法正则化方法,添加一个weight_decay参数进行正则化#befor training
hour_var = Variable(torch.Tensor([[2.5]]))
y_pred = model(hour_var)####有x输入时,会调用forward函数,直接预测一个2.5的分类
print("predict (before training)given", 4, 'is', float(model(hour_var).data[0][0]>0.5))
########
#####开始训练
epochs = 40##迭代次数
for epoch in range(epochs):#计算grads和costy_pred = model(x_data)   #x_data输入数据进入模型中####y为预测结果loss = criterion(y_pred, y_data)#####求损失print('epoch = ', epoch+1, loss.data[0])#反向传播三件套,固定使用。定义了loss是某种损失后,直接loss.backword即可更新线性映射中的w和b,直到分类正确。optimizer.zero_grad() #梯度清零loss.backward() #反向传播optimizer.step()  #优化迭代#After training
hour_var = Variable(torch.Tensor([[4.0]]))
y_pred = model(hour_var)
print("predict (after training)given", 4, 'is', float(model(hour_var).data[0][0]>0.5))

运行结果如下:

predict (before training)given 4 is 0.0
predict (after training)given 4 is 1.0

可见,经过训练可以实现正确2分类.

Torch搭网络学习笔记(一)逻辑回归相关推荐

  1. 吴恩达《机器学习》学习笔记七——逻辑回归(二分类)代码

    吴恩达<机器学习>学习笔记七--逻辑回归(二分类)代码 一.无正则项的逻辑回归 1.问题描述 2.导入模块 3.准备数据 4.假设函数 5.代价函数 6.梯度下降 7.拟合参数 8.用训练 ...

  2. 吴恩达《机器学习》学习笔记五——逻辑回归

    吴恩达<机器学习>学习笔记五--逻辑回归 一. 分类(classification) 1.定义 2.阈值 二. 逻辑(logistic)回归假设函数 1.假设的表达式 2.假设表达式的意义 ...

  3. 吴恩达《机器学习》学习笔记八——逻辑回归(多分类)代码

    吴恩达<机器学习>笔记八--逻辑回归(多分类)代码 导入模块及加载数据 sigmoid函数与假设函数 代价函数 梯度下降 一对多分类 预测验证 课程链接:https://www.bilib ...

  4. CS229学习笔记(3)逻辑回归(Logistic Regression)

    1.分类问题 你要预测的变量yyy是离散的值,我们将学习一种叫做逻辑回归 (Logistic Regression) 的算法,这是目前最流行使用最广泛的一种学习算法. 从二元的分类问题开始讨论. 我们 ...

  5. 医咖会免费SPSS教程学习笔记—二元逻辑回归

    1.假设检验 2.如何判断连续自变量与因变量的logit转换值之间存在线性关系 首先,创建连续自变量的自然对数值.方法在上一条博文. 其次,请依次点击:分析-回归-二元逻辑-拖入因变量,拖入自变量到协 ...

  6. 机器学习笔记-基于逻辑回归的分类预测

    天池学习笔记:AI训练营机器学习-阿里云天池 基于逻辑回归的分类预测 1 逻辑回归的介绍和应用 1.1 逻辑回归的介绍 逻辑回归(Logistic regression,简称LR)虽然其中带有&quo ...

  7. SDN软件定义网络 学习笔记(4)--数据平面

    SDN软件定义网络 学习笔记(4)--数据平面 1. 简介 2. SDN数据平面架构 2.1 传统网络交换设备架构 2.2 SDN交换设备架构 2.3 数据平面架构图 3. SDN芯片与交换机 3.1 ...

  8. 华为网络学习笔记(一) 网络通信协议

    华为网络学习笔记(一) 一.网络通信协议 通讯协议:通讯协议又称通信规程,是指通信双方对数据传送控制的一种约定.约定中包括对数据格式,同步方式,传送速度,传送步骤,检纠错方式以及控制字符定义等问题做出 ...

  9. SDN软件定义网络 学习笔记(1)--基本概念

    SDN软件定义网络 学习笔记(1)--基本概念 1. 定义 2. 提出背景 3. 体系结构 1. 定义 软件定义网络(Software Defined Network,SDN),顾名思义,SDN 与传 ...

  10. python中socket模块常用吗_python网络学习笔记——socket模块使用记录

    此文章记录了笔者学习python网络中socket模块的笔记. 建议初次学习socket的读者先读一遍socket模块主要函数的介绍. socket模块的介绍可以参考笔者的前一篇关于socket官方文 ...

最新文章

  1. AI工程师面试知识点:神经网络相关
  2. 利用nginx和mongrel、unicorn 对puppet进行端口负载均衡
  3. ajax 加载partial view ,并且 附加validate验证
  4. MANIFEST.MF的用途(转载)
  5. Ubuntu18.04无法进入图形界面桌面的问题及解决
  6. 242种颜色样式、中英文名称及十六进制的值
  7. Flutter 接入iOS苹果内购支付踩坑过程
  8. win7 定时开关机命令
  9. mysql为什么要用b+树
  10. 2016年GitHub上史上最全的Android开源项目分类汇总
  11. Ta-lib学习笔记02--K线模式识别
  12. 网络营销招生方案及河南大学生高校名单
  13. opencv检测相交点_在网络摄像头feed opencv中检测2条线之间的交点
  14. luogu P1979 华容道
  15. mysql中depart_mysql实训
  16. comsol官方案例学习——轴对称瞬态传热
  17. mysql骚操作_关于MySQL的一些骚操作——提升正确性,抠点性能
  18. 【迁移学习】猫狗数据分类案例(TensorFlow2)
  19. c语言应用论文英文,c语言中英文翻译资料 本科毕业论文(设计).doc
  20. 如何正确关机,重启,以及常用的快捷键

热门文章

  1. 编译Linux内核4.4实现可读NTFS
  2. 象棋(Xiangqi, ACM/ICPC Fuzhou 2011, UVa1589)C++超详细解题
  3. 许小年:企业家精神的衰落与重振
  4. 八个研发物联网产品的重要问题
  5. php 遍历文件夹并压成zip_将文件夹压缩成zip文件的php代码
  6. UT000010 Session is Invalid
  7. 实时PPP多系统组合与单系统解算ZTD和Clock差异
  8. 腾讯汤道生:开放中台能力助力产业升级
  9. MySQL必知必会——实践习题
  10. 用只读打开服务器上的文档,打开WebDAV文档在MS Office中以IT只读方式打开WebDAV服务器...