《机器学习与数据挖掘》实验五 编程实现误差逆传播算法(BP算法)
前言:
摘要:本文对机器学习实验五 标准BP算法的代码进行实现,如果不了解的BP算法的话,可以自行上网搜索BP算法的详解.
实验题目:编程实现误差逆传播算法(BP算法)
实验目的:掌握误差逆传播算法(BP算法)的工作流程
实验环境(硬件和软件)Anaconda/Jupyter notebook/Pycharm
实验内容:
编码实现标准BP算法,在西瓜数据集3.0上用这个算法训练一个单隐层网络,并进行测试。
要求:
一、已经给定部分代码,补充完整的代码,需要补充代码的地方已经用红色字体标注,在第(2)部分,包括:
#补充前向传播代码
#补充反向传播代码
#补充参数更新代码
#补充Loss可视化代码
二、将补充完整的第(2)部分的代码提交,并提交实验结果;(也可以自己重写这部分的代码提交)
代码 :
PS:标准BP算法
import pandas as pd import numpy as np from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import StandardScaler import matplotlib.pyplot as pltseed = 2020 import randomnp.random.seed(seed) # Numpy module. random.seed(seed) # Python random module.plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.close('all') # 数据预处理 data = pd.read_csv("PS:这里放入自己数据所在的路径。")def preprocess(data):# 将非数映射数字for title in data.columns:if data[title].dtype == 'object':encoder = LabelEncoder()data[title] = encoder.fit_transform(data[title])# 去均值和方差归一化ss = StandardScaler()X = data.drop('好瓜', axis=1)Y = data['好瓜']X = ss.fit_transform(X)x, y = np.array(X), np.array(Y).reshape(Y.shape[0], 1)return x, y# 定义Sigmoid def sigmoid(x):return 1 / (1 + np.exp(-x))# 求导 def d_sigmoid(x):return x * (1 - x)# 标准BP算法def standard_BP(x, y, dim=10, eta=0.8, max_iter=500):n_samples = 1w1 = np.random.random((x.shape[1], dim))w2 = np.random.random((dim, 1))b1 = np.random.random((n_samples, dim))b2 = np.random.random((n_samples, 1))losslist = []for ite in range(max_iter):loss_per_ite = []for m in range(x.shape[0]):xi, yi = x[m, :], y[m, :]xi, yi = xi.reshape(1, xi.shape[0]), yi.reshape(1, yi.shape[0])##补充前向传播代码u1 = np.dot(xi, w1) + b1out1 = sigmoid(u1)u2 = np.dot(out1, w2) + b2out2 = sigmoid(u2)loss = np.square(yi - out2) / 2loss_per_ite.append(loss)print('iter:%d loss:%.4f' % (ite, loss))##反向传播##补充反向传播代码##补充参数更新代码g = (yi - out2) * d_sigmoid(out2)d_w2 = np.dot(np.transpose(out1), g)d_b2 = -gd_out1 = np.dot(g, np.transpose(w2))e = d_out1 * sigmoid(out1)d_w1 = np.dot(np.transpose(xi), e)d_b1 = -ew1 = w1 + eta * d_w1w2 = w2 + eta * d_w2b1 = b1 + eta * d_b1b2 = b2 + eta * d_b2losslist.append(np.mean(loss_per_ite))##Loss可视化##补充Loss可视化代码plt.figure()plt.plot([i + 1 for i in range(max_iter)], losslist)plt.legend(['standard BP'])plt.xlabel('iteration')plt.ylabel('loss')plt.show()return w1, w2, b1, b2# 测试 def main():data = pd.read_table('watermelon30.txt', delimiter=',')data.drop('编号', axis=1, inplace=True)x, y = preprocess(data)dim = 10# _,_,_,_ = standard_BP(x,y,dim)w1, w2, b1, b2 = standard_BP(x, y, dim)u1 = np.dot(x, w1) + b1out1 = sigmoid(u1)u2 = np.dot(out1, w2) + b2out2 = sigmoid(u2)y_pred = np.round(out2)result = pd.DataFrame(np.hstack((y, y_pred)), columns=['真值', '预测'])result.to_excel('result.xlsx', index=False)# 补充测试代码,根据当前的x,预测其类别; if __name__ == '__main__':main()
总结:
标准BP的算法实现的核心在于对的向前,向后,参数的更新算法的理解。
PS:BP算法中有标准BP算法与累积BP算法的区别,这里给出的是标准的BP算法的一个小示例,如果读者对BP算法有兴趣的话,建议您在网上查阅资料或者查阅书籍。
《机器学习与数据挖掘》实验五 编程实现误差逆传播算法(BP算法)相关推荐
- BP算法误差逆传播参数更新公式推导
BP算法误差逆传播参数更新公式推导
- 广州大学机器学习与数据挖掘实验三
实验三 聚类分析 一. 实验目的 本实验课程是计算机.人工智能.软件工程等专业学生的一门专业课程,通过实验,帮助学生更好地掌握数据挖掘与机器学习相关概念.技术.原理.应用等:通过实验提高学生编写实验报 ...
- 人工智能知识全面讲解:多层神经网络与误差逆传播算法
7.3.1 从单层到多层神经网络 明斯基教授曾表示,单层神经网络无法解决异或问题,但是当增加一个计 算层以后,两层神经网络不仅可以解决异或问题,而且具有非常好的非线性分 类效果.只是两层神经网络的计算 ...
- 广州大学机器学习与数据挖掘实验二
实验二 逻辑回归与朴素贝叶斯分类 一. 实验目的 本实验课程是计算机.人工智能.软件工程等专业学生的一门专业课程,通过实验,帮助学生更好地掌握数据挖掘与机器学习相关概念.技术.原理.应用等:通过实验提 ...
- 河北工业大学数据挖掘实验五 k-means聚类算法
k-均值聚类算法 一.实验目的 二.实验原理 1.k-均值聚类 2.终止条件 三.实验内容和步骤 1.实验内容 2.实验步骤 3.程序框图 4.实验样本 5.实验代码 四.实验结果 五.实验分析 一. ...
- linux实验五编程淮海工学院,实验一-LinuxC编程工具GCC和GDB.doc
实验一-LinuxC编程工具GCC和GDB 淮海工学院计算机工程学院实验报告书 课程名: <Linux程序设计> 题 目: 实验一 Linux C编程工具:GCC和GDB 班 级: 软件1 ...
- 山东大学软件学院数据挖掘实验五(2)的坑
一. 实验目的 掌握数据导入Hive表的方式 理解三种数据导入Hive表的原理 二. 实验内容 1.启动Hadoop和Hive服务并创建数据表 2.将Hive表中的数据导出 三. 实验 ...
- 机器学习笔记(十五)——HMM序列问题和维特比算法
一.引言 这篇blog主要讲序列问题和其解法--维特比算法. 二.HMM中的第二个基本问题 序列问题:给定一个观察序列O=O1O2-OTO=O_1O_2\dots O_T和模型u=(A,B,π)u=( ...
- 误差逆传播算法公式理解及推导
前言:公式理解及推导参考自<机器学习>周志华 P101 BP网络 BP网络一般是指由 误差逆传播(error BackPropagation, BP)算法训练的多层前馈神经网络. 给定训练 ...
最新文章
- 二十岁出头的时候上,你一无所有,你拥有一切
- Ubuntu14.04下切换系统自带的Python和Anaconda 下的Python
- Openresty Redis正确使用连接池(set_keepalive)
- js复制功能的有效方法总结新
- 易中天与单田芳的区别在哪儿
- 通达OA 新旧两种数据库连接方式
- 随机信号的傅里叶分析
- python打乱list_超实用!每 30 秒学会一个 Python 小技巧,GitHub 标星 5300!
- NetCore2.0Web应用之Startup
- selenium+log4j+eclipse相关问题及解决方案
- Date类的getYear(),getMonth过时,现在的获取方法
- copy constructor和copy assignment operator的区别
- Scala学习数组/映射/元组
- python软件安装链接电视_Python爬虫程序:电视剧琅琊榜全集的自动化处理
- Appium架构介绍与环境安装
- 解决scalac Error: bad option -make:transitive
- 30个python的最佳实践,快去试试吧!
- PCB过孔的孔径大小对通流的影响
- linux 笔记本合盖不休眠设置
- 叮咚! 你有一份节日祝福请查收~
热门文章
- Hadoop 2.X的安装与配置
- windows配置好用的RSS
- CAD2017打开图纸点字体替换时没有字体选择框的问题
- ios微信组件跳转_iOSAPP跳转微信小程序
- win10电脑桌面透明便签_DesktopNoteOK桌面便签小工具下载|windows10桌面透明便签插件_最火软件站...
- arduino 源码分层浅析
- 微信小程序跳转微信小程序实现免登录
- 解决qrcode生成的二维码安卓手机长按不识别问题
- 微服务架构深度解析与最佳实践 - 第七部分:全文总结与引用材料
- 屏的像素与传输速率_HDMI线的传输速率是如何定义的