文章目录

  • 实验要求
  • 一、生成训练集
  • 二、数据加载器
  • 三、手动构建模型
    • 3.1 logistic回归模型
    • 3.2 损失函数和优化算法
    • 3.3 模型训练
  • 四、训练结果

实验要求

  1. 人工构造训练数据集
  2. 手动实现logistic回归(只借助Tensor和Numpy相关的库)
  3. 从loss以及训练集上的准确率等多个角度对结果进行分析(可借助nn.BCELoss或nn.BCEWithLogitsLoss作为损失函数,从零实现二元交叉熵为选作)

一、生成训练集

# 1、生成训练集...h_k(x)=1/(1+e^(-k^*x)),k为参数,这里设置为[1.3,-1.0]
num_inputs = 2  # 特征数
num_examples = 1000  # 训练数据集样本数
true_k = [1.3, -1.0]
features = tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)
labels = 1 / (1 + torch.exp(-1 * (true_k[0] * features[:, 0] + true_k[1] * features[:, 1])))
# 加入噪声
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
num0 = 0
num1 = 0
for i in range(num_examples):if labels[i] < 0.5:labels[i] = 0num0 += 1else:labels[i] = 1num1 += 1
# print(labels)
labels = labels.view(num_examples, 1)
print(num0, num1)

二、数据加载器

# 2、读取数据
def train_data_loader(batch_size, features, labels):num_exp = len(features)indices = list(range(num_exp))random.shuffle(indices)  # 打乱样本顺序for i in range(0, num_exp, batch_size):j = torch.LongTensor(indices[i:min(i + batch_size, num_exp)])  # 最后一次可能不足batchyield features.index_select(0, j), labels.index_select(0, j)  # 维度,索引# yield关键字:带yield的函数是一个生成器,只有调用函数时函数才会执行# next(函数):每次执行到yield处会return,下一次调用next时从上一次停止的地方继续执行至yield处# 函数.send(num)会将num值送入上一次停止的地方,再执行next

三、手动构建模型

3.1 logistic回归模型

k = torch.normal(0, 1.0, (num_inputs, 1), dtype=torch.float32)  # 训练参数随机初始化
# k = tensor([[0.5], [-0.5]], dtype=torch.float32)
k.requires_grad_(True)def logistic_regression(x, k):return 1 / (1 + torch.exp(-1 * torch.mm(x, k)))

3.2 损失函数和优化算法

这里没有从零开始实现二元交叉熵损失函数,使用了torch.nn.BCELoss函数

# 3.2 损失函数和优化算法
# def bce_loss(y_hat, y):
# return -1 * (y * torch.log10(y_hat) + (1 - y) * torch.log10(1 - y_hat)) 单个样本的损失def sgd(params, lr, batch_size):for param in params:param.data -= lr * param.grad / batch_size  # 这里更改param时用的param.data

3.3 模型训练

# 3.3 模型训练
lr = 0.03
num_epochs = 20
batch_size = 10
net = logistic_regression
loss = torch.nn.BCELoss()for epoch in range(num_epochs):  # 一共进行num_epoch个迭代周期# 每个epoch会遍历使用一遍所有的训练集样本for x, y in train_data_loader(batch_size, features, labels):# x,y是一个batch的样本特征和标签y_hat = net(x, k)l = loss(y_hat, y)l.backward()  # 求梯度sgd([k], lr, batch_size)k.grad.data.zero_()  # 梯度清零train_l = loss(net(features, k), labels)  # 计算训练样本的损失print('epoch %d, loss %f' % (epoch + 1, train_l.item()))print('k', k)

四、训练结果

[PyTorch]手动实现logistic回归(只借助Tensor和Numpy相关的库)相关推荐

  1. stata手动绘制logistic回归预测模型校准曲线(Calibration curve)校准曲线(1)

    校准曲线图表示的是预测值和实际值的差距,作为预测模型的重要部分,目前很多函数能绘制校准曲线. 一般分为两种,一种是通过Hosmer-Lemeshow检验,把P值分为10等分,求出每等分的预测值和实际值 ...

  2. 在PyTorch中使用Logistic回归进行10种猴子物种分类

    欢迎关注 "小白玩转Python",发现更多 "有趣" 引言 本文提供了一个使用PyTorch构建一个非常基本的 Logistic模型的简单步骤,并将其应用于猴 ...

  3. 手动绘制logistic回归预测模型校准曲线(Calibration curve)(1)

    校准曲线图表示的是预测值和实际值的差距,作为预测模型的重要部分,目前很多函数能绘制校准曲线. 一般分为两种,一种是通过Hosmer-Lemeshow检验,把P值分为10等分,求出每等分的预测值和实际值 ...

  4. Pytorch手动实现softmax回归

    文章目录 简述 理论基础 回归 softmax 损失函数 读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 预测 整体代码 d2lzh_pytorch ...

  5. 手动以及使用torch.nn实现logistic回归和softmax回归

    其他文章 手动以及使用torch.nn实现logistic回归和softmax回(当前文章) 手动以及使用torch.nn实现前馈神经网络实验 文章目录 任务 一.Pytorch基本操作考察 1.1 ...

  6. 手动绘制R语言Logistic回归模型的外部验证校准曲线(Calibration curve)(2)

    校准曲线图表示的是预测值和实际值的差距,作为预测模型的重要部分,目前很多函数能绘制校准曲线. 一般分为两种,一种是通过Hosmer-Lemeshow检验,把P值分为10等分,求出每等分的预测值和实际值 ...

  7. logistic回归 如何_第七章:利用Python实现Logistic回归分类模型

    免责声明:本文是通过网络收集并结合自身学习等途径合法获取,仅作为学习交流使用,其版权归出版社或者原创作者所有,并不对涉及的版权问题负责.若原创作者或者出版社认为侵权,请联系及时联系,我将立即删除文章, ...

  8. 【机器学习算法】Logistic回归分类算法

    文章目录 一.Logistic回归 1. 分类问题 2. Logistic函数 (1) 阶跃函数(不可导) (2)可导的阶跃函数 二.Logistic回归的算法原理 1. 基本思路 2. 数学解析 ( ...

  9. logistic模型预测人口python_基于logistic回归stats模型的概率预测置信区间

    您可以使用delta method查找预测概率的近似方差.也就是说var(proba) = np.dot(np.dot(gradient.T, cov), gradient) 其中gradient是模 ...

最新文章

  1. RESTful API 设计最佳实践
  2. 关于机器学习中的一些常用方法的补充
  3. Windows7 中配置IIS7的方法(HTTP 错误 404.3 - Not Found)
  4. 【django】配置项目日志【5】
  5. 什么是野指针和内存泄露?如何避免野指针
  6. IAR切BANK--命令连接器文件xcl格式说明
  7. JAVA入门级教学之(猜数字测试)
  8. mysql5.7非源码版msi安装教程
  9. 我是如何在自学编程9个月后找到工作的
  10. HTML hidden 属性
  11. lammps计算的应力的方法
  12. 怎么批量查找关键词-批量查找关键词软件工具
  13. 三星宣布首款 PCIe 5.0 企业级 SSD:PM1743,将于 2022 年推出
  14. 【Ubuntu和本地电脑互传文件】
  15. 13号线ab线规划图_2018广州地铁13号线二期最新消息:计划今年开工 2022年底建成通车(附线路图+站点)...
  16. 原来小米手机电源键除了开关机,还隐藏这3大用法,真是厉害了
  17. STM32F103C8T6的TIM1的CH1、CH2、CH3三路互补PWM实现四路PWM两两输出
  18. 文化袁探索专栏——React Native启动流程
  19. Hadoop入门案例WordCount
  20. 如何在HTML中引用jQuery函数库

热门文章

  1. hdwiki 框架简介
  2. 重复安装GI的时候报错INS-32025
  3. 折纸珠峰c语言程序,c语言折纸超过珠穆拉玛峰
  4. 百度云不限速下载(官方渠道,无风险)
  5. 黑白图片复原为彩色Picture Colorizer(图片着色器)
  6. 误差棒到底是个什么棒?到底棒不棒!
  7. 修改服务器cimc地址,UCSC系列服务器的CIMC设置.PDF
  8. flask框架学习笔记
  9. java怎么绘画坦克_坦克游戏教程一:使用java绘图功能绘制简单坦克
  10. 小程序源码:AI微信小程序源码下载人脸照片AI转换动漫照片全新源码安装简单无需服务器域名-多玩法安装简单