多分类问题——识别手写体数字0-9

一.逻辑回归解决多分类问题

1.图片像素为20*20,X的属性数目为400,输出层神经元个数为10,分别代表1-10(把0映射为10)。

通过以下代码先形式化展示数据 ex3data1.mat内容:

load('ex3data1.mat'); % training data stored in arrays X, y
m = size(X, 1); %求出样本总数
% Randomly select 100 data points to display
rand_indices = randperm(m); %函数功能随机打乱这m个数字,输出给rand_indices.
sel = X(rand_indices(1:100), :); %按照打乱后的数列取出100个数字,作为X矩阵的行数。displayData(sel); %通过本函数将选出的X矩阵中100个样本进行图形化

函数displayData()实现解析如下:

function [h, display_array] = displayData(X, example_width)
%DISPLAYDATA Display 2D data in a nice grid

if ~exist('example_width', 'var') || isempty(example_width) example_width = round(sqrt(size(X, 2)));   %四舍五入求出图片的宽度
end

colormap(gray); %将图片定义为灰色系

[m n] = size(X);
example_height = (n / example_width); %求出图片的高度% Compute number of items to display
display_rows = floor(sqrt(m));  %计算出每行每列展示多少个数字图片
display_cols = ceil(m / display_rows);

pad = 1; %图片之间间隔% Setup blank display 创建要展示的图片像素大小,空像素,数字图片之间有1像素间隔
display_array = - ones(pad + display_rows * (example_height + pad), ...pad + display_cols * (example_width + pad));% Copy each example into a patch on the display array  将像素点填充进去
curr_ex = 1;
for j = 1:display_rowsfor i = 1:display_colsif curr_ex > m, break; end% Get the max value of the patchmax_val = max(abs(X(curr_ex, :)));display_array(pad + (j - 1) * (example_height + pad) + (1:example_height), ...pad + (i - 1) * (example_width + pad) + (1:example_width)) = ...reshape(X(curr_ex, :), example_height, example_width) / max_val; %reshape函数进行矩阵维数转换curr_ex = curr_ex + 1;endif curr_ex > m, break; end
end

h = imagesc(display_array, [-1 1]); %将像素点画为图片
axis image off %不显示坐标轴
drawnow; %刷新屏幕
end

2.向量化逻辑回归

向量化代价函数和梯度下降,代码同第三周编程练习相同:http://www.cnblogs.com/LoganGo/p/9009767.html

核心代码如下:

function [J, grad] = lrCostFunction(theta, X, y, lambda)

m = length(y); % number of training examples

J = 0;
grad = zeros(size(theta));%分别计算代价值J和梯度grad
J=1/m*(-(y')*log(sigmoid(X*theta))-(1-y)'*log(1-sigmoid(X*theta)))+lambda/(2*m)*(theta'*theta-theta(1)^2);
%grad = 1/m*X'*(sigmoid(X*theta)-y)+lambda*theta/m;
%grad(1) = grad(1)-lambda*theta(1)/m;
grad=1/m*X'*(sigmoid(X*theta)-y)+lambda/m*([0;theta(2:end)]);
grad = grad(:);

end

3.逻辑回归解决多分类问题

oneVsAll.m函数解析:通过阅读原文中所给的英文解析,足够完成本函数的编写

function [all_theta] = oneVsAll(X, y, num_labels, lambda)

m = size(X, 1);
n = size(X, 2);

all_theta = zeros(num_labels, n + 1); %为训练1-10个便签,所以需要矩阵为10*n+1

X = [ones(m, 1) X];
%运用了fmincg()函数求参数,与函数fminunc()相比,处理属性过多时更高效!
options = optimset('GradObj', 'on', 'MaxIter', 50);
for c=1:num_labels,all_theta(c,:)=fmincg(@(t)(lrCostFunction(t, X, (y==c), lambda)), all_theta(c,:)', options)';
endend

预测函数predictOneVsAll()函数编写:

function p = predictOneVsAll(all_theta, X)m = size(X, 1);
num_labels = size(all_theta, 1);

p = zeros(size(X, 1), 1);
X = [ones(m, 1) X];
index=0;
pre=zeros(num_labels,1); %存储每个样本对应数字1-10的预测值for c=1:m,for d=1:num_labels,pre(d)=sigmoid(X(c,:)*(all_theta(d,:)'));
  end[maxnum index]=max(pre);p(c)=index;   %找到该样本最大的预测值所对应的数字,作为实际预测值
end
end

二.神经网络解决多分类问题

使用已经训练好的参数θ1θ2来做预测,predict.m如下:

function p = predict(Theta1, Theta2, X)

m = size(X, 1);
num_labels = size(Theta2, 1);
X=[ones(m,1) X]; %为a1添加为1的偏置

p = zeros(size(X, 1), 1);

for i=1:m, %分别对m个样本做预测a2=sigmoid(Theta1*X(i,:)'); %计算a2a2=[1;a2];                  %为a2添加为1的偏置a3=sigmoid(Theta2*a2);      %计算a3[manum index]=max(a3);      %求出哪个数字的预测值最大p(i)=index;                 %得出预测值
end
end

转载于:https://www.cnblogs.com/LoganGo/p/9057793.html

Coursera-AndrewNg(吴恩达)机器学习笔记——第四周编程作业(多分类与神经网络)...相关推荐

  1. python第六周实验_机器学习 | 吴恩达机器学习第六周编程作业(Python版)

    实验指导书    下载密码:ovyt 本篇博客主要讲解,吴恩达机器学习第六周的编程作业,作业内容主要是实现一个正则化的线性回归算法,涉及本周讲的模型选择问题,绘制学习曲线判断高偏差/高方差问题.原始实 ...

  2. 机器学习 | 吴恩达机器学习第六周编程作业(Python版)

    实验指导书    下载密码:ovyt 本篇博客主要讲解,吴恩达机器学习第六周的编程作业,作业内容主要是实现一个正则化的线性回归算法,涉及本周讲的模型选择问题,绘制学习曲线判断高偏差/高方差问题.原始实 ...

  3. [吴恩达机器学习笔记]12支持向量机3SVM大间距分类的数学解释

    12.支持向量机 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考资料 斯坦福大学 2014 机器学习教程中文笔记 by 黄海广 12.3 大间距分类背后的数学原理- Mathematic ...

  4. 吴恩达机器学习笔记55-异常检测算法的特征选择(Choosing What Features to Use of Anomaly Detection)

    吴恩达机器学习笔记55-异常检测算法的特征选择(Choosing What Features to Use of Anomaly Detection) 对于异常检测算法,使用特征是至关重要的,下面谈谈 ...

  5. 吴恩达机器学习笔记:(四)矩阵、多元梯度下降

    吴恩达机器学习笔记 矩阵基础知识 矩阵逆运算 矩阵的转置 实践乘法 多元梯度下降 特征缩放 学习率α 矩阵基础知识 矩阵逆运算 矩阵的转置 实践乘法 多元梯度下降 特征缩放 学习率α 学习率的选择:

  6. 吴恩达机器学习笔记:(一)机器学习方法简介

    吴恩达机器学习笔记 Supervised Learning(监督学习) Unsupervised Learning(无监督学习) clustering 聚类算法 market segments 市场细 ...

  7. 吴恩达机器学习笔记第一周

    第一周 吴恩达机器学习笔记第一周 一. 引言(Introduction) 1.1 欢迎 1.2 机器学习是什么? 1.3 监督学习 1.4 无监督学习 二.单变量线性回归(Linear Regress ...

  8. 吴恩达机器学习笔记week8——神经网络 Neutral network

    吴恩达机器学习笔记week8--神经网络 Neutral network 8-1.非线性假设 Non-linear hypotheses 8-2.神经元与大脑 Neurons and the brai ...

  9. 吴恩达机器学习笔记整理(Week6-Week11)

    1. Week 6 1.1 应用机器学习的建议(Advice for Applying Machine Learning) 1.1.1 决定下一步做什么 到目前为止,我们已经介绍了许多不同的学习算法, ...

最新文章

  1. hdu2037今年暑假不AC
  2. shell中条件判断if中的-z到-d的意思
  3. PHP从零开始--字段修饰符数据操作SQL语言
  4. 河北科技大学——数据结构课后习题
  5. kvm上添加万兆网卡_烂泥:为KVM虚拟机添加网卡
  6. jQuery知识点学习整理
  7. 简历编辑导出工具(类似wps简历助手)
  8. 元器件保护必备知识——静电防护
  9. 数据挖掘-基于随机森林模型的企业偷漏税纳税人识别
  10. Matlab绘制简单动画
  11. [pytorch] monai Vit 网络 图文分析
  12. Abstract Factory模式(抽象工厂模式)
  13. 【大数据分析】未开先火|北京环球影城网络传播热度洞察
  14. 关于 JavaScript 中 null 的一切
  15. 古诗+代码 = 绝配
  16. ava查询mysql的数据_【技术综述】AVA-第一个大规模的美学质量评估数据库
  17. [R]提高R语言速度
  18. EMQX 入门教程——导读
  19. 室内清扫机器人部分资料收集汇总
  20. 研究生论文查重原则是什么?

热门文章

  1. 信息学奥赛一本通(C++)在线评测系统——基础(一)C++语言——1078:求分数序列和
  2. 【STM32】FreeRTOS资源(持续更新)
  3. 【Linux网络编程】因特网的IP协议是不可靠无连接的,那为什么当初不直接把它设计为可靠的?
  4. mysql导入wordpress_WordPress搬家,导入mysql出错的解决方法 - 老牛博客
  5. python向mysql中添加数据_通过python操控MYSQL添加数据,并将数据添加到EXCEL中-阿里云开发者社区...
  6. 数据结构-----跳表
  7. iTerm2 保存日志
  8. 牛客IOI周赛19-普及组 B.小y的序列
  9. CF-1209 F. Koala and Notebook(建图BFS)
  10. Luogu-P4768 (Kruskal重构树+最短路)