手写bpnn算法实现iris多分类

东北大学工业智能专业机器学习实验的一个作业,手写bpnn,我实在是写不出来,在网上参考了github大神的代码,再自己修改一下得到了要的代码。
原代码地址:https://github.com/RiptideBo/simpy/blob/master/deep_learning/models/bpnn.py

import numpy as np
import matplotlib.pyplot as pltdef sigmoid(x):return 1 / (1 + np.exp(-1 * x))class layerbuild:   #建造层def __init__(self, units, activation=None, learning_rate=None, is_input_layer=False):""":param units: 每层的神经元数量,第一层必须和x的特征数相同,最后一层要和进行onehot编码后的y的特征数相同"""self.units = unitsself.weight = Noneself.bias = Noneself.activation = activationif learning_rate is None:learning_rate = 0.3self.learn_rate = learning_rateself.is_input_layer = is_input_layerdef initializer(self, back_units):self.weight = np.asmatrix(np.random.normal(0, 0.5, (self.units, back_units)))self.bias = np.asmatrix(np.random.normal(0, 0.5, self.units)).T#随机化权重和偏置if self.activation is None:self.activation = sigmoiddef cal_gradient(self):if self.activation == sigmoid:gradient_mat = np.dot(self.output, (1 - self.output).T)gradient_activation = np.diag(np.diag(gradient_mat))else:gradient_activation = 1return gradient_activationdef forward_propagation(self, xdata):self.xdata = xdataif self.is_input_layer:self.wx_plus_b = xdataself.output = xdatareturn xdataelse:self.wx_plus_b = np.dot(self.weight, self.xdata) - self.biasself.output = self.activation(self.wx_plus_b)return self.outputdef backpropagation(self, gradient):gradient_activation = self.cal_gradient()  # i * i 维gradient = np.asmatrix(np.dot(gradient.T, gradient_activation))self._gradient_weight = np.asmatrix(self.xdata)self._gradient_bias = -1self._gradient_x = self.weightself.gradient_weight = np.dot(gradient.T, self._gradient_weight.T)self.gradient_bias = gradient * self._gradient_biasself.gradient = np.dot(gradient, self._gradient_x).Tself.weight = self.weight - self.learn_rate * self.gradient_weightself.bias = self.bias - self.learn_rate * self.gradient_bias.T# 更新权重和偏置return self.gradientclass BPNN:def __init__(self):self.layers = []self.train_mse = []def add_layer(self, layer):self.layers.append(layer)def build(self):for i, layer in enumerate(self.layers[:]):if i < 1:layer.is_input_layer = Trueelse:layer.initializer(self.layers[i - 1].units)def train(self, xdata, ydata, max_train_round, accuracy):self.max_train_round = max_train_roundself.accuracy = accuracyx_shape = np.shape(xdata)k=0#k用来记录训练轮数for round_i in range(max_train_round):all_loss = 0k+=1for row in range(x_shape[0]):_xdata = np.asmatrix(xdata[row, :]).T_ydata = np.asmatrix(ydata[row, :]).Tfor layer in self.layers:_xdata = layer.forward_propagation(_xdata)loss, gradient = self.cal_loss(_ydata, _xdata)all_loss = all_loss + lossfor layer in self.layers[:0:-1]:gradient = layer.backpropagation(gradient)mse = all_loss / x_shape[0]self.train_mse.append(mse)if mse < self.accuracy:print("训练集上达到预设精度,所用训练轮数为:{}轮".format(k))#print(self.train_mse)xx=np.arange(1,k+1)xx=xx.reshape(-1,1)yy=self.train_mseyy=np.array(self.train_mse)yy=yy.reshape(-1,1)plt.plot(xx,yy)plt.xlabel("train_round")plt.ylabel("error")return print("分类准确率为",1-mse)if k==self.max_train_round:print("训练集上未达到预设精度,所用训练轮数为:{}轮".format(k))xx=np.arange(1,k+1)xx=xx.reshape(-1,1)yy=np.array(self.train_mse)yy=yy.reshape(-1,1)plt.plot(xx,yy)plt.xlabel("train_round")plt.ylabel("error")plt.title("loss visualization")return print("分类准确率为",1-mse)def cal_loss(self, ydata, ydata_):self.loss = -((1-ydata.T)*np.log(1-ydata_) + ydata.T * np.log(ydata_))self.loss_gradient = 2 * (ydata_ - ydata)return self.loss, self.loss_gradientdef predict(self,xdata):y_pred=[]x_shape = np.shape(xdata)for row in range(x_shape[0]):_xdata = np.asmatrix(xdata[row, :]).Tfor layer in self.layers:_xdata = layer.forward_propagation(_xdata)#print(_xdata)_xdata=list(_xdata)y_pred.append(_xdata.index(max(_xdata)))y_pred=np.array(y_pred)return y_predimport numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris_1 = load_iris()
#加载鸢尾花数据集
x = iris_1['data']
y = iris_1['target']
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=2)
import pandas as pd
y_train=y_train.reshape(-1,1)y_train=pd.DataFrame(y_train)
y_train.columns=['type']
y_train1 = pd.get_dummies(y_train.type, prefix='type')
y_train=np.array(y_train1)#这四行代码是将y转化为onehot编码
#使用方法:先实例化,然后构建层
model = BPNN()
for i in (4, 10, 20, 3):model.add_layer(layerbuild(i))
#这里是构建层,注意,第一层的数字必须和x的特征数一样,最后一层是输出,和y的特征数一样。
#x有四个特征,所以第一个是4,y有三类,所以最后的数字是3
#中间两个是隐层,数字任选,你可以调参观察中间的数字变化对于训练轮数有何影响。
model.build()
model.train(xdata=x_train, ydata=y_train, max_train_round=500, accuracy=0.05)
print("测试集上的准确度为:",accuracy_score(model.predict(x_test),y_test))

这个可以实现多分类预测,如果你要预测别的数据集的话就把加载数据的部分改一下,然后记得把y转化为二维向量,就是必须要用reshape,再把y用onehot编码转换一下,这个看代码照葫芦画瓢应该能会吧。
最后,如果你也是工业智能的学生恰好刷到了这篇文章,要参考这里的代码的话,记得改一下变量名hhh

手写bpnn算法实现iris多分类相关推荐

  1. TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类

    TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类 目录 设计思路 实现代码 设计思路 更新-- 实现代码 # -*- coding:utf-8 -*- import ten ...

  2. 算法------手写LRU算法

    算法------手写LRU算法 LRU是Redis中常用的内存淘汰算法. 意思是:当缓存容量满的时候,淘汰最近很少使用的数据. 具体实现逻辑: 把缓存放到双向链表当中,最近使用过.或新放入的缓存,放到 ...

  3. 用python手写KNN算法+kd树及其BBF优化(原理与实现)(下篇)

    用python手写KNN算法+kd树及其BBF优化(原理与实现)(下篇) 接上一篇用python手写KNN算法+kd树及其BBF优化(原理与实现)(上篇) 我们使用training2和test2两个数 ...

  4. 如何写一个简单的手写识别算法?

     可以精准快速的识别出自定义的简单图形. 类似于下面这种? Magic Touch - A Free Game by Nitrome Magic Touch: Wizard for Hire on ...

  5. Python手写线性回归算法

    作者 | 苏南下 来源 | 机器会学习ML(ID:AI_Learning007) 摘要:通俗易懂介绍线性回归算法,并 Python 手写实现. 之前我们介绍了:kNN 算法,主要用于解决分类问题,也可 ...

  6. python手写kmeans算法

    kmean聚类是最基础和常见的算法,工程上使用比较常见,spark, sklearn都有实现,本文手写实现kmeans #!/usr/bin/python import sys import rand ...

  7. Keras【Deep Learning With Python】keras框架下的MNIST数据集训练及自己手写数字照片的识别(分类神经网络)

    文章目录 前言 mnist_model.py predict.py 前言 深度学习领域的"hello,world"可能就是这个超级出名的MNIST手写数字数据集的训练(想多了,要是 ...

  8. 机器学习 手写KNN算法预测城市空气质量

    文章目录 一.KNN算法简介 二.KNN算法实现思路 三.KNN算法预测城市空气质量 1. 获取数据 2. 生成测试集和训练集 3. 实现KNN算法 一.KNN算法简介 KNN(K-Nearest N ...

  9. 使用python手写FFT算法

    FFT(Fast Fourier Transform) 是 DFT(Discrete Fourier Transform)的快读实现,它在机理上没有改变DFT的算法,只是在实现上采用的巧妙的实现. 使 ...

  10. 我手写了个SLAM算法!

    1.前言 前一段时间看过我文章的都知道,我打算写一个SLAM源码阅读的文章,然后,我就去读了Gmapping的源码,感受良多,不足的地方是源码太乱了,阅读起来真的不香.于是就有了这篇文章,在我仔细阅读 ...

最新文章

  1. 零基础入门学习Python(5)Python的数据类型
  2. 字节跳动面试官问我看过哪些源码,然后就没有然后了
  3. 服务器 不支持gbk,解决JS请求服务器gbk文件乱码的问题
  4. Linux下ctrl+c,ctrl+z,ctrl+d的区别
  5. android textview获取背景颜色,Android TextView背景颜色与背景图片设置
  6. Workflow之Activity
  7. API接口设计之RESTful软件架构风格
  8. MegaCli常见命令
  9. Androrid Studio Debug Warning:debug info can be unavailable
  10. Linux之常用操作命令总结一
  11. 三层链路冗余-单宿主网络(拓扑图及思路)
  12. 缓解眼睛疲劳,护眼调节色温软件推荐
  13. javascript event bubbling and capturing (再谈一谈js的事件冒泡和事件补获,看到这篇文章加深了理解)...
  14. 腾讯云+CentOS 7.2+python:搭建微信公众号后台入门教程
  15. eclipse报错 错误: 找不到或无法加载主类
  16. 倍福--授权文件拷贝
  17. 属牛人性格特点及脾气如何呢?
  18. 三星typec转接耳机没反应_typec转3.5mm转接线,你买对了吗?
  19. NAACL2022-Prompt相关论文对Prompt的看法
  20. FFmpeg 基础库(一)视频格式

热门文章

  1. 智能网联汽车仿真测试软件,智能网联汽车测试评价及检测认证
  2. 普通进销存管理系统设计2
  3. 项目管理十大知识领域之项目整合管理
  4. time+dd测试硬盘读写速度
  5. PFC_颗粒流软件_喷射混凝土模拟
  6. 【060】助力一箭四星,翼辉系统再续辉煌
  7. Pandas库的基本使用方法
  8. 怎么彻底卸载cad2017_彻底卸载cad2010的方法步骤
  9. AT89C2051烧写器的制做与调试
  10. 粒子群算法(PSO) C