版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

全连接神经网络是深度学习的基础,理解它就可以掌握深度学习的核心概念:前向传播、反向误差传递、权重、学习率等。这里先用python创建模型,用minist作为数据集进行训练。

定义3层神经网络:输入层节点28*28(对应minist图片像素数)、隐藏层节点300、输出层节点10(对应0-9个数字)。

网络的激活函数采用sigmoid,网络权重的初始化采用正态分布。

完整代码如下:

1 #-*- coding:utf-8 -*-

2

3 u"""全连接神经网络训练学习MINIST"""

4

5 __author__ = 'zhengbiqing 460356155@qq.com'

6

7

8 importnumpy9 importscipy.special10 importscipy.misc11 from PIL importImage12 importmatplotlib.pyplot13 importpylab14 importdatetime15 from random importshuffle16

17

18 #是否训练网络

19 LEARN =True20

21 #是否保存网络

22 SAVE_PARA =False23

24 #网络节点数

25 INPUT = 784

26 HIDDEN = 300

27 OUTPUT = 10

28

29 #学习率和训练次数

30 LR = 0.05

31 EPOCH = 10

32

33 #训练数据集文件

34 TRAIN_FILE = 'mnist_train.csv'

35 TEST_FILE = 'mnist_test.csv'

36

37 #网络保存文件名

38 WEIGHT_IH = "minist_fc_wih.npy"

39 WEIGHT_HO = "minist_fc_who.npy"

40

41

42 #神经网络定义

43 classNeuralNetwork:44 def __init__(self, inport_nodes, hidden_nodes, output_nodes, learnning_rate):45 #神经网络输入层、隐藏层、输出层节点数

46 self.inodes =inport_nodes47 self.hnodes =hidden_nodes48 self.onodes =output_nodes49

50 #神经网络训练学习率

51 self.learnning_rate =learnning_rate52

53 #用均值为0,标准方差为连接数的-0.5次方的正态分布初始化权重

54 #权重矩阵行列分别为hidden * input、 output * hidden,和ih、ho相反

55 self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))56 self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))57

58 #sigmoid函数为激活函数

59 self.active_fun = lambdax: scipy.special.expit(x)60

61 #设置神经网络权重,在加载已训练的权重时调用

62 defset_weight(self, wih, who):63 self.wih =wih64 self.who =who65

66 #前向传播,根据输入得到输出

67 defget_outputs(self, input_list):68 #把list转换为N * 1的矩阵,ndmin=2二维,T转制

69 inputs = numpy.array(input_list, ndmin=2).T70

71 #隐藏层输入 = W dot X,矩阵乘法

72 hidden_inputs =numpy.dot(self.wih, inputs)73 hidden_outputs =self.active_fun(hidden_inputs)74

75 final_inputs =numpy.dot(self.who, hidden_outputs)76 final_outputs =self.active_fun(final_inputs)77

78 returninputs, hidden_outputs, final_outputs79

80 #网络训练,误差计算,误差反向分配更新网络权重

81 deftrain(self, input_list, target_list):82 inputs, hidden_outputs, final_outputs =self.get_outputs(input_list)83

84 targets = numpy.array(target_list, ndmin=2).T85

86 #误差计算

87 output_errors = targets -final_outputs88 hidden_errors =numpy.dot(self.who.T, output_errors)89

90 #连接权重更新

91 self.who += numpy.dot(self.learnning_rate * output_errors * final_outputs * (1 -final_outputs), hidden_outputs.T)92 self.wih += numpy.dot(self.learnning_rate * hidden_errors * hidden_outputs * (1 -hidden_outputs), inputs.T)93

94

95 #图像像素值变换

96 defvals2input(vals):97 #[0,255]的图像像素值转换为i[0.01,1],以便sigmoid函数作非线性变换

98 return (numpy.asfarray(vals) / 255.0 * 0.99) + 0.01

99

100

101 '''

102 训练网络103 train:是否训练网络,如果不训练则直接加载已训练得到的网络权重104 epoch:训练次数105 save:是否保存训练结果,即网络权重106 '''

107 defnet_train(train, epochs, save):108 iftrain:109 with open(TRAIN_FILE, 'r') as train_file:110 train_list =train_file.readlines()111

112 for epoch inrange(epochs):113 #打乱训练数据

114 shuffle(train_list)115

116 for data intrain_list:117 all_vals = data.split(',')118 #图像数据为0~255,转换到0.01~1区间,以便激活函数更有效

119 inputs = vals2input(all_vals[1:])120

121 #标签,正确的为0.99,其他为0.01

122 targets = numpy.zeros(OUTPUT) + 0.01

123 targets[int(all_vals[0])] = 0.99

124

125 net.train(inputs, targets)126

127 #每个epoch结束后用测试集检查识别准确度

128 net_test(epoch)129 print('')130

131 ifsave:132 #保存连接权重

133 numpy.save(WEIGHT_IH, net.wih)134 numpy.save(WEIGHT_HO, net.who)135 else:136 #不训练直接加载已保存的权重

137 wih =numpy.load(WEIGHT_IH)138 who =numpy.load(WEIGHT_HO)139 net.set_weight(wih, who)140

141

142 '''

143 用测试集检查准确率144 '''

145 defnet_test(epoch):146 with open(TEST_FILE, 'r') as test_file:147 test_list =test_file.readlines()148

149 ok =0150 errlist = [0] * 10

151

152 for data intest_list:153 all_vals = data.split(',')154 inputs = vals2input(all_vals[1:])155 _, _, net_out =net.get_outputs(inputs)156

157 max =numpy.argmax(net_out)158 if max ==int(all_vals[0]):159 ok += 1

160 else:161 #识别错误统计,每个数字识别错误计数

162 #print('target:', all_vals[0], 'net_out:', max)

163 errlist[int(all_vals[0])] += 1

164

165 print('EPOCH: {epoch} score: {score}'.format(epoch=epoch, score = ok / len(test_list) * 100))166 print('error list:', errlist, 'total:', sum(errlist))167

168

169 #变换图片的尺寸,保存变换后的图片

170 defresize_img(filein, fileout, width, height, type):171 img =Image.open(filein)172 out =img.resize((width, height), Image.ANTIALIAS)173 out.save(fileout, type)174

175

176 #用训练得到的网络识别一个图片文件

177 defimg_test(img_file):178 file_name_list = img_file.split('.')179 file_name, file_type = file_name_list[0], file_name_list[1]180 out_file = file_name + 'out' + '.' +file_type181 resize_img(img_file, out_file, 28, 28, file_type)182

183 img_array = scipy.misc.imread(out_file, flatten=True)184 img_data = 255.0 - img_array.reshape(784)185 img_data = (img_data / 255.0 * 0.99) + 0.01

186

187 _, _, net_out =net.get_outputs(img_data)188 max =numpy.argmax(net_out)189 print('pic recognized as:', max)190

191

192 #显示数据集某个索引对应的图片

193 defimg_show(train, index):194 file = TRAIN_FILE if train elseTEST_FILE195 with open(file, 'r') as test_file:196 test_list =test_file.readlines()197

198 all_values = test_list[index].split(',')199 print('number is:', all_values[0])200

201 image_array = numpy.asfarray(all_values[1:]).reshape((28, 28))202 matplotlib.pyplot.imshow(image_array, cmap='Greys', interpolation='None')203 pylab.show()204

205

206 start_time =datetime.datetime.now()207

208 net =NeuralNetwork(INPUT, HIDDEN, OUTPUT, LR)209 net_train(LEARN, EPOCH, SAVE_PARA)210

211 if notLEARN:212 net_test(0)213 else:214 print('MINIST FC Train:', INPUT, HIDDEN, OUTPUT, 'LR:', LR, 'EPOCH:', EPOCH)215 print('train spend time:', datetime.datetime.now() -start_time)216

217 #用画图软件创建图片文件,由得到的网络进行识别

218 #img_test('t9.png')

219

220 #显示minist中的某个图片

221 #img_show(True, 1)

784-300-10简单的全连接神经网络训练结果准确率基本在97.7%左右,运行结果如下:

EPOCH: 0 score: 95.96000000000001

error list:  [13, 21, 31, 28, 51, 61, 33, 66, 44, 56]  total:  404

EPOCH: 1 score: 96.77

error list:  [15, 19, 27, 63, 37, 37, 21, 40, 18, 46]  total:  323

EPOCH: 2 score: 97.25

error list:  [9, 17, 26, 26, 24, 56, 21, 41, 22, 33]  total:  275

EPOCH: 3 score: 97.82

error list:  [9, 16, 21, 18, 20, 18, 22, 21, 31, 42]  total:  218

EPOCH: 4 score: 97.54

error list:  [12, 23, 17, 25, 15, 34, 19, 25, 22, 54]  total:  246

EPOCH: 5 score: 97.78999999999999

error list:  [10, 16, 20, 23, 21, 32, 18, 31, 26, 24]  total:  221

EPOCH: 6 score: 97.6

error list:  [9, 13, 26, 34, 27, 26, 20, 28, 22, 35]  total:  240

EPOCH: 7 score: 97.74000000000001

error list:  [12, 8, 26, 29, 27, 26, 25, 20, 27, 26]  total:  226

EPOCH: 8 score: 97.77

error list:  [7, 10, 27, 16, 29, 28, 23, 29, 26, 28]  total:  223

EPOCH: 9 score: 97.99

error list:  [11, 10, 32, 17, 18, 24, 14, 22, 21, 32]  total:  201

MINIST FC Train: 784 300 10 LR: 0.05 EPOCH: 10

train spend time:  0:05:54.137925

Process finished with exit code 0

图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)...相关推荐

  1. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  2. Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化

    文章目录 1 可视化 神经网络的二元分类效果 2 全连接神经网络 3 TensorFlow搭建一个全连接神经网络 3.1 读取MNIST数据 3.2 建立占位符 3.3 建立模型 3.4 正确率 3. ...

  3. 深度学习系列:全连接神经网络和BP算法

    前言 注:以后我的文章会写在个人博客网站上,本站文章也已被搬运.本文地址: https://xiaodongfan.com/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E ...

  4. 深度学习笔记其五:卷积神经网络和PYTORCH

    深度学习笔记其五:卷积神经网络和PYTORCH 1. 从全连接层到卷积 1.1 不变性 1.2 多层感知机的限制 1.2.1 平移不变性 1.2.2 局部性 1.3 卷积 1.4 "沃尔多在 ...

  5. python——tensorflow使用和两层全连接神经网络搭建

    一.Tensorflow使用 1.Tensorflow简介 TensorFlow是一个软件库,封装了建构神经网络的API,类似于MatLab的神经网络工具箱,是Google于2015年推出的深度学习框 ...

  6. 【深度学习】(3) 全连接层、激活函数

    各位同学好,今天和大家分享一下tensorflow2.0深度学习中的相关操作.内容有: (1) 全连接层创建: tf.keras.Sequential(),tf.keras.layers.Dense( ...

  7. 深度学习笔记02——全连接层

    1. Fully connect 每一个activation function就是一个神经元.全连接层就是将每个神经元的输出都作为下一层所有神经元的输入. deep learning 就是有很多的hi ...

  8. 深度学习中多层全连接层的作用

    全连接层参数特多(可占整个网络参数80%左右) 那么全连接层对模型影响参数就是三个: 1,全接解层的总层数(长度) 2,单个全连接层的神经元数(宽度) 3,激活函数 首先我们要明白激活函数的作用是: ...

  9. 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

最新文章

  1. 什么是ObjCTypes?
  2. vue监听h5页面返回健(微信和支付宝浏览器亲测):
  3. 基于用户投票的排名算法(六):贝叶斯平均
  4. 逻辑运算符''取某值
  5. 京东把 Elasticsearch 用的真牛逼!
  6. 学习Nutch不错的系列文章
  7. 编译php时的configure,PHP编译configure时常见错误
  8. 使用Memory DC
  9. Luogu1007 独木桥
  10. cvCalcBackProject() 直方图反向投影匹配
  11. 计算机错误 引用无效名称,有关无效的引用的疑难解答
  12. Windows中I/O完成端口机制详解
  13. 微信小程序 人脸识别登陆模块
  14. MySQL如何删除一行数据
  15. 2021年网络空间安全学院预推免面试经验总结
  16. 创新业务中真需求和伪需求的思考
  17. leetcode:460. LFU最不常用缓存
  18. origin如何绘制双y轴曲线_origin怎么画双y轴 看完恍然大悟
  19. 如何搭建一个自己的FTP服务器
  20. 地方政府留言板文本数据

热门文章

  1. 基于卷积神经网络的人脸认证(判断两个人脸是否是一个人)
  2. Ubuntu 12.04 下编译Android 4.0.3
  3. Ubuntu10.10 配置ssh服务器及samba服务器
  4. 【S操作】老铁留步,干货来了!小总结云存储云办公云笔记工具——我的云工具选择,供您参考...
  5. nohup 命令(设置后台进程): appending output to ‘nohup.out’ 问题
  6. Java中的进程与线程
  7. linux-----shell高级编程----sed应用
  8. CMDB经验分享之 – 剖析CMDB的设计过程
  9. 操作系统的安装与启动基本原理
  10. 移动互联网服务客户端开发技巧 ( Webview及正则)