简单神经网络和卷积神经网络识别手写数字
目录
1 简单CNN
2 Improved CNN
3 卷积神经网络
4 参考博客和视频
5 相关函数阅读keras官方文档
1 简单CNN
实现的一个两层神经网络其隐藏层有15神元输出有10神经元
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt# 加载数据
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# print(train_labels.shape, test_images.shape)
# print(train_images[0])
# print(train_labels[0])
# plt.imshow(train_images[0])
# plt.show()train_images = train_images.reshape((60000, 28*28)).astype("float")
test_images = test_images.reshape((10000, 28*28)).astype("float")
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)# print(train_images[0])
# print(train_labels[0])network = models.Sequential()
network.add(layers.Dense(units=15, activation='relu', input_shape=(28*28,),))
network.add(layers.Dense(units=10, activation='softmax'))
# 多或二分类用softmax输出的是概率 sigmod# 编译步骤
network.compile(optimizer=RMSprop(lr=0.001), loss= 'categorical_crossentropy', metrics=["accuracy"])
network.fit(train_images, train_labels, epochs=20, batch_size=128, verbose=2)# print(network.summary())
# 测试
y_pre = network.predict(test_images[:5])
print(y_pre, test_labels[:5])
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, "test_accuracy:", test_accuracy)
发生过拟合了,在训练集上准确率高但在测试集上低
(每次运行结果可能不一样)
网络结构
print(network.summary())
线性分类器
Softmax分类器
先exp指数运算,好处将值变为正了,然后归一化
13%是猫 87%为车
交叉熵损失
较低时效果好
2 Improved CNN
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt# 加载数据
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# print(train_labels.shape, test_images.shape)
# print(train_images[0])
# print(train_labels[0])
# plt.imshow(train_images[0])
# plt.show()train_images = train_images.reshape((60000, 28*28)).astype("float")
test_images = test_images.reshape((10000, 28*28)).astype("float")
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)# print(train_images[0])
# print(train_labels[0])network = models.Sequential()
network.add(layers.Dense(units=128, activation='relu', input_shape=(28*28,),kernel_regularizer=regularizers.l1(0.0001)))
network.add(layers.Dropout(0.01))
network.add(layers.Dense(units=32, activation='relu',kernel_regularizer=regularizers.l1(0.0001)))
network.add(layers.Dropout(0.01))
network.add(layers.Dense(units=10, activation='softmax'))
# 多或二分类用softmax输出的是概率 sigmod# 编译步骤
network.compile(optimizer=RMSprop(lr=0.001), loss= 'categorical_crossentropy', metrics=["accuracy"])
network.fit(train_images, train_labels, epochs=20, batch_size=128, verbose=2)# print(network.summary())
# 测试
# y_pre = network.predict(test_images[:5])
# print(y_pre, test_labels[:5])
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, " test_accuracy:", test_accuracy)
同样发生过拟合了
网络结构
3 卷积神经网络
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import models, layers
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import mnist
# 加载数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 搭建LeNet网络
def LeNet():network = models.Sequential()network.add(layers.Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))network.add(layers.AveragePooling2D((2, 2)))network.add(layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu'))network.add(layers.AveragePooling2D((2, 2)))network.add(layers.Conv2D(filters=120, kernel_size=(3, 3), activation='relu'))network.add(layers.Flatten())network.add(layers.Dense(84, activation='relu'))network.add(layers.Dense(10, activation='softmax'))return network
network = LeNet()
network.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])train_images = train_images.reshape((60000, 28, 28, 1)).astype('float') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)# 训练网络,用fit函数, epochs表示训练多少个回合, batch_size表示每次训练给多大的数据
network.fit(train_images, train_labels, epochs=10, batch_size=128, verbose=2)
test_loss, test_accuracy = network.evaluate(test_images, test_labels)
print("test_loss:", test_loss, " test_accuracy:", test_accuracy)
网络结构
4 参考博客和视频
手把手完成mnist手写数字识别视频
csdn博客
5 相关函数阅读keras官方文档
快速开始序贯(Sequential)模型
简单神经网络和卷积神经网络识别手写数字相关推荐
- 如何构建一个神经网络以使用TensorFlow识别手写数字
介绍 (Introduction) Neural networks are used as a method of deep learning, one of the many subfields o ...
- 508 任务一:卷积神经网络相关概念与pytorch识别手写数字
推荐文章与视频: https://blog.csdn.net/v_JULY_v/article/details/51812459? https://blog.csdn.net/weixin_37763 ...
- BP神经网络理解原理——用Python编程实现识别手写数字(翻译英文文献)
BP神经网络理解原理--用Python编程实现识别手写数字 备注,这里可以用这个方法在csdn中编辑公式: https://www.zybuluo.com/codeep/note/163962 一 ...
- 利用python实现简单的人工神经网络识别手写数字
利用 Python 搭建起了一个简单的神经网络模型,并完成识别手写数字. 1.前置工作 1.1 环境配置 这里使用scikit-learn库内建的手写数字字符集作为本文的数据集.scikit-lear ...
- 四、用简单神经网络识别手写数字(内含代码详解及订正)
本博客主要内容为图书<神经网络与深度学习>和National Taiwan University (NTU)林轩田老师的<Machine Learning>的学习笔记,因此在全 ...
- BP神经网络识别手写数字项目解析及代码
这两天在学习人工神经网络,用传统神经网络结构做了一个识别手写数字的小项目作为练手.点滴收获与思考,想跟大家分享一下,欢迎指教,共同进步. 平常说的BP神经网络指传统的人工神经网络,相比于卷积神经网络( ...
- 华裔女性钱璐璐:用 DNA 开发人工智能神经网络,识别手写数字!
"既然要学人脑的思维方式,为什么不去研究人脑?"霍金斯在<论智能>中说道. 如今,不少生物学研究者正朝着这个方向努力. 不过,请注意:这不是一次传统意义上的生物实验. ...
- Python神经网络识别手写数字-MNIST数据集
Python神经网络识别手写数字-MNIST数据集 一.手写数字集-MNIST 二.数据预处理 输入数据处理 输出数据处理 三.神经网络的结构选择 四.训练网络 测试网络 测试正确率的函数 五.完整的 ...
- 【神经网络与深度学习】第一章 使用神经网络来识别手写数字
人类的视觉系统,是大自然的奇迹之一. 来看看下面一串手写的数字: 大多数人可以毫不费力地认出这些数字是504192.这种轻松是欺骗性的,我们觉得很轻松的一瞬,其实背后过程非常复杂. 在我们大脑的每个半 ...
- 第1章使用神经网络识别手写数字
人类视觉系统是世界奇观之一.考虑以下手写数字序列: 大多数人毫不费力地将这些数字识别为504192.这很容易就是欺骗性的.在我们大脑的每个半球,人类有一个主要的视觉皮层,也被称为V1,包含1.4亿个神 ...
最新文章
- 动态规划(DP),压缩状态,插入字符构成回文字符串
- 将UTC日期时间转换为本地日期时间
- Ubuntu 安装docker CE以及harbor
- 一款免费好用的代码在线比较工具
- 第四次实验 恶意代码技术
- mysql环境搭载后老出错_使用Docker在window10下搭建SWOFT开发环境,mysql连接错误
- 【C++grammar】动态类型转换、typeid与RTTI
- java资源争夺_所有满足类似需求,争夺同类资源的组织和个人统称为( )。...
- 让全球数亿人拍摄到更美的照片,【北京三星研究院】招聘
- MS17-010 “永恒之蓝“ 修复方案
- 怎么设置计算机管理员权限,Windows7管理员权限怎么设置?
- 信息发展树标杆 智慧城市筑屏障
- 百度文库内容复制文字解决方法
- 微信开发工具(小程序)
- 【贪玩巴斯】带你一起攻克英语语法长难句—— 第六章——英语的特殊结构 ——2022年3月19日-20日
- 我的世界服务器怎么做无限的弓,我的世界怎么用命令方块做无限弓?
- 什么是back annotation
- 发那科机器人没有码垛指令_FANUC 机器人码垛编程详细讲解
- 华为的鸿蒙os,鸿蒙OS明天正式发布,十大特性,能拯救暴跌80%的华为手机吗?...
- Tor配置:514 Authentication required
热门文章
- 管理感悟:鼓励正确的山头主义
- 管理感悟:学会推论及验证
- python 删除文件_lt;python笔记gt;点击工具架,删除filechache的文件
- vscode c++ 开发环境搭建(离线、内网)
- Dxg——AD(Altium Designer) 开发笔记整理分类合集【所有的相关记录,都整理在此】
- java 程序出现标点错误,我是学java的新手,下面代码出现报错,请问是什么原因?如何解决?...
- 修改刷新没反应_【原神】全特产高效率采集线路和刷新时间
- jetty 找不到html页面,记一次jetty 404问题排查修复
- 线程编程 pthread 问题集合
- 汇编指令大全及标志位