概述

最近正在学习智能信息处理课程,接触到了一些有关深度学习pytorch的简单应用,pytorch作为python中最常见的深度学习任务工具应用也非常广泛。

如果小伙伴们对神经网络部分相关理论知识比较熟悉,但不知道代码具体怎么实现,可以参考本篇文章的代码部分,希望能够对大家有所帮助。

也是作为模板供自己和大家参考,主要是怕忘hhh(手动/doge

数据集

本篇代码使用的数据是sklearn中的鸢尾花数据集,训练集由120组鸢尾花的数据特征及其标签组成,另包含相同类型的数据30组用作测试集。

其中,鸢尾花的数据特征以小数形式保存,标签分为3类(0,1,2),代表其属于不同类别。

另附数据集下载地址:
https://download.csdn.net/download/weixin_52456426/86724498

(上图为部分训练集数据截图)

代码

本代码采用torch提供的框架,设定输入层特征数为4(即特征数量),隐藏层(hidden layer)神经元数为10,输出层维数为3(标签数量),激活函数选用sigmoid(当然可以用ReLU或者tanh等,经实际试验,在本次数据集上效果差不多的)
优化器选用Adam参数优化(也可以用SGD随机梯度下降等);由于是经典回归任务,loss损失函数设定为交叉熵损失函数。

设定训练轮数为2000,每10轮进行训练集上的accuracy(准确率)与loss值的计算,最终在测试集上运行我们的模型。

代码部分如下(关键部分给出注释):

# -*- coding: utf-8 -*-
# 2022/9/27
# Author:Jonathan_K_Wolf
import numpy as np
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F# 导入数据
data_train = pd.read_csv('./iris_training.csv')
# 训练集特征向量集合
train_x = np.array(data_train.iloc[:, 0:4])
# 训练集标签向量集合
train_y = np.array(data_train.iloc[:, 4])
# 转化为torch.Tensor形式
x_train = torch.FloatTensor(train_x)
y_train = torch.LongTensor(train_y)# 定义网络结构
class Net(nn.Module):def __init__(self, n_features, n_hidden, n_output):super(Net, self).__init__()self.hidden = nn.Linear(n_features, n_hidden)self.out = nn.Linear(n_hidden, n_output)# 前向传播过程def forward(self, x):x = F.relu(self.hidden(x))x = self.out(x)return x# 实例化网络
model = Net(n_features=4, n_hidden=10, n_output=3)
# 参数优化器,可以选用Adam或SGD
optimizer = torch.optim.Adam(model.parameters(), lr=0.06)
# 损失函数,分类用交叉熵,回归用均方误差MSELoss
loss_func = torch.nn.CrossEntropyLoss()# 训练阶段
for i in range(1, 3001):output = model(x_train)loss = loss_func(output, y_train)optimizer.zero_grad()loss.backward()optimizer.step()if i % 10 == 0:acc = 0prediction = torch.argmax(output, dim=1)pred_y = prediction.numpy()target_y = y_train.numpy()for item in range(len(pred_y)):if pred_y[item] == target_y[item]:acc += 1acc /= len(pred_y)acc = round(acc, 3)print('accuracy at {} epoch:{}, loss at {} epoch:{}'.format(i, acc, i, loss))# 测试阶段
data_test = pd.read_csv('./iris_test.csv')
test_x = np.array(data_test.iloc[:, 0:4])
test_y = np.array(data_test.iloc[:, 4])# 测试集转换成Tensor格式
x_test = torch.FloatTensor(test_x)
y_test = torch.LongTensor(test_y)# 对测试集进行准确率评估
output_prediction = model(x_test)
output_pred = torch.argmax(output_prediction, dim=1)
output_pred = output_pred.numpy()
acc_test = 0
for item in range(len(test_y)):if output_pred[item] == test_y[item]:acc_test += 1
acc_test /= len(test_y)
acc_test = round(acc_test, 3)
print('test of acc:{}'.format(acc_test))

最终效果如下:

【pytorch】简单BP神经网络用于通用分类任务的代码模板相关推荐

  1. Python基于PyTorch实现BP神经网络ANN分类模型项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 在人工神经网络的发展历史上,感知机(Multilayer Per ...

  2. Python基于PyTorch实现BP神经网络ANN回归模型项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 在人工神经网络的发展历史上,感知机(Multilayer Per ...

  3. 【Matlab】基于多层前馈网络BP神经网络实现多分类预测(Excel可直接替换数据)

    [Matlab]基于多层前馈网络BP神经网络实现多分类预测(Excel可直接替换数据) 1.算法简介 1.1 算法原理 1.2 算法流程 2.测试数据集 3.替换数据 4.混淆矩阵 5.对比结果 6. ...

  4. 【Matlab树叶分类】BP神经网络植物叶片分类【含GUI源码 916期】

    一.代码运行视频(哔哩哔哩) [Matlab树叶分类]BP神经网络植物叶片分类[含GUI源码 916期] 二.matlab版本及参考文献 1 matlab版本 2014a 2 参考文献 [1] 蔡利梅 ...

  5. 【树叶分类】基于BP神经网络植物叶片分类Matlab代码

    1 简介 本文以树叶为实验对象,针对传统分类问题耗时长,效率低的不足,提出了一个基于BP神经网络植物智能分类系统.这个计算机辅助分类系统不仅能够帮助提高植物分类的准确率同时也能缩减工作人员的工作量. ...

  6. 利用pytorch完成BP神经网络的搭建

    使用pytorch完成神经网络的搭建 一.搭建一个最简单的BP神经网络 BP神经网络前向传播: h = w 1 x y = w 2 h h=w1x\\ y=w2h h=w1xy=w2h import ...

  7. BP神经网络进阶-MINIST分类

    BP神经网络进阶 前言 在BP神经网络原理探索一文中,只是介绍了简单的回归,并给出简单的回归代码.这次要涉及到BP神经网络的分类问题,以在博客园中上蹿下跳异常活泼的MINIST数据集分类为练手~ MI ...

  8. 基于pytorch的BP神经网络实现

    对于一个神经网络,我们可以根据神经网络结构从头实现,例如一个BP神经网络,我们需要选择损失函数.激活函数,根据公式推导反向传递的梯度,并使用梯度下降更新参数,而卷积神经网络,还要写卷积.池化等函数,同 ...

  9. 基于Pytorch全连接神经网络实现多分类

    (一)计算机视觉工具包的介绍 为了方便开发者应用,PyTorch专门开发了一个视觉工具包torchvision,主要包含以下三个部分: 1.models models提供了深度学习中各种经典的神经网络 ...

最新文章

  1. SQLite简易入门
  2. valorant服务器维护啥情况,valorant连不上服务器怎么办 valorant连不上服务器解决方法介绍...
  3. omnigraffle 的一些总结
  4. Spire.XLS 教程:从C#的Excel形状中提取文本和图像
  5. linux c read函数返回值,Linuxc - GNU Readline 库及编程简介
  6. ubantu-16+ndk-r14b 编译 ffmpeg-4.0.2+lame_mp3-3.99.5
  7. 深入理解ARM体系架构(S3C6410)---认识S3C6410
  8. 避免将属性的可见属性层次结构用作用户定义的层次结构中的级别
  9. service: no such service mysqld 与 MySQL 的开启、关闭和重启
  10. [转]用了docker是否还有必要使用openstack?
  11. Slardar Sql Mapper Framework for Java( Java 持久层框架一枚~)
  12. 基于Verilog-HDL实现会呼吸的流水灯
  13. python属于汇编语言还是高级语言_python语言属于汇编语言吗?_后端开发
  14. WGS84 与 北京54 坐标系互转
  15. 在计算机中NIC是什么意思?
  16. 周总结2022.1.10-2022.1.16
  17. 手持云台 1.前期准备
  18. 分形几何python代码_Python, Cython绘制美妙绝伦的Mandelbrot集, 曼德博集分形图案
  19. 大学生创业知识(转)
  20. 3.qt-图解Weiler-Atherton任意多边形剪裁算法

热门文章

  1. springboot搭建支付宝手机网站支付
  2. [转] JS实例操作QQ空间自动点赞方法
  3. 股票交易日志4 12.16
  4. 牛客网-推理判断练习
  5. QE动力学矩阵文件的主要内容及单位
  6. 如何让网页显示友好的错误信息页面
  7. 利用非qq号码的QQ邮箱来获取qq号
  8. linux系统盘的概念,了解linux系统硬盘分区概念-SELinux入门-linux网卡配置及参数学习_169IT.COM...
  9. 人民币大写在线转换工具
  10. 选购笔记本电脑型号的查询