文章目录

  • 1.感知机代码实现

1.感知机代码实现

# 随机梯度下降
import time
import numpy as np
# 显示进度条
from tqdm import tqdm# mnist数据集将数据集做一个首尾拼接,28*28=784
def loaddata(filename):"""加载mnist数据集:param filename: 要加载的数据集路径:return: list形式的数据集及标签"""print('start to read data')# 存放数据dataArr = []labelArr = []# 打开文件fr = open(filename, 'r')# 将文件进行按行读取for line in tqdm(fr.readlines()):# 将读取后的每一行按切割符切割,并返回字段列表curline = line.strip().split(',')# 存放标记# Mnsit有0-9是个标记,由于是二分类任务,所以将>=5的作为1,<5为-1if int(curline[0]) >= 5:labelArr.append(1)else:labelArr.append(-1)# 归一化-遍历每一行除了第一个元素外的所有元素,并归一化# 列表推导式dataArr.append([int(num) / 255 for num in curline[1:]])# 返回data和labelreturn dataArr, labelArrdef perceptron(dataArr, labelArr, iter=50):"""感知机训练过程:param dataArr: 训练集的数据:param labelArr: 训练集的标签:param iter: 迭代次数:return: 训练好的w和b"""print('strat to train')# 将数据转换成矩阵形式,矩阵可以并行计算,方便运算dataMat = np.mat(dataArr)# 将标签转换成矩阵后,在进行转置# 转置是因为运算中需要单独取label中的某一个元素labelMat = np.mat(labelArr).T# 获取数据矩阵的shape,m*nm, n = np.shape(dataMat)# 创建初始权重w,初始值全为0。w = np.zeros((1, n))# 初始化偏置b为0b = 0# 初始化步长η,即学习率,控制梯度下降速率η = 0.001# 迭代计算for k in range(iter):# 采用随机梯度下降-计算一个样本,就针对该样本做一个梯度下降for i in range(m):# 获取当前样本的向量及标签xi = dataMat[i]yi = labelMat[i]# 判断是否为误分类样本,如果为误分类样本,则进行梯度更新# 误分类样本特征:-1 * yi * (w * xi.T + b) >= 0if -1 * yi * (w * xi.T + b) >= 0:w = w + η * yi * xib = b + η * yi# 打印训练进度print('Round %d : %d training' % (k, iter))return w, bdef model_test(dataArr, labelArr, w, b):"""测试准确率:param dataArr:测试数据集:param labelArr: 测试标签:param w: 权重:param b: 偏置:return: 测试集的准确率"""print('start to test')dataMat = np.mat(dataArr)labelMat = np.mat(labelArr).Tm, n = np.shape(dataMat)errornum = 0for i in range(m):xi = dataMat[i]yi = labelMat[i]result = -yi * (w * xi.T + b)# 如果result >= 0,说明被错误分类if result >= 0:errornum += 1# 正确率 = 1 - 错误率accurRate = 1 - (errornum / m)return accurRateif __name__ == '__main__':# 获取当前时间start = time.time()# 加载数据集trainData, trainLabel = loaddata('data/mnist_train.csv')testData, testLabel = loaddata('data/mnist_test.csv')# 获取训练后的权重w, b = perceptron(trainData, trainLabel)# 进行测试获得准确率accurRate = model_test(testData, testLabel, w, b)# 获取当前时间,作为结束时间end = time.time()# 显示正确率print('accuracy rate is:', accurRate)# 显示用时时长print('time span:', end - start)
start to read data
100%|██████████| 60000/60000 [00:11<00:00, 5247.08it/s]
start to read data
100%|██████████| 10000/10000 [00:01<00:00, 5326.08it/s]
strat to train
Round 0 : 50 training
Round 1 : 50 training
Round 2 : 50 training
Round 3 : 50 training
Round 4 : 50 training
Round 5 : 50 training
Round 6 : 50 training
Round 7 : 50 training
Round 8 : 50 training
Round 9 : 50 training
Round 10 : 50 training
Round 11 : 50 training
Round 12 : 50 training
Round 13 : 50 training
Round 14 : 50 training
Round 15 : 50 training
Round 16 : 50 training
Round 17 : 50 training
Round 18 : 50 training
Round 19 : 50 training
Round 20 : 50 training
Round 21 : 50 training
Round 22 : 50 training
Round 23 : 50 training
Round 24 : 50 training
Round 25 : 50 training
Round 26 : 50 training
Round 27 : 50 training
Round 28 : 50 training
Round 29 : 50 training
Round 30 : 50 training
Round 31 : 50 training
Round 32 : 50 training
Round 33 : 50 training
Round 34 : 50 training
Round 35 : 50 training
Round 36 : 50 training
Round 37 : 50 training
Round 38 : 50 training
Round 39 : 50 training
Round 40 : 50 training
Round 41 : 50 training
Round 42 : 50 training
Round 43 : 50 training
Round 44 : 50 training
Round 45 : 50 training
Round 46 : 50 training
Round 47 : 50 training
Round 48 : 50 training
Round 49 : 50 training
start to test
accuracy rate is: 0.8141
time span: 111.59057807922363

统计学习方法读书笔记5-感知机代码实现相关推荐

  1. 统计学习方法读书笔记(六)-逻辑斯蒂回归与最大熵模型(迭代尺度法(IIS))

    全部笔记的汇总贴:统计学习方法读书笔记汇总贴 逻辑斯谛回归 (logistic regression )是统计学习中的经典分类方法.最大熵是概率模型学习的一个准则,将其推广到分类问题得到最大熵模型(m ...

  2. 统计学习方法读书笔记(九)-EM算法及其推广

    全部笔记的汇总贴:统计学习方法读书笔记汇总贴 EM算法用于含有隐变量(hidden variable)的概率模型参数的极大似然估计,或极大后验概率估计.EM算法的每次迭代由两步组成:E步,求期望(ex ...

  3. 统计学习方法 读书笔记(五)

    读书笔记仅供个人学习使用 本文主要参考书籍为<统计学习方法>(李航)第二版 参考 Sunning_001的博客 决策树 决策树的定义 if-then 的理解 条件概率分布的理解 决策树学习 ...

  4. 统计学习方法读书笔记15-逻辑斯蒂回归习题

    文章目录 1.课后习题 2.视频课后习题 1.课后习题 import numpy as np import time import matplotlib.pyplot as plt from mpl_ ...

  5. 首发:李航老师的《统计学习方法》第二版的代码实现(Github标星过万!)

    李航老师的<统计学习方法>第二版的代码实现更新完毕,本文提供下载.(黄海广) 李航老师编写的<统计学习方法>全面系统地介绍了统计学习的主要方法,特别是监督学习方法,包括感知机. ...

  6. 逻辑斯蒂回归_逻辑斯蒂回归详细解析 | 统计学习方法学习笔记 | 数据分析 | 机器学习...

    本文包括: 重要概念 逻辑斯蒂回归和线性回归 二项逻辑斯谛回归模型 逻辑斯蒂回顾与几率 模型参数估计 多项逻辑斯谛回归 其它有关数据分析,机器学习的文章及社群 1.重要概念: 在正式介绍逻辑斯蒂回归模 ...

  7. 【读书笔记】《代码不朽》

    [<代码不朽>读书笔记] 第一章:简介 "简单"原则: 对每次提交负责. 第二章:函数要短 [要求]<=15行 [做法] 提取方法 替换为方法对象 第三章:逻辑分 ...

  8. 李航《统计学习方法》笔记

    虽然书名是统计学习,但是却是机器学习领域中和重要的一本参考书.当前的机器学习中机器指计算机,但是所运用的方法和知识是基于数据(对象)的统计和概率知识,建立一个模型,从而对未来的数据进行预测和分析(目的 ...

  9. 机器学习:《统计学习方法》笔记(一)—— 隐马尔可夫模型

    参考:<统计学习方法>--李航:隐马尔可夫模型--码农场 摘要 介绍隐马尔可夫模型的基本概念.概率计算.学习方法.预测方法等内容. 正文 1. 基本概念 隐马尔可夫模型是关于时序的模型,描 ...

  10. 统计学习方法 学习笔记(1)统计学习方法及监督学习理论

    统计学习方法及监督学习理论 1.1.统计学习 1.1.1.统计学习的特点 1.1.2.统计学习的对象 1.1.3.统计学习的目的 1.1.4.统计学习的方法 1.1.5.统计学习的研究 1.1.6.统 ...

最新文章

  1. “智慧城市”背后的安全隐患
  2. c语言函数指针的理解与使用(学习)
  3. 用户体验——减少用户的操作!
  4. dbunit java_Java – 让DbUnit使用Hibernate事务
  5. 腾讯 监控系统服务器数据采集,实战低成本服务器搭建千万级数据采集系统
  6. java客户端运行hadoop_JAVA客户端连接部署在docker上的hdfs
  7. 日期范围 java_JavaJoDA时间-实现日期范围迭代器
  8. $(document).ready和window.onload的区别
  9. 智能优化算法:引力搜索算法-附代码
  10. Unity3D 下载与安装
  11. 银行转账和分布式事务(转)
  12. matlab GUI画图实例——手动输入函数画图
  13. 电子信息技术专业名词中英文对照(三)
  14. 高低温试验箱的11点使用注意事项说明
  15. 质因数分解-P1069 [NOIP2009 普及组] 细胞分裂
  16. Zigbee 应用层协议自定义
  17. Latex——论文翻译
  18. 关于ETL过程如何保证数据量的准确性和数据的正确性的讨论
  19. 12306火车余票查询
  20. 按键精灵_字符点阵制作

热门文章

  1. 关于WEB标准的理解
  2. 使用opensl 的BufferQueueAudioPlayer对wav文件的播放
  3. shell脚本基础练习题
  4. iOS学习笔记之正则表达式
  5. 收藏一个好看的单选多选样式
  6. 浅入浅出数据结构(23)——图的概念、存储方式与拓扑排序
  7. linux 编译安装nginx,配置自启动脚本
  8. info - 阅读 info 文档
  9. 数据结构:二维ST表
  10. 软件度量都该度个啥?