PyTorch学习—5.torch实现逻辑回归
文章目录
逻辑回归是线性的二分类模型。逻辑回归是分析自变量x与因变量v(概率)之间关系的方法。模型表达式为:
y=f(WX+b)f(x)=1e−x+1y=f(WX+b)\\f(x)=\frac{1}{e^{-x}+1}y=f(WX+b)f(x)=e−x+11
f(x)为Sigmoid函数f(x)为Sigmoid函数f(x)为Sigmoid函数
下面展示torch实现逻辑回归的过程。
# -*- coding: utf-8 -*-
"""
# @file name : lesson-05-Logsitic-Regression.py
# @author : tingsongyu
# @date : 2019-09-03 10:08:00
# @brief : 逻辑回归模型训练
"""
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)# ============================ step 1/5 生成数据 ============================
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias # 类别0 数据 shape=(100, 2)
y0 = torch.zeros(sample_nums) # 类别0 标签 shape=(100)
x1 = torch.normal(-mean_value * n_data, 1) + bias # 类别1 数据 shape=(100, 2)
y1 = torch.ones(sample_nums) # 类别1 标签 shape=(100)
train_x = torch.cat((x0, x1), 0)
train_y = torch.cat((y0, y1), 0)# ============================ step 2/5 选择模型 ============================
# 利用torch中的nn构建逻辑回归模型
class LR(nn.Module):def __init__(self):super(LR, self).__init__()self.features = nn.Linear(2, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.features(x)x = self.sigmoid(x)return xlr_net = LR() # 实例化逻辑回归模型# ============================ step 3/5 选择损失函数 ============================
# 使用二分类交叉熵损失函数
loss_fn = nn.BCELoss()# ============================ step 4/5 选择优化器 ============================
# 优化器选择SGD(随机梯度下降)
lr = 0.01 # 学习率
optimizer = torch.optim.SGD(lr_net.parameters(), lr=lr, momentum=0.9)# ============================ step 5/5 模型训练 ============================
for iteration in range(1000):# 前向传播y_pred = lr_net(train_x)# 计算 lossloss = loss_fn(y_pred.squeeze(), train_y)# 反向传播loss.backward()# 更新参数optimizer.step()# 清空梯度optimizer.zero_grad()# 每迭代训练20次绘图if iteration % 20 == 0:# 以0.5为阈值进行分类mask = y_pred.ge(0.5).float().squeeze()# 计算正确预测的样本个数correct = (mask == train_y).sum()# 计算分类准确率acc = correct.item() / train_y.size(0)# 绘制训练数据plt.scatter(x0.data.numpy()[:, 0], x0.data.numpy()[:, 1], c='r', label='class 0')plt.scatter(x1.data.numpy()[:, 0], x1.data.numpy()[:, 1], c='b', label='class 1')# 绘制逻辑回归模型w0, w1 = lr_net.features.weight[0]w0, w1 = float(w0.item()), float(w1.item())plot_b = float(lr_net.features.bias[0].item())plot_x = np.arange(-6, 6, 0.1)plot_y = (-w0 * plot_x - plot_b) / w1plt.xlim(-5, 7)plt.ylim(-7, 7)plt.plot(plot_x, plot_y)plt.text(-5, 5, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})plt.title("Iteration: {}\nw0:{:.2f} w1:{:.2f} b: {:.2f} accuracy:{:.2%}".format(iteration, w0, w1, plot_b, acc))plt.legend()plt.show()plt.pause(0.5)if acc > 0.99:break
如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
PyTorch学习—5.torch实现逻辑回归相关推荐
- 机器学习学习吴恩达逻辑回归_机器学习基础:逻辑回归
机器学习学习吴恩达逻辑回归 In the previous stories, I had given an explanation of the program for implementation ...
- 破解数据匮乏现状:纵向联邦学习场景下的逻辑回归(LR)
摘要:主要介绍了华为云可信智能计算服务(TICS)采用的纵向联邦逻辑回归(LR)方案. 本文分享自华为云社区<纵向联邦学习场景下的逻辑回归(LR)>,作者: 汽水要加冰. 海量训练数据是人 ...
- 吴恩达深度学习 —— 2.14 向量化逻辑回归的梯度输出
这一节将学习如果向量化计算m个训练数据的梯度,强调一下,是同时计算. 前面已经说过,在逻辑回归中,有dz(1)=a(1)−y(1)dz^{(1)}=a^{(1)}-y^{(1)}dz(1)=a(1)− ...
- 【推荐算法 学习与复现】-- 逻辑回归算法族 -- LR
协同过滤仅仅使用有限的用户行为信息,逻辑回归算法模型大多引入用户行为.用户特征.物品特征和上下文特征等,从CF逐步过渡到综合不同特征的机器学习模型. (1)逻辑回归模型 将用户特征(年龄.性别等).用 ...
- 神经网络和深度学习(5)-- 逻辑回归
神经网络和深度学习 上一篇 主目录 下一篇 文章结构 1.逻辑回归 [前言] 逻辑回归学习算法,该算法适用于二分类问题,本节将主要 介绍逻辑回归的 Hypothesis Function(假设函数) ...
- 深度学习(入门)——逻辑回归模型(Logistics Regression)
从逻辑回归开始,是因为这是一个简单的,可以理解成为一个简单的一层的神经网络,后续将逐步深入,了解更复杂的神经网络 一,假设函数 逻辑回归算法适用于二分类问题,例如在输入一个猫的图片到模型中,模型会输出 ...
- TensorFlow基础7-机器学习基础知识(逻辑回归,鸢尾花实现多分类)
记录TensorFlow听课笔记 文章目录 记录TensorFlow听课笔记 一,线性回归 二,广义线性回归 三,一元/多元逻辑回归 四,实现一元逻辑回归 五,多分类问题 六,TensorFlow实现 ...
- [深度学习]Python/Theano实现逻辑回归网络的代码分析
2014-07-21 10:28:34 首先PO上主要Python代码(2.7), 这个代码在Deep Learning上可以找到. 1 # allocate symbolic variables f ...
- 【Pytorch学习】torch.mm()torch.matmul()和torch.mul()以及torch.spmm()
目录 1 引言 2 torch.mul(a, b) 3 torch.mm(a, b) 4 torch.matmul() 5 torch.spmm() 参考文献 1 引言 做深度学习过程中免不了使用 ...
- pytorch学习 中 torch.squeeze() 和torch.unsqueeze()的用法
squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的 ...
最新文章
- Python分析离散心率信号(上)
- 细说plsql中的空值表达式
- openstack——horizon篇
- sysfs cannot create duplicate filename问题
- python降序排列说true不存在_Python数据类型串讲(中)
- java的static和private_static关键字什么意思?Java中是否可以覆盖一个private或者是static的方法?...
- 11位大咖带你玩转WebRTC开发(内附PPT资料下载)
- git fatal:HttpRequestException encountered
- idea @Autowired 注入爆红(无法注入)
- C++中内联函数和宏定义的区别
- nginx源码分析:打开监听套接字的流程
- 高通模式9008模式linux,高通芯片如何进入9008模式深度救砖
- IBM server guide 下载地址及列表
- android oem解锁,Android平台OEM解锁分析
- 随机四位数的猜数游戏
- 深入浅出-交接运维工作
- 小C实例也有大梦想——自定义strlen函数
- Python BDD 框架之lettuce
- ensp 移动主机搜索不到AP信道_H3C H5套装评测,AC+AP无缝漫游
- 测试之第四集找bug的专业与素养