之前的逻辑回归文章:从 0 开始机器学习 - 逻辑回归原理与实战!跟大家分享了逻辑回归的基础知识和分类一个简单数据集的方法。

今天登龙再跟大家分享下如何使用逻辑回归来分类手写的 [0 - 9] 这 10 个字符,数据集如下:

下面我就带着大家一步一步写出关键代码,完整的代码在我的 Github 仓库中:

https://github.com/DLonng/AI-Notes/tree/master/MachineLearning/ex3-neural-network/logistic_reg

一、加载手写字符数据

1.1 读取数据集

raw_X, raw_y = load_data('ex3data1.mat')# 5000 x 400
print(raw_X.shape)# 5000
print(raw_y.shape)

数据集有 5000 个样本,每个样本是一个 20 x 20  = 400 像素的手写字符图像:

这个识别手写字符问题属于有监督学习,所以我们有训练集的真实标签 y,维度是 5000,表示训练集中 5000 个样本的真实数字:

1.2 添加全 1 向量

老规矩,在训练样本的第一列前添加一列全 1 的向量(为了与 相乘进行向量化表示):

# 添加第一列全 1 向量
X = np.insert(raw_X, 0, values = np.ones(raw_X.shape[0]), axis = 1)
# 5000 行 401 列
X.shape

添加一列后,样本变为 5000 行 401 列

1.3 向量化标签

把原标签(5000 行 1 列)变为(5000 行 10 列),相当于把每个真实标签用 10 个位置的向量替换:

# 把原标签中的每一类用一个行向量表示
y_matrix = []# k = 1 ... 10
# 当 raw_y == k 时把对应位置的值设置为 1,否则为 0
for k in range(1, 11):y_matrix.append((raw_y == k).astype(int))

改变后向量标签的每一行代表一个标签,只不过用 10 个位置来表示,比如数字 1 对应第一个位置为 1,数字 2 对应第二个位置为 1 ,以此类推,不过注意数字 0 对应第 10 个位置为 1

而每一列代表原始标签值中所有相同的字符,比如第一列表示所有数字 1 的真实标签值,第二列表示所有数字 2 的真实标签值,以此类推,第 10 列表示数字 0 的真实值:

因为我们加载的是 .mat 类型的 Matlab 数据文件,而 Matlab 中索引是从 1 开始的,因此原数据集中用第 10 列表示数字 0。

但是为了方便 Python 处理,我们这里把第 10 列表示的数字 0 移动到第一列,使得列数按照数字顺序 [0 - 9] 排列:

# 因为 Matlab 下标从 1 开始,所以 raw_y 中用 10 表示标签 0
# 这里把标签 0 的行向量移动到第一行
y_matrix = [y_matrix[-1]] + y_matrix[:-1]

这是原实验配的图片,原理是一样的,可以对比理解下(这里没有移动第 10 列哦):

为何要这样做呢?主要是为了完成后面一次预测多个数字的任务。

二、训练模型

逻辑回归和正则化的原理之前都讲过了,没看过的同学可以复习下:

  • 从 0 开始机器学习 - 逻辑回归原理与实战!

  • 从 0 开始机器学习 - 正则化技术原理与编程!

这里我就直接放关键的函数,然后稍加解释下。

2.1 逻辑回归假设函数

假设函数使用常用的 sigmoid 函数:

def sigmoid(z):return 1 / (1 + np.exp(-z))

2.2 逻辑回归代价函数

def cost(theta, X, y):return np.mean(-y * np.log(sigmoid(X @ theta)) - (1 - y) * np.log(1 - sigmoid(X @ theta)))

2.3 逻辑回归正则化代价函数

def regularized_cost(theta, X, y, l=1):theta_j1_to_n = theta[1:]# 正则化代价regularized_term = (l / (2 * len(X))) * np.power(theta_j1_to_n, 2).sum()return cost(theta, X, y) + regularized_term

2.4 梯度计算

def gradient(theta, X, y):return (1 / len(X)) * X.T @ (sigmoid(X @ theta) - y)

2.5 正则化梯度

在原梯度后面加上正则化梯度即可:

def regularized_gradient(theta, X, y, l=1):theta_j1_to_n = theta[1:]# 正则化梯度regularized_theta = (l / len(X)) * theta_j1_to_n# 不对 theta_0 正则化regularized_term = np.concatenate([np.array([0]), regularized_theta])return gradient(theta, X, y) + regularized_term

2.6 逻辑回归训练函数

使用 scipy.optimize 来优化:

"""逻辑回归函数args:X: 特征矩阵, (m, n + 1),第一列为全 1 向量y: 标签矩阵, (m, )l: 正则化系数return: 训练的参数向量
"""
def logistic_regression(X, y, l = 1):# 保存训练的参数向量,维度为特征矩阵的列数,即特征数 + 1theta = np.zeros(X.shape[1])# 使用正则化代价和梯度训练res = opt.minimize(fun = regularized_cost,x0 = theta,args = (X, y, l),method = 'TNC',jac = regularized_gradient,options = {'disp': True})# 得到最终训练参数final_theta = res.xreturn final_theta

三、训练模型

我们先来训练模型,使得能识别单个数字 0,y[0] (5000 行 1 列)代表所有真实标签值为 0 的样本,参考之前讲的向量化标签:

theta_0 = logistic_regression(X, y[0])

预测的结果 theta_0(401 行 1 列) 是手写字符 0 对应的参数向量。

四、预测训练集数字 0

我们用训练的 theta_0 参数来预测下训练集中所有的字符图像为 0 的准确度:

def predict(x, theta):prob = sigmoid(x @ theta)return (prob >= 0.5).astype(int)
# 字符 0 的预测值,也是 5000 行 1 列
y_pred = predict(X, theta_0)

y_pred 是 5000 行 1 列的向量,元素只有 0 和 1,1 表示样本预测值为数字 0,0 表示预测值不是数字 0。

我们再把预测值和真实值进行比较,计算下误差的平均值作为输出精度:

# 打印预测数字 1 的精度
print('Accuracy = {}'.format(np.mean(y[0] == y_pred)))Accuracy = 0.9974

显示该模型识别训练集中手写数字 0 的图像正确率约为 99.74%!这只是分类一个数字,下面再来一次把 10 个数字都进行分类。

五、分类 10 个数字

上面只训练并预测了一个字符 0,我们可以使用 for 循环来训练全部的 10 个字符,每个字符的训练方法都和上面单个数字相同:

# 训练 0 - 9 这 10 个类别的 theta_[0 -> 9] 参数向量
theta_k = np.array([logistic_regression(X, y[k]) for k in range(10)])

theta_k 是 10 个数字对应的参数向量(每行代表一个参数向量):

# 10 行 401 列
print(theta_k.shape)

对特征矩阵进行预测,注意这里对 theta_k 进行了转置,是为了进行矩阵的乘法运算:

# X(5000, 401), theta_k.T(401, 10)
prob_matrix = sigmoid(X @ theta_k.T)# prob_matrix(5000, 10)
prob_matrix

打印下预测的矩阵(5000 行 10 列):

将每行中的最大一列的索引放入 y_pred 中,用来表示预测的数字:

y_pred = np.argmax(prob_matrix, axis = 1)# (5000, 1)
print(y_pred.shape)y_pred

此时的 y_pred 变为 5000 行 1 列,每一行就是模型预测的数字识别结果,再把真实标签中的 10 替换为 0:

# 用 0 代替 10
y_answer[y_answer == 10] = 0

打印出训练集中每个手写数字的预测精度:

print(classification_report(y_answer, y_pred))

可以看到每个字符在训练集上的预测效果都能达到 90% 以上,说明这个模型在训练集上的预测效果比较不错。

OK,今天就分享这些,希望大家多多实践!完整可运行代码链接:

https://github.com/DLonng/AI-Notes/tree/master/MachineLearning/ex3-neural-network/logistic_reg

学会了记得回来给我一个 Star 哦 (^▽^)!


投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

【从 0 开始机器学习】逻辑回归识别手写字符!相关推荐

  1. [转载] Pytorch入门实战-----逻辑回归识别手写数据集

    参考链接: 在PyTorch中使用Logistic逻辑回归识别手写数字 定义的网络比较简单,可以自行修改,调一下参数,识别率就会上去了. import torch import torch.nn as ...

  2. python识别手写文字_如何快速使用Python神经网络识别手写字符?(文末福利)

    原标题:如何快速使用Python神经网络识别手写字符?(文末福利) 点击标题下[异步社区]可快速关注 在本文中,我们将进一步探讨一些使用Python神经网络识别手写字符非常有趣的想法.如果只是想了解神 ...

  3. 如何识别手写文字python_如何快速使用Python神经网络识别手写字符?(文末福利)...

    ​点击标题下[异步社区]可快速关注 在本文中,我们将进一步探讨一些使用Python神经网络识别手写字符非常有趣的想法.如果只是想了解神经网络的基本知识,那不必阅读本文,可以先阅读<Python神 ...

  4. Python 神经网络是这样识别手写字符哒?

    当谷歌的 AlphaGo 战胜了人类顶级棋手,人工智能开始更多进入大众视野.而谷歌 AI 教父认为:"AlphaGo 有直觉神经网络已接近大脑". 千百年来,人类试图了解智能的机制 ...

  5. Python神经网络是这样识别手写字符哒?

    当谷歌的AlphaGo战胜了人类顶级棋手,人工智能开始更多进入大众视野.而谷歌AI教父认为:"AlphaGo有直觉神经网络已接近大脑". 千百年来,人类试图了解智能的机制,并将它复 ...

  6. 神经网络python识别词语_Python 神经网络是这样识别手写字符哒?

    当谷歌的 AlphaGo 战胜了人类顶级棋手,人工智能开始更多进入大众视野.而谷歌 AI 教父认为:"AlphaGo 有直觉神经网络已接近大脑". 千百年来,人类试图了解智能的机制 ...

  7. python数据分析实战案例logistic_Python机器学习随笔之logistic回归识别手写数字

    编者注:本文用logistic回归来识别多分类问题的手写数字,是之前logisitic回归二分类问题的延续,该篇文章关于其思想以及编程原理见本人之前文章,在这里只注重识别及其编程过程. 01数据准备 ...

  8. python手写字母识别_机器学习--kNN算法识别手写字母

    本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...

  9. 识别手写字体app_我如何构建手写识别器并将其运送到App Store

    识别手写字体app 从构建卷积神经网络到将OCR部署到iOS (From constructing a Convolutional Neural Network to deploying an OCR ...

  10. 吴恩达机器学习 逻辑回归 作业3(手写数字分类) Python实现 代码详细解释

    整个项目的github:https://github.com/RobinLuoNanjing/MachineLearning_Ng_Python 里面可以下载进行代码实现的数据集 题目介绍: In t ...

最新文章

  1. Docker Basic
  2. sql语句的编程手册(2)
  3. ListView相关
  4. CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别
  5. 快速幂(Fast_Power)
  6. 微服务调用组件Feign:简介以及搭建环境
  7. 现在就启用 HTTPS,免费的!
  8. cad pu插件下载lisp_CAD自动编号lisp插件下载
  9. Java回顾之Spring基础
  10. Java — 【报错】Parameter index out of range (1 number of parameters, which is 0).
  11. 我的面试标准:能干活、基础要好、有潜力!
  12. 带通 带阻滤波器 幅频响应_方程推导:二阶有源带通滤波器设计!(内附教程+原理图+视频+代码下载)...
  13. 如何实现网页的自动登录
  14. arm模拟器手机版_ARM模拟器——SkyEye的使用
  15. 一文了解什么是嵌入式?
  16. JavaWeb-Cookie、Session
  17. Excel表格密码保护解除
  18. 生产制造企业仓库管理不到位?ERP系统帮你解决
  19. Java 生成随机汉字名称
  20. 阿里巴巴中报绩优 要帮中小企业产业升级

热门文章

  1. 20172329 2017-2018-2《程序设计与数据结构》课程总结
  2. WPF Invoke与BeginInvoke的区别
  3. bootstrap3-iframe-modal子页面在父页面显示模态框
  4. Oracle odi 数据表导出到文件
  5. android button text属性中英文大小写问题
  6. DB - 常用SQL积累
  7. 【linux】make出现遗漏分隔符
  8. Java程序向MySql数据库中插入的中文数据变成了问号
  9. 小程序navigator点击有时候会闪一下
  10. 我的2017:从工作再到学生