吴恩达机器学习 神经网络 作业1(用已经求好的权重进行手写数字分类) Python实现 代码详细解释
整个项目的github:https://github.com/RobinLuoNanjing/MachineLearning_Ng_Python
里面可以下载进行代码实现的数据集
题目介绍:
In the previous part of this exercise, you implemented multi-class logistic regression to recognize handwritten digits. However, logistic regression cannot form more complex hypotheses as it is only a linear classifier.
In this part of the exercise, you will implement a neural network to recognize handwritten digits using the same training set as before. The neural network will be able to represent complex models that form non-linear hypotheses. For this week, you will be using parameters from a neural network that we have already trained. Your goal is to implement the feedforward propagation algorithm to use our weights for prediction. In next week’s ex- ercise, you will write the backpropagation algorithm for learning the neural network parameters.
翻译:
在之前的练习中,你实现了用多类逻辑回归去识别手写的数字。但是逻辑回归不能形成更复杂的假设。因为它仅仅是一个线性分类器。
在这节的练习里,你将会实现用一个神经网络来进行手写数字识别,数据集将会使用和上一节一样的数据。神经网络能够呈现更为复杂的非线性假设的模型。这一周,你们将会使用我们利用神经网络已经训练好的权重。你的目标就是实现前向传播算法进行预测。下周,你需要自己写一个反向传播算法来用神经网络进行学习。
题目解析:
就是把之前的利用逻辑回归算法里求出的theta,换成他们用神经网络求出的theta,然后进行一次预测。值得注意的是,已经训练好的theta有两组,一个是theta1,一个是theta2。
代码详细解释:
大部分代码基本上和上一节作业没有变化。
首先就是删掉了oneVsAll()函数,当然也顺便删掉了costFunction()和gradient()函数。我们不需要求theta,因为theta已经给出了。
1.然后直接看主函数。注意!!!!!!!!!!!!!!这里载入的是ex3data1.mat文件。也就是吴恩达作业中的数据,跟我上一节导入的数据不一样。上一节导入的数据是别人修改过的。区别是:ex3data1.mat文件中,0对应的y值是10!而别人修改过的0对应的值是0。
区别可大了。ex3weights.mat权重是根据0的y值为10训练的。如果你还采用上一节作业我用的数据,预测准确率大概减少百分之十。
def logisticRegression_oneVsAll():data=loadMatFile('ex3data1.mat') #1.导入文件。在.mat文件中存放的是两个矩阵,X是图片矩阵(5000x400),每一行都是一个数字图片的矩阵X=data['X']y=data['y']m=len(X)#我们先选100个数字看看。rand_indices=[np.random.randint(0,m) for x in range(100)] #2.显示100个数字:这一步是利用列表表达式选取100个随机的数字。show_data(X[rand_indices, :]) # 显示100个数字X=np.hstack((np.ones((len(y),1)),X)) #先将X中补上一列1。# X = np.insert(X, 0, values=np.ones(X.shape[0]), axis=1)data=loadWeight('ex3weights.mat') #载入已经求好的权重。theta1=data['Theta1']theta2=data['Theta2']theta1=np.transpose(theta1) #两个权重都需要进行转置才可以运算theta2=np.transpose(theta2)theta1=np.insert(theta1,0,1,axis=1) #这里需要注意下,theta1必须添加一个bias偏置predict(theta1,theta2,X,y) #4.调用predict函数。
2.直接看predict函数。这个函数也是改动比较大。体现了神经网络的分层。X是输入层的数据,X*theta1,并且进行激励函数sigmod的运算,求出的结果就是第二层的数据layer2,也就是a(2)层,然后我们利用a(2)层的数据求出第三层的数据z2,z2进行sigmoid运算求出hx,也就是结果。
#预测函数
def predict(theta1,theta2,X,y):z=np.dot(X,theta1) #此时是5000x26layer2=sigmoid(z) #将theta代入到假设函数中去z2=np.dot(layer2,theta2) #此时是5000x10hx=sigmoid(z2)m=X.shape[0]'''返回h中每一行最大值所在的列号- np.max(h, axis=1)返回h中每一行的最大值(是某个数字的最大概率)- 最后where找到的最大概率所在的列号(列号即是对应的数字)'''p=np.where(hx[0,:]==np.max(hx,axis=1)[0]) #我们需要知道每一行中的最大值,因为这个最大值对应的列数就是我们预测的数字。for i in range(1,m): #打个比方,第一行数据对应的实际值是0,那么如果预测准确,这一行中的最大值应该是在第0列。temp=np.where(hx[i,:]==np.max(hx,axis=1)[i])p=np.vstack((p,temp)) #我们将每一行中的最大值对应的列都添加到p数组中,此时p数组就是存放的每个数字的预测值。p=p+1 #这里需要将p加1,因为在这个数据集中,0对应的y值是10。所以会导致hx发生偏移。什么意思呢?我们是利用已经求好的theta1和theta2进行计算。求出的#结果hx实际上是按照[1,2,3,4,5,6,7,8,9,0]排列的。而他们所对应的索引是[0,1,2,3,4,5,6,7,8,9]。也就是说我们通过循环求出的索引实际上是偏小的。得+1print('在当前数据集上,训练的准确度为%f%%'%np.mean(np.float64(p==y)*100)) #我们将预测值p与实际值y,进行比较。如果相同,则为True,不同则为False,通过np.float计算,True为1,Flase为0.return p
实验结果:
这个结果应该是过拟合的。。。
在当前数据集上,训练的准确度为97.520000%
全部代码:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import scipy.io as spio
from scipy import optimizenp.set_printoptions(suppress=True, threshold=np.nan) #去除科学计数法,不然看起来太难受
matplotlib.rcParams['font.family']='Arial Unicode MS' #mac环境下防止中文乱码'''
1.导入文件loadMatFile(),注意,跟之前不同,这次导入的文件是.mat文件
2.利用show_data()函数,我们先显示100个数字看看。
3.直接调用predict函数,并且代入通过神经网络运算求好的theta1和theta2
'''def logisticRegression_oneVsAll():data=loadMatFile('ex3data1.mat') #1.导入文件。在.mat文件中存放的是两个矩阵,X是图片矩阵(5000x400),每一行都是一个数字图片的矩阵X=data['X']y=data['y']m=len(X)#我们先选100个数字看看。rand_indices=[np.random.randint(0,m) for x in range(100)] #2.显示100个数字:这一步是利用列表表达式选取100个随机的数字。show_data(X[rand_indices, :]) # 显示100个数字X=np.hstack((np.ones((len(y),1)),X)) #先将X中补上一列1。# X = np.insert(X, 0, values=np.ones(X.shape[0]), axis=1)data=loadWeight('ex3weights.mat') #载入已经求好的权重。theta1=data['Theta1']theta2=data['Theta2']theta1=np.transpose(theta1) #两个权重都需要进行转置才可以运算theta2=np.transpose(theta2)theta1=np.insert(theta1,0,1,axis=1) #这里需要注意下,theta1必须添加一个bias偏置#这段代码可以不用看,这是我参考别人的代码# a1=X #5000x401# z2=a1@theta1 #5000x26## a2=sigmoid(z2) #5000x26# z3=a2@theta2 #5000x10## a3=sigmoid(z3)## y_pred = np.argmax(a3, axis=1)+1## y=y.flatten() #这里一定要把y转变为1维数组。因为y_pred就是一维数组## accuracy = np.mean(y_pred == y)# print('accuracy = {0}%'.format(accuracy * 100))predict(theta1,theta2,X,y) #4.调用predict函数。#导入mat文件
def loadMatFile(path):return spio.loadmat(path) #这里我们需要借助scipy.io的loadmat方法来导入.mat文件#导入weight
def loadWeight(path):return spio.loadmat(path)# 显示随机的100个数字
'''
显示100个数(若是一个一个绘制将会非常慢,可以将要画的数字整理好,放到一个矩阵中,显示这个矩阵即可)- 初始化一个二维数组- 将每行的数据调整成图像的矩阵,放进二维数组- 显示即可
'''
def show_data(imgs):pad=1 #因为我们显示的是100张图片矩阵的集合,所以每张图片我们可以设置一个分割线,对图片进行划分。pad指的是分割线宽度。show_imgs=-np.ones((pad+10*(20+pad),pad+10*(20+pad))) #初始化一个211x211的矩阵,因为100张图片都是20x20,加上分割线,总共的大小就是211x211。#这里需要了解下,如果初始化的矩阵值为-1,则分割线的颜色会是黑色。row=0 #因为我们要显示100个数字,所以我们需要从图片数组的第0行遍历到第99行,这个row是用来控制遍历的行数for i in range(10): #双层循环,100张图片放进去。for j in range(10):show_imgs[pad+i*(20+pad):pad+i*(20+pad)+20,pad+j*(20+pad):pad+j*(20+pad)+20]=( #这段代码比较复杂。等号左边是从show_imgs这个大矩阵中给图片挑选位置。需要注意图片与图片之间都需要留位置给分割线imgs[row,:].reshape(20,20,order='F')) #因为imgs中的每个数字是一行400个像素数据,我们需要将其改造为20x20的矩阵。order=F,是指列优先对原数组进行reshape。因为python默认的是以行优先,但是matlab是列优先。如果不加这个的话,所有的数字都是横着显示的row+=1plt.imshow(show_imgs,cmap='gray') # 显示灰度图像,plt.imshow()函数负责对图像进行处理,并显示其格式,但是不能显示。其后跟着plt.show()才能显示出来。plt.axis('off') #把显示的轴去掉plt.show()#Sigmoid函数
def sigmoid(z):hx=np.ones((len(z),1)) #初始化一列数组,里面用于存放经过S函数变换后得值。hx=1.0/(1.0+np.exp(-z))return hx#预测函数
def predict(theta1,theta2,X,y):z=np.dot(X,theta1) #此时是5000x26layer2=sigmoid(z) #将theta代入到假设函数中去z2=np.dot(layer2,theta2) #此时是5000x10hx=sigmoid(z2)m=X.shape[0]'''返回h中每一行最大值所在的列号- np.max(h, axis=1)返回h中每一行的最大值(是某个数字的最大概率)- 最后where找到的最大概率所在的列号(列号即是对应的数字)'''p=np.where(hx[0,:]==np.max(hx,axis=1)[0]) #我们需要知道每一行中的最大值,因为这个最大值对应的列数就是我们预测的数字。for i in range(1,m): #打个比方,第一行数据对应的实际值是0,那么如果预测准确,这一行中的最大值应该是在第0列。temp=np.where(hx[i,:]==np.max(hx,axis=1)[i])p=np.vstack((p,temp)) #我们将每一行中的最大值对应的列都添加到p数组中,此时p数组就是存放的每个数字的预测值。p=p+1 #这里需要将p加1,因为在这个数据集中,0对应的y值是10。所以会导致hx发生偏移。什么意思呢?我们是利用已经求好的theta1和theta2进行计算。求出的#结果hx实际上是按照[1,2,3,4,5,6,7,8,9,0]排列的。而他们所对应的索引是[0,1,2,3,4,5,6,7,8,9]。也就是说我们通过循环求出的索引实际上是偏小的。得+1print('在当前数据集上,训练的准确度为%f%%'%np.mean(np.float64(p==y)*100)) #我们将预测值p与实际值y,进行比较。如果相同,则为True,不同则为False,通过np.float计算,True为1,Flase为0.return p#调用 logisticRegression_oneVsAll函数
logisticRegression_oneVsAll()
吴恩达机器学习 神经网络 作业1(用已经求好的权重进行手写数字分类) Python实现 代码详细解释相关推荐
- 吴恩达机器学习 逻辑回归 作业3(手写数字分类) Python实现 代码详细解释
整个项目的github:https://github.com/RobinLuoNanjing/MachineLearning_Ng_Python 里面可以下载进行代码实现的数据集 题目介绍: In t ...
- 吴恩达机器学习神经网络作业(python实现)
1. 多分类逻辑回归 自动识别手写数字 import numpy as np import pandas as pd import matplotlib.pyplot as plt from scip ...
- 4. 吴恩达机器学习课程-作业4-神经网络学习
fork了别人的项目,自己重新填写,我的代码如下 https://gitee.com/fakerlove/machine-learning/tree/master/code 代码原链接 文章目录 4. ...
- 1. 吴恩达机器学习课程-作业1-线性回归
fork了别人的项目,自己重新填写,我的代码如下 https://gitee.com/fakerlove/machine-learning/tree/master/code 代码原链接 文章目录 1. ...
- 3. 吴恩达机器学习课程-作业3-多分类和神经网络
fork了别人的项目,自己重新填写,我的代码如下 https://gitee.com/fakerlove/machine-learning/tree/master/code 代码原链接 文章目录 3. ...
- 吴恩达ex3_吴恩达机器学习 EX3 作业 第一部分多分类逻辑回归 手写数字
1 多分类逻辑回归 逻辑回归主要用于分类,也可用于one-vs-all分类.如本练习中的数字分类,输入一个训练样本,输出结果可能为0-9共10个数字中的一个数字.一对多分类训练过程使用"一对 ...
- 8. 吴恩达机器学习课程-作业8-异常检测和推荐系统
fork了别人的项目,自己重新填写,我的代码如下 https://gitee.com/fakerlove/machine-learning/tree/master/code 代码原链接 文章目录 8. ...
- 7. 吴恩达机器学习课程-作业7-Kmeans and PCA
fork了别人的项目,自己重新填写,我的代码如下 https://gitee.com/fakerlove/machine-learning/tree/master/code 代码原链接 文章目录 7. ...
- 6. 吴恩达机器学习课程-作业6-SVM
fork了别人的项目,自己重新填写,我的代码如下 https://gitee.com/fakerlove/machine-learning/tree/master/code 代码原链接 文章目录 6. ...
最新文章
- mysql linux centos 安装_Linux centos 下在线安装mysql
- 旋转角度_办公娱乐新神器!这款稳固的创意支架,360°旋转随便换角度
- 说好的敬畏每一行代码呢?Antd代码彩蛋炸翻一圈人
- java.text.SimpleDateFormat多线程下的问题
- 百度网盘的速度又又又又又又被黑了...侮辱性极强...
- Eigen 模板库的简介
- 通讯框架 t-io 学习——websocket 部分源码解析
- JSP入门 el表达式
- apache php mysql 开发_Wndows下Apache+php+Mysql环境的搭建及其涉及的知识(转)
- PAT (Basic Level) 1039 到底买不买(模拟)
- web窗体的内置对象
- window8下安装RabbitMQ
- javascript经典实例_一道前端经常忽视的JavaScript面试题
- linux c 文件拷贝函数,Linux C函数库参考手册
- EH使用IPMI基础操作
- SpringBoot工程中,如果不继承spring-boot-starter-parent ,还可以怎么做到的版本管理?
- 程序设计导引及在线实践——练习记录
- EJB - 环境设置
- 曽有望登顶互联网的它,留下“遗产”消失不见
- java网课|面向对象的思想
热门文章
- 编译器为C++ 空类自动生成的成员函数
- 华容道与数据结构 (续 3)
- eclipse——Error exists in required project Proceed with launch?
- C#基础知识---匿名方法使用
- ftp协议及vsftpd的基本应用
- Flex in a Week系列视频教程中文版发布
- 设置DBGridEH自适应列宽的最好方法
- 解决remix在线编译器连接本地私有链环境不成功的问题
- CSocket,CAsyncSocket多线程退出时的一些注意事项(解决关闭WinSoket崩溃的问题)
- SQL2008安装后激活方式以及提示评估期已过解决方法