文章目录

  逻辑回归是线性的二分类模型。逻辑回归是分析自变量x与因变量v(概率)之间关系的方法。模型表达式为:
y=f(WX+b)f(x)=1e−x+1y=f(WX+b)\\f(x)=\frac{1}{e^{-x}+1}y=f(WX+b)f(x)=e−x+11​
f(x)为Sigmoid函数f(x)为Sigmoid函数f(x)为Sigmoid函数
下面展示torch实现逻辑回归的过程。

# -*- coding: utf-8 -*-
"""
# @file name  : lesson-05-Logsitic-Regression.py
# @author     : tingsongyu
# @date       : 2019-09-03 10:08:00
# @brief      : 逻辑回归模型训练
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)# ============================ step 1/5 生成数据 ============================
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias      # 类别0 数据 shape=(100, 2)
y0 = torch.zeros(sample_nums)                         # 类别0 标签 shape=(100)
x1 = torch.normal(-mean_value * n_data, 1) + bias     # 类别1 数据 shape=(100, 2)
y1 = torch.ones(sample_nums)                          # 类别1 标签 shape=(100)
train_x = torch.cat((x0, x1), 0)
train_y = torch.cat((y0, y1), 0)# ============================ step 2/5 选择模型 ============================
# 利用torch中的nn构建逻辑回归模型
class LR(nn.Module):def __init__(self):super(LR, self).__init__()self.features = nn.Linear(2, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.features(x)x = self.sigmoid(x)return xlr_net = LR()   # 实例化逻辑回归模型# ============================ step 3/5 选择损失函数 ============================
# 使用二分类交叉熵损失函数
loss_fn = nn.BCELoss()# ============================ step 4/5 选择优化器   ============================
# 优化器选择SGD(随机梯度下降)
lr = 0.01  # 学习率
optimizer = torch.optim.SGD(lr_net.parameters(), lr=lr, momentum=0.9)# ============================ step 5/5 模型训练 ============================
for iteration in range(1000):# 前向传播y_pred = lr_net(train_x)# 计算 lossloss = loss_fn(y_pred.squeeze(), train_y)# 反向传播loss.backward()# 更新参数optimizer.step()# 清空梯度optimizer.zero_grad()# 每迭代训练20次绘图if iteration % 20 == 0:# 以0.5为阈值进行分类mask = y_pred.ge(0.5).float().squeeze()# 计算正确预测的样本个数correct = (mask == train_y).sum()# 计算分类准确率acc = correct.item() / train_y.size(0)# 绘制训练数据plt.scatter(x0.data.numpy()[:, 0], x0.data.numpy()[:, 1], c='r', label='class 0')plt.scatter(x1.data.numpy()[:, 0], x1.data.numpy()[:, 1], c='b', label='class 1')# 绘制逻辑回归模型w0, w1 = lr_net.features.weight[0]w0, w1 = float(w0.item()), float(w1.item())plot_b = float(lr_net.features.bias[0].item())plot_x = np.arange(-6, 6, 0.1)plot_y = (-w0 * plot_x - plot_b) / w1plt.xlim(-5, 7)plt.ylim(-7, 7)plt.plot(plot_x, plot_y)plt.text(-5, 5, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})plt.title("Iteration: {}\nw0:{:.2f} w1:{:.2f} b: {:.2f} accuracy:{:.2%}".format(iteration, w0, w1, plot_b, acc))plt.legend()plt.show()plt.pause(0.5)if acc > 0.99:break

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!


PyTorch学习—5.torch实现逻辑回归相关推荐

  1. 机器学习学习吴恩达逻辑回归_机器学习基础:逻辑回归

    机器学习学习吴恩达逻辑回归 In the previous stories, I had given an explanation of the program for implementation ...

  2. 破解数据匮乏现状:纵向联邦学习场景下的逻辑回归(LR)

    摘要:主要介绍了华为云可信智能计算服务(TICS)采用的纵向联邦逻辑回归(LR)方案. 本文分享自华为云社区<纵向联邦学习场景下的逻辑回归(LR)>,作者: 汽水要加冰. 海量训练数据是人 ...

  3. 吴恩达深度学习 —— 2.14 向量化逻辑回归的梯度输出

    这一节将学习如果向量化计算m个训练数据的梯度,强调一下,是同时计算. 前面已经说过,在逻辑回归中,有dz(1)=a(1)−y(1)dz^{(1)}=a^{(1)}-y^{(1)}dz(1)=a(1)− ...

  4. 【推荐算法 学习与复现】-- 逻辑回归算法族 -- LR

    协同过滤仅仅使用有限的用户行为信息,逻辑回归算法模型大多引入用户行为.用户特征.物品特征和上下文特征等,从CF逐步过渡到综合不同特征的机器学习模型. (1)逻辑回归模型 将用户特征(年龄.性别等).用 ...

  5. 神经网络和深度学习(5)-- 逻辑回归

    神经网络和深度学习 上一篇 主目录 下一篇 文章结构 1.逻辑回归 [前言] 逻辑回归学习算法,该算法适用于二分类问题,本节将主要 介绍逻辑回归的 Hypothesis Function(假设函数) ...

  6. 深度学习(入门)——逻辑回归模型(Logistics Regression)

    从逻辑回归开始,是因为这是一个简单的,可以理解成为一个简单的一层的神经网络,后续将逐步深入,了解更复杂的神经网络 一,假设函数 逻辑回归算法适用于二分类问题,例如在输入一个猫的图片到模型中,模型会输出 ...

  7. TensorFlow基础7-机器学习基础知识(逻辑回归,鸢尾花实现多分类)

    记录TensorFlow听课笔记 文章目录 记录TensorFlow听课笔记 一,线性回归 二,广义线性回归 三,一元/多元逻辑回归 四,实现一元逻辑回归 五,多分类问题 六,TensorFlow实现 ...

  8. [深度学习]Python/Theano实现逻辑回归网络的代码分析

    2014-07-21 10:28:34 首先PO上主要Python代码(2.7), 这个代码在Deep Learning上可以找到. 1 # allocate symbolic variables f ...

  9. 【Pytorch学习】torch.mm()torch.matmul()和torch.mul()以及torch.spmm()

    目录 1 引言 2 torch.mul(a, b) 3 torch.mm(a, b) 4 torch.matmul() 5 torch.spmm() 参考文献 1 引言   做深度学习过程中免不了使用 ...

  10. pytorch学习 中 torch.squeeze() 和torch.unsqueeze()的用法

    squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的 ...

最新文章

  1. Python分析离散心率信号(上)
  2. 细说plsql中的空值表达式
  3. openstack——horizon篇
  4. sysfs cannot create duplicate filename问题
  5. python降序排列说true不存在_Python数据类型串讲(中)
  6. java的static和private_static关键字什么意思?Java中是否可以覆盖一个private或者是static的方法?...
  7. 11位大咖带你玩转WebRTC开发(内附PPT资料下载)
  8. git fatal:HttpRequestException encountered
  9. idea @Autowired 注入爆红(无法注入)
  10. C++中内联函数和宏定义的区别
  11. nginx源码分析:打开监听套接字的流程
  12. 高通模式9008模式linux,高通芯片如何进入9008模式深度救砖
  13. IBM server guide 下载地址及列表
  14. android oem解锁,Android平台OEM解锁分析
  15. 随机四位数的猜数游戏
  16. 深入浅出-交接运维工作
  17. 小C实例也有大梦想——自定义strlen函数
  18. Python BDD 框架之lettuce
  19. ensp 移动主机搜索不到AP信道_H3C H5套装评测,AC+AP无缝漫游
  20. 测试之第四集找bug的专业与素养

热门文章

  1. 虚继承 - C++快速入门29
  2. 第2次作业 -- 熟悉 JUnit 测试
  3. SpringBoot参数传递bean自动填充
  4. 3-13 图片几何变换小结
  5. js判断当前页面是否有父页面,页面部分跳转解决办法,子页面跳转父页面不跳转解决 (原)...
  6. mysql 的命令行操作
  7. bootstrap的三角方向符号实现
  8. 数学趣题——猴子吃桃问题
  9. 机器学习(3)——K-近邻算法改进约会网站的配对效果实例
  10. Tigase XMPP Server